Skip to content
Snippets Groups Projects
Commit f3c213d1 authored by Armin Bacher's avatar Armin Bacher
Browse files

DDP Kompatibilität

parent e06f56fa
Branches
No related tags found
No related merge requests found
......@@ -10,6 +10,16 @@ import torch
import time
import os
import random, numpy as np
import sys
# --------------------------------------
# DDP-kompatibel machen
from transformers import Trainer
from accelerate import Accelerator
accelerator = Accelerator()
# CUDA Speicheroptimierung (Fragmentierung reduzieren)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Fixe Seeds
def set_seed(seed=42):
......@@ -47,7 +57,8 @@ class FlashTrainer(Trainer):
)
return (loss, outputs) if return_outputs else loss
# GPT2 Modell erzeugen (für Benchmark ggf. GPT2-XL / größer)
# GPT2 Modell erzeugen
def get_gpt2_model(attention_impl="torch"):
config = GPT2Config(
n_layer=24, n_head=16, n_embd=1024, vocab_size=50257, # GPT2-Medium
......@@ -66,7 +77,6 @@ def get_gpt2_model(attention_impl="torch"):
# Tokenizer & Dataset vorbereiten
seq_len = 2048
batch_size = 2
per_device_batch_size = 1
tokenizer = AutoTokenizer.from_pretrained("gpt2")
......@@ -86,6 +96,8 @@ def count_model_params(model):
def train_model(attention_impl="torch"):
model = get_gpt2_model(attention_impl)
model = accelerator.prepare(model)
train_args = TrainingArguments(
output_dir=f"./gpt2_{attention_impl}_a100",
overwrite_output_dir=True,
......@@ -98,17 +110,16 @@ def train_model(attention_impl="torch"):
fp16=True,
dataloader_pin_memory=True,
dataloader_num_workers=4,
gradient_accumulation_steps=1,
ddp_find_unused_parameters=False,
fp16_backend="amp",
optim="adamw_torch",
evaluation_strategy="no"
)
# Anzahl Prozesse/GPU-Geräte (automatisch über DDP)
world_size = int(os.environ.get("WORLD_SIZE", 1))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
global_batch_size = per_device_batch_size * world_size
print(f"\n Benchmark-Konfiguration für {attention_impl.upper()}")
print(f" Modell: GPT2 | Layers: {model.config.n_layer} | Embedding Dim: {model.config.n_embd}")
print(f" Sequence Length: {model.config.n_positions} | Batch Size: {per_device_batch_size} | Global Batch: {global_batch_size} | FP16: {train_args.fp16}")
start_time = time.time()
trainer = FlashTrainer(
......@@ -126,17 +137,17 @@ def train_model(attention_impl="torch"):
n_layer = model.config.n_layer
hidden_dim = model.config.n_embd
steps = int(len(tokenized_dataset) / batch_size)
steps = int(len(tokenized_dataset) / per_device_batch_size / world_size)
avg_step = elapsed / steps if steps else float('nan')
tokens_per_step = batch_size * seq_len
tokens_per_step = per_device_batch_size * seq_len
flops_per_step = (
6 * seq_len * num_params + 12 * n_layer * hidden_dim * seq_len * seq_len
)
flops_total = flops_per_step * batch_size * steps
flops_total = flops_per_step * per_device_batch_size * steps
tflops_per_s = flops_total / (elapsed * 1e12)
output_path = f"benchmark_results_seq{seq_len}_bs{batch_size}.txt"
output_path = f"benchmark_results_seq{seq_len}_bs{per_device_batch_size}.txt"
header_written = os.path.exists(output_path)
with open(output_path, "a") as f:
......@@ -150,21 +161,14 @@ def train_model(attention_impl="torch"):
f.write(f"Tokens/s: {tokens_per_step / avg_step:.2f} | TFLOPs/s: {tflops_per_s:.3f}\n\n")
if __name__ == "__main__":
import sys
valid_impls = ["torch", "flash", "flash2"]
# Erwartet z. B. "flash2" als Argument
if len(sys.argv) < 2:
print("❌ Bitte Attention-Variante als Argument angeben (z. B. torch / flash / flash2)")
print("🔁 Beispiel: python GPT2M_MultiA100.py flash2")
sys.exit(1)
attention_impl = sys.argv[1].lower()
if attention_impl not in valid_impls:
print(f"❌ Ungültige Attention-Variante: '{attention_impl}'")
print(f"✅ Gültige Optionen: {valid_impls}")
sys.exit(1)
print(f"\n🚀 Starte Benchmark für Attention-Variante: {attention_impl.upper()}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment