From cf3564c67ea51bde9ce695829dcc05987bce4301 Mon Sep 17 00:00:00 2001
From: Armin Bacher <armin.bacher@student.uni-halle.de>
Date: Mon, 31 Mar 2025 21:34:39 +0000
Subject: [PATCH] Add GPT2_TrainingBenchmark.py

---
 Benchmark_Training/GPT2_TrainingBenchmark.py | 233 +++++++++++++++++++
 1 file changed, 233 insertions(+)
 create mode 100644 Benchmark_Training/GPT2_TrainingBenchmark.py

diff --git a/Benchmark_Training/GPT2_TrainingBenchmark.py b/Benchmark_Training/GPT2_TrainingBenchmark.py
new file mode 100644
index 0000000..1f1e72b
--- /dev/null
+++ b/Benchmark_Training/GPT2_TrainingBenchmark.py
@@ -0,0 +1,233 @@
+# GPT2 Benchmark: Replikation FlashAttention-2 Paper auf A100 GPUs
+# ---------------------------------------------------------------
+# Dieser Code benchmarkt das Training von GPT2-Medium auf dem WikiText-103-Datensatz,
+# mit verschiedenen Attention-Implementierungen (torch, flash, flash2).
+# Ziel: Vergleich von Laufzeit, Speicherverbrauch und FLOPs mit Fokus auf FlashAttention-2.
+
+import os
+import sys
+import time
+import torch
+import random
+import numpy as np
+import torch.nn.functional as F
+from datasets import load_dataset
+from transformers import GPT2Config, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling, Trainer
+from accelerate import Accelerator
+from flash_attn.models.gpt import GPTLMHeadModel
+
+# ----------------------------------------
+# 1. Vorbereitung & Konfiguration
+# ----------------------------------------
+
+# Setzt CUDA-Alloc-Konfiguration zur Fragmentierungsvermeidung
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
+
+# Initialisiere DDP-Kompatibilität via Accelerate
+accelerator = Accelerator()
+
+# Seed für Reproduzierbarkeit
+def set_seed(seed=42):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+
+set_seed(42)
+
+# Dummy ColumnParallelLinear wenn nicht vorhanden (Workaround für FlashAttention)
+from flash_attn.models import gpt
+if not hasattr(gpt, "ColumnParallelLinear") or not isinstance(gpt.ColumnParallelLinear, type):
+    import torch.nn as nn
+    class ColumnParallelLinear(nn.Module):
+        def __init__(self, *args, **kwargs):
+            super().__init__()
+    gpt.ColumnParallelLinear = ColumnParallelLinear
+
+
+# ----------------------------------------
+# 2. Modell-Definition
+# ----------------------------------------
+
+# GPT2 Modell erzeugen
+def get_gpt2_model(attention_impl="torch"):
+    config = GPT2Config(
+        n_layer=24, n_head=16, n_embd=1024, vocab_size=50257,  # GPT2-Medium
+        n_positions=1024,
+        resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1,
+        layer_norm_epsilon=1e-5
+    )
+    config.attention_config = {
+        "attn_impl": attention_impl,
+        "alibi": False,
+        "rope": True,
+        "rope_theta": 10000.0,
+        "use_flash_rotary": True
+    }
+    return GPTLMHeadModel(config)
+
+
+# ----------------------------------------
+# 3. Tokenizer & Daten
+# ----------------------------------------
+
+seq_len = 1024
+per_device_batch_size = 8
+target_steps = 10_000
+
+# Lade Tokenizer und setze Padding-Token
+tokenizer = AutoTokenizer.from_pretrained("gpt2")
+tokenizer.pad_token = tokenizer.eos_token
+
+# Dataset laden
+raw_dataset = load_dataset("wikitext", "wikitext-103-v1")
+global_batch_size = per_device_batch_size * torch.cuda.device_count()
+dataset_size = target_steps * global_batch_size
+
+# Tokenisierung
+def tokenize(example):
+    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=seq_len)
+
+dataset = raw_dataset["train"].select(range(dataset_size))
+tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])
+
+# Datacollator
+data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+
+
+# ----------------------------------------
+# 4. Custom Trainer mit GPT2-spezifischem Loss
+# ----------------------------------------
+
+# Custom Trainer
+class FlashTrainer(Trainer):
+    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+        labels = inputs.pop("labels")
+        inputs.pop("attention_mask", None)
+        outputs = model(**inputs)
+        logits = outputs[0]
+
+        shift_logits = logits[..., :-1, :].contiguous()
+        shift_labels = labels[..., 1:].contiguous()
+
+        loss = F.cross_entropy(
+            shift_logits.view(-1, shift_logits.size(-1)),
+            shift_labels.view(-1),
+            ignore_index=-100,
+        )
+        return (loss, outputs) if return_outputs else loss
+
+
+# ----------------------------------------
+# 5. Training & Benchmarking
+# ----------------------------------------
+
+def count_model_params(model):
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+# Globale Speichererfassung
+def get_memory_summary():
+    total_allocated = 0
+    total_reserved = 0
+    peak_allocated = 0
+    for device in range(torch.cuda.device_count()):
+        total_allocated += torch.cuda.memory_allocated(device)
+        total_reserved += torch.cuda.memory_reserved(device)
+        peak_allocated += torch.cuda.max_memory_allocated(device)
+    return total_allocated, total_reserved, peak_allocated
+
+
+def train_model(attention_impl="torch"):
+    model = get_gpt2_model(attention_impl)
+    config = model.config
+    model = accelerator.prepare(model)
+
+    # Hole Speicherstatistiken vor dem Training
+    torch.cuda.empty_cache()
+    torch.cuda.reset_peak_memory_stats()
+    mem_before_alloc, _, _ = get_memory_summary()
+
+
+    train_args = TrainingArguments(
+        output_dir=f"./gpt2_{attention_impl}_a100",
+        overwrite_output_dir=True,
+        per_device_train_batch_size=per_device_batch_size,
+        num_train_epochs=1,
+        logging_steps=999999,
+        report_to="none",
+        save_strategy="no",
+        remove_unused_columns=False,
+        fp16=True,
+        dataloader_pin_memory=True,
+        dataloader_num_workers=4,
+        ddp_find_unused_parameters=False,
+    )
+
+    world_size = int(os.environ.get("WORLD_SIZE", "1"))
+    global_batch_size = per_device_batch_size * world_size
+
+    start_time = time.time()
+
+    trainer = FlashTrainer(
+        model=model,
+        args=train_args,
+        train_dataset=tokenized_dataset,
+        data_collator=data_collator,
+        tokenizer=tokenizer
+    )
+
+    trainer.train()
+
+    # Hole Speicherstatistiken nach dem Training
+    mem_after_alloc, mem_after_reserved, peak_alloc = get_memory_summary()
+
+    elapsed = time.time() - start_time
+    num_params = count_model_params(model)
+    n_layer = config.n_layer
+    hidden_dim = config.n_embd
+    steps = int(len(tokenized_dataset) / per_device_batch_size / world_size)
+    avg_step = elapsed / steps if steps else float('nan')
+    tokens_per_step = per_device_batch_size * seq_len
+
+    # FLOP-Schätzung basierend auf theoretischen Tokens/s:
+    flops_per_step = (
+        6 * seq_len * num_params + 12 * n_layer * hidden_dim * seq_len * seq_len
+    )
+    flops_total = flops_per_step * per_device_batch_size * steps
+    tflops_per_s = flops_total / (elapsed * 1e12)
+
+    # Logging
+    output_path = f"benchmark_2GPU_embd{config.n_embd}_seq{seq_len}_bs{per_device_batch_size}.txt"
+    is_main_process = int(os.environ.get("RANK", 0)) == 0
+
+    if is_main_process:
+        with open(output_path, "a") as f:
+            f.write("# FlashAttention Benchmark Ergebnisse\n")
+            f.write(f"Modell: GPT2 | Layers: {config.n_layer} | n_head: {config.n_head} | Embedding Dim: {config.n_embd}\n")
+            f.write(f"Sequence Length: {config.n_positions} | Batch Size: {train_args.per_device_train_batch_size} | Effective Batch Size (global): {global_batch_size} | FP16: {train_args.fp16}\n\n")
+
+            f.write(f"=== {attention_impl.upper()} ===\n")
+            f.write(f"Runtime: {elapsed:.2f}s | Steps: {steps} | Step Time: {avg_step:.4f}s\n")
+            f.write(f"Tokens/s: {tokens_per_step / avg_step:.2f} | TFLOPs/s: {tflops_per_s:.3f}\n")
+
+            f.write(f"MemAlloc Before: {mem_before_alloc / 1024**2:.2f} MiB | MemAlloc After: {mem_after_alloc / 1024**2:.2f} MiB | MemoReserved After: {mem_after_reserved / 1024**2:.2f} MiB\n")
+            f.write(f"Peak MemAlloc (all GPUs): {peak_alloc / 1024**2:.2f} MiB\n\n")
+
+
+# ----------------------------------------
+# 6. CLI Entry Point
+# ----------------------------------------
+
+if __name__ == "__main__":
+    valid_impls = ["torch", "flash", "flash2"]
+    if len(sys.argv) < 2:
+        print("❌ Bitte Attention-Variante als Argument angeben (z. B. torch / flash / flash2)")
+        sys.exit(1)
+
+    attention_impl = sys.argv[1].lower()
+    if attention_impl not in valid_impls:
+        print(f"❌ Ungültige Attention-Variante: '{attention_impl}'")
+        sys.exit(1)
+
+    print(f"\n🚀 Starte Benchmark für Attention-Variante: {attention_impl.upper()}")
+    train_model(attention_impl)
\ No newline at end of file
-- 
GitLab