Skip to content
Snippets Groups Projects
Commit cf3564c6 authored by Armin Bacher's avatar Armin Bacher
Browse files

Add GPT2_TrainingBenchmark.py

parent 16a320df
No related branches found
No related tags found
No related merge requests found
# GPT2 Benchmark: Replikation FlashAttention-2 Paper auf A100 GPUs
# ---------------------------------------------------------------
# Dieser Code benchmarkt das Training von GPT2-Medium auf dem WikiText-103-Datensatz,
# mit verschiedenen Attention-Implementierungen (torch, flash, flash2).
# Ziel: Vergleich von Laufzeit, Speicherverbrauch und FLOPs mit Fokus auf FlashAttention-2.
import os
import sys
import time
import torch
import random
import numpy as np
import torch.nn.functional as F
from datasets import load_dataset
from transformers import GPT2Config, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling, Trainer
from accelerate import Accelerator
from flash_attn.models.gpt import GPTLMHeadModel
# ----------------------------------------
# 1. Vorbereitung & Konfiguration
# ----------------------------------------
# Setzt CUDA-Alloc-Konfiguration zur Fragmentierungsvermeidung
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Initialisiere DDP-Kompatibilität via Accelerate
accelerator = Accelerator()
# Seed für Reproduzierbarkeit
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
# Dummy ColumnParallelLinear wenn nicht vorhanden (Workaround für FlashAttention)
from flash_attn.models import gpt
if not hasattr(gpt, "ColumnParallelLinear") or not isinstance(gpt.ColumnParallelLinear, type):
import torch.nn as nn
class ColumnParallelLinear(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
gpt.ColumnParallelLinear = ColumnParallelLinear
# ----------------------------------------
# 2. Modell-Definition
# ----------------------------------------
# GPT2 Modell erzeugen
def get_gpt2_model(attention_impl="torch"):
config = GPT2Config(
n_layer=24, n_head=16, n_embd=1024, vocab_size=50257, # GPT2-Medium
n_positions=1024,
resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-5
)
config.attention_config = {
"attn_impl": attention_impl,
"alibi": False,
"rope": True,
"rope_theta": 10000.0,
"use_flash_rotary": True
}
return GPTLMHeadModel(config)
# ----------------------------------------
# 3. Tokenizer & Daten
# ----------------------------------------
seq_len = 1024
per_device_batch_size = 8
target_steps = 10_000
# Lade Tokenizer und setze Padding-Token
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# Dataset laden
raw_dataset = load_dataset("wikitext", "wikitext-103-v1")
global_batch_size = per_device_batch_size * torch.cuda.device_count()
dataset_size = target_steps * global_batch_size
# Tokenisierung
def tokenize(example):
return tokenizer(example["text"], truncation=True, padding="max_length", max_length=seq_len)
dataset = raw_dataset["train"].select(range(dataset_size))
tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])
# Datacollator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# ----------------------------------------
# 4. Custom Trainer mit GPT2-spezifischem Loss
# ----------------------------------------
# Custom Trainer
class FlashTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
labels = inputs.pop("labels")
inputs.pop("attention_mask", None)
outputs = model(**inputs)
logits = outputs[0]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return (loss, outputs) if return_outputs else loss
# ----------------------------------------
# 5. Training & Benchmarking
# ----------------------------------------
def count_model_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Globale Speichererfassung
def get_memory_summary():
total_allocated = 0
total_reserved = 0
peak_allocated = 0
for device in range(torch.cuda.device_count()):
total_allocated += torch.cuda.memory_allocated(device)
total_reserved += torch.cuda.memory_reserved(device)
peak_allocated += torch.cuda.max_memory_allocated(device)
return total_allocated, total_reserved, peak_allocated
def train_model(attention_impl="torch"):
model = get_gpt2_model(attention_impl)
config = model.config
model = accelerator.prepare(model)
# Hole Speicherstatistiken vor dem Training
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
mem_before_alloc, _, _ = get_memory_summary()
train_args = TrainingArguments(
output_dir=f"./gpt2_{attention_impl}_a100",
overwrite_output_dir=True,
per_device_train_batch_size=per_device_batch_size,
num_train_epochs=1,
logging_steps=999999,
report_to="none",
save_strategy="no",
remove_unused_columns=False,
fp16=True,
dataloader_pin_memory=True,
dataloader_num_workers=4,
ddp_find_unused_parameters=False,
)
world_size = int(os.environ.get("WORLD_SIZE", "1"))
global_batch_size = per_device_batch_size * world_size
start_time = time.time()
trainer = FlashTrainer(
model=model,
args=train_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
tokenizer=tokenizer
)
trainer.train()
# Hole Speicherstatistiken nach dem Training
mem_after_alloc, mem_after_reserved, peak_alloc = get_memory_summary()
elapsed = time.time() - start_time
num_params = count_model_params(model)
n_layer = config.n_layer
hidden_dim = config.n_embd
steps = int(len(tokenized_dataset) / per_device_batch_size / world_size)
avg_step = elapsed / steps if steps else float('nan')
tokens_per_step = per_device_batch_size * seq_len
# FLOP-Schätzung basierend auf theoretischen Tokens/s:
flops_per_step = (
6 * seq_len * num_params + 12 * n_layer * hidden_dim * seq_len * seq_len
)
flops_total = flops_per_step * per_device_batch_size * steps
tflops_per_s = flops_total / (elapsed * 1e12)
# Logging
output_path = f"benchmark_2GPU_embd{config.n_embd}_seq{seq_len}_bs{per_device_batch_size}.txt"
is_main_process = int(os.environ.get("RANK", 0)) == 0
if is_main_process:
with open(output_path, "a") as f:
f.write("# FlashAttention Benchmark Ergebnisse\n")
f.write(f"Modell: GPT2 | Layers: {config.n_layer} | n_head: {config.n_head} | Embedding Dim: {config.n_embd}\n")
f.write(f"Sequence Length: {config.n_positions} | Batch Size: {train_args.per_device_train_batch_size} | Effective Batch Size (global): {global_batch_size} | FP16: {train_args.fp16}\n\n")
f.write(f"=== {attention_impl.upper()} ===\n")
f.write(f"Runtime: {elapsed:.2f}s | Steps: {steps} | Step Time: {avg_step:.4f}s\n")
f.write(f"Tokens/s: {tokens_per_step / avg_step:.2f} | TFLOPs/s: {tflops_per_s:.3f}\n")
f.write(f"MemAlloc Before: {mem_before_alloc / 1024**2:.2f} MiB | MemAlloc After: {mem_after_alloc / 1024**2:.2f} MiB | MemoReserved After: {mem_after_reserved / 1024**2:.2f} MiB\n")
f.write(f"Peak MemAlloc (all GPUs): {peak_alloc / 1024**2:.2f} MiB\n\n")
# ----------------------------------------
# 6. CLI Entry Point
# ----------------------------------------
if __name__ == "__main__":
valid_impls = ["torch", "flash", "flash2"]
if len(sys.argv) < 2:
print("❌ Bitte Attention-Variante als Argument angeben (z. B. torch / flash / flash2)")
sys.exit(1)
attention_impl = sys.argv[1].lower()
if attention_impl not in valid_impls:
print(f"❌ Ungültige Attention-Variante: '{attention_impl}'")
sys.exit(1)
print(f"\n🚀 Starte Benchmark für Attention-Variante: {attention_impl.upper()}")
train_model(attention_impl)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment