diff --git a/Benchmark_Training/GPT2_TrainingBenchmark.py b/Benchmark_Training/GPT2_TrainingBenchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1e72b2a3434a371cfcb7858388a1dda630dac7 --- /dev/null +++ b/Benchmark_Training/GPT2_TrainingBenchmark.py @@ -0,0 +1,233 @@ +# GPT2 Benchmark: Replikation FlashAttention-2 Paper auf A100 GPUs +# --------------------------------------------------------------- +# Dieser Code benchmarkt das Training von GPT2-Medium auf dem WikiText-103-Datensatz, +# mit verschiedenen Attention-Implementierungen (torch, flash, flash2). +# Ziel: Vergleich von Laufzeit, Speicherverbrauch und FLOPs mit Fokus auf FlashAttention-2. + +import os +import sys +import time +import torch +import random +import numpy as np +import torch.nn.functional as F +from datasets import load_dataset +from transformers import GPT2Config, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling, Trainer +from accelerate import Accelerator +from flash_attn.models.gpt import GPTLMHeadModel + +# ---------------------------------------- +# 1. Vorbereitung & Konfiguration +# ---------------------------------------- + +# Setzt CUDA-Alloc-Konfiguration zur Fragmentierungsvermeidung +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +# Initialisiere DDP-Kompatibilität via Accelerate +accelerator = Accelerator() + +# Seed für Reproduzierbarkeit +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +set_seed(42) + +# Dummy ColumnParallelLinear wenn nicht vorhanden (Workaround für FlashAttention) +from flash_attn.models import gpt +if not hasattr(gpt, "ColumnParallelLinear") or not isinstance(gpt.ColumnParallelLinear, type): + import torch.nn as nn + class ColumnParallelLinear(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + gpt.ColumnParallelLinear = ColumnParallelLinear + + +# ---------------------------------------- +# 2. Modell-Definition +# ---------------------------------------- + +# GPT2 Modell erzeugen +def get_gpt2_model(attention_impl="torch"): + config = GPT2Config( + n_layer=24, n_head=16, n_embd=1024, vocab_size=50257, # GPT2-Medium + n_positions=1024, + resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, + layer_norm_epsilon=1e-5 + ) + config.attention_config = { + "attn_impl": attention_impl, + "alibi": False, + "rope": True, + "rope_theta": 10000.0, + "use_flash_rotary": True + } + return GPTLMHeadModel(config) + + +# ---------------------------------------- +# 3. Tokenizer & Daten +# ---------------------------------------- + +seq_len = 1024 +per_device_batch_size = 8 +target_steps = 10_000 + +# Lade Tokenizer und setze Padding-Token +tokenizer = AutoTokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.eos_token + +# Dataset laden +raw_dataset = load_dataset("wikitext", "wikitext-103-v1") +global_batch_size = per_device_batch_size * torch.cuda.device_count() +dataset_size = target_steps * global_batch_size + +# Tokenisierung +def tokenize(example): + return tokenizer(example["text"], truncation=True, padding="max_length", max_length=seq_len) + +dataset = raw_dataset["train"].select(range(dataset_size)) +tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"]) + +# Datacollator +data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + +# ---------------------------------------- +# 4. Custom Trainer mit GPT2-spezifischem Loss +# ---------------------------------------- + +# Custom Trainer +class FlashTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + labels = inputs.pop("labels") + inputs.pop("attention_mask", None) + outputs = model(**inputs) + logits = outputs[0] + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ignore_index=-100, + ) + return (loss, outputs) if return_outputs else loss + + +# ---------------------------------------- +# 5. Training & Benchmarking +# ---------------------------------------- + +def count_model_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +# Globale Speichererfassung +def get_memory_summary(): + total_allocated = 0 + total_reserved = 0 + peak_allocated = 0 + for device in range(torch.cuda.device_count()): + total_allocated += torch.cuda.memory_allocated(device) + total_reserved += torch.cuda.memory_reserved(device) + peak_allocated += torch.cuda.max_memory_allocated(device) + return total_allocated, total_reserved, peak_allocated + + +def train_model(attention_impl="torch"): + model = get_gpt2_model(attention_impl) + config = model.config + model = accelerator.prepare(model) + + # Hole Speicherstatistiken vor dem Training + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + mem_before_alloc, _, _ = get_memory_summary() + + + train_args = TrainingArguments( + output_dir=f"./gpt2_{attention_impl}_a100", + overwrite_output_dir=True, + per_device_train_batch_size=per_device_batch_size, + num_train_epochs=1, + logging_steps=999999, + report_to="none", + save_strategy="no", + remove_unused_columns=False, + fp16=True, + dataloader_pin_memory=True, + dataloader_num_workers=4, + ddp_find_unused_parameters=False, + ) + + world_size = int(os.environ.get("WORLD_SIZE", "1")) + global_batch_size = per_device_batch_size * world_size + + start_time = time.time() + + trainer = FlashTrainer( + model=model, + args=train_args, + train_dataset=tokenized_dataset, + data_collator=data_collator, + tokenizer=tokenizer + ) + + trainer.train() + + # Hole Speicherstatistiken nach dem Training + mem_after_alloc, mem_after_reserved, peak_alloc = get_memory_summary() + + elapsed = time.time() - start_time + num_params = count_model_params(model) + n_layer = config.n_layer + hidden_dim = config.n_embd + steps = int(len(tokenized_dataset) / per_device_batch_size / world_size) + avg_step = elapsed / steps if steps else float('nan') + tokens_per_step = per_device_batch_size * seq_len + + # FLOP-Schätzung basierend auf theoretischen Tokens/s: + flops_per_step = ( + 6 * seq_len * num_params + 12 * n_layer * hidden_dim * seq_len * seq_len + ) + flops_total = flops_per_step * per_device_batch_size * steps + tflops_per_s = flops_total / (elapsed * 1e12) + + # Logging + output_path = f"benchmark_2GPU_embd{config.n_embd}_seq{seq_len}_bs{per_device_batch_size}.txt" + is_main_process = int(os.environ.get("RANK", 0)) == 0 + + if is_main_process: + with open(output_path, "a") as f: + f.write("# FlashAttention Benchmark Ergebnisse\n") + f.write(f"Modell: GPT2 | Layers: {config.n_layer} | n_head: {config.n_head} | Embedding Dim: {config.n_embd}\n") + f.write(f"Sequence Length: {config.n_positions} | Batch Size: {train_args.per_device_train_batch_size} | Effective Batch Size (global): {global_batch_size} | FP16: {train_args.fp16}\n\n") + + f.write(f"=== {attention_impl.upper()} ===\n") + f.write(f"Runtime: {elapsed:.2f}s | Steps: {steps} | Step Time: {avg_step:.4f}s\n") + f.write(f"Tokens/s: {tokens_per_step / avg_step:.2f} | TFLOPs/s: {tflops_per_s:.3f}\n") + + f.write(f"MemAlloc Before: {mem_before_alloc / 1024**2:.2f} MiB | MemAlloc After: {mem_after_alloc / 1024**2:.2f} MiB | MemoReserved After: {mem_after_reserved / 1024**2:.2f} MiB\n") + f.write(f"Peak MemAlloc (all GPUs): {peak_alloc / 1024**2:.2f} MiB\n\n") + + +# ---------------------------------------- +# 6. CLI Entry Point +# ---------------------------------------- + +if __name__ == "__main__": + valid_impls = ["torch", "flash", "flash2"] + if len(sys.argv) < 2: + print("❌ Bitte Attention-Variante als Argument angeben (z. B. torch / flash / flash2)") + sys.exit(1) + + attention_impl = sys.argv[1].lower() + if attention_impl not in valid_impls: + print(f"❌ Ungültige Attention-Variante: '{attention_impl}'") + sys.exit(1) + + print(f"\n🚀 Starte Benchmark für Attention-Variante: {attention_impl.upper()}") + train_model(attention_impl) \ No newline at end of file