Skip to content
Snippets Groups Projects
Commit bc7e25df authored by Armin Bacher's avatar Armin Bacher
Browse files

Delete GPT2L_A100.py

parent b87a8d6b
Branches
No related tags found
No related merge requests found
# flashattn2_benchmark_a100.py
# Replikation FlashAttention-2 Paper - Benchmarking auf A100 GPU Cluster
from transformers import GPT2Config, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from flash_attn.models.gpt import GPTLMHeadModel
from transformers import Trainer
import torch.nn.functional as F
import torch
import time
import os
import random, numpy as np
# Fixe Seeds
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
# Dummy ColumnParallelLinear wenn nicht vorhanden
from flash_attn.models import gpt
if not hasattr(gpt, "ColumnParallelLinear") or not isinstance(gpt.ColumnParallelLinear, type):
import torch.nn as nn
class ColumnParallelLinear(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
gpt.ColumnParallelLinear = ColumnParallelLinear
# Custom Trainer
class FlashTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
labels = inputs.pop("labels")
inputs.pop("attention_mask", None)
outputs = model(**inputs)
logits = outputs[0]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return (loss, outputs) if return_outputs else loss
# GPT2 Modell erzeugen (für Benchmark ggf. GPT2-XL / größer)
def get_gpt2_model(attention_impl="torch"):
config = GPT2Config(
n_layer=24, n_head=20, n_embd=1280, vocab_size=50257, # GPT2-Large
n_positions=1024,
resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-5
)
config.attention_config = {
"attn_impl": attention_impl,
"alibi": False,
"rope": True,
"rope_theta": 10000.0,
"use_flash_rotary": True
}
return GPTLMHeadModel(config)
# Tokenizer & Dataset vorbereiten
seq_len = 1024
batch_size = 4
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
raw_dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
def tokenize(example):
return tokenizer(example["text"], truncation=True, padding="max_length", max_length=seq_len)
dataset = raw_dataset["train"].select(range(1024)) # kleines Sample fürs Profiling
tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
def count_model_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def train_model(attention_impl="torch"):
model = get_gpt2_model(attention_impl)
train_args = TrainingArguments(
output_dir=f"./gpt2_{attention_impl}_a100",
overwrite_output_dir=True,
per_device_train_batch_size=batch_size,
num_train_epochs=1,
logging_steps=999999,
report_to="none",
save_strategy="no",
remove_unused_columns=False,
fp16=True # Auf A100 sinnvoll / wie im Paper
)
start_time = time.time()
trainer = FlashTrainer(
model=model,
args=train_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
tokenizer=tokenizer
)
trainer.train()
elapsed = time.time() - start_time
num_params = count_model_params(model)
n_layer = model.config.n_layer
hidden_dim = model.config.n_embd
steps = int(len(tokenized_dataset) / batch_size)
avg_step = elapsed / steps if steps else float('nan')
tokens_per_step = batch_size * seq_len
flops_per_step = (
6 * seq_len * num_params + 12 * n_layer * hidden_dim * seq_len * seq_len
)
flops_total = flops_per_step * batch_size * steps
tflops_per_s = flops_total / (elapsed * 1e12)
output_path = f"benchmark_results_seq{seq_len}_bs{batch_size}.txt"
header_written = os.path.exists(output_path)
with open(output_path, "a") as f:
if not header_written:
f.write("# FlashAttention Benchmark Ergebnisse\n")
f.write(f"Modell: GPT2 | Layers: {model.config.n_layer} | Embedding Dim: {model.config.n_embd}\n")
f.write(f"Sequence Length: {model.config.n_positions} | Batch Size: {train_args.per_device_train_batch_size} | FP16: {train_args.fp16}\n\n")
f.write(f"=== {attention_impl.upper()} ===\n")
f.write(f"Runtime: {elapsed:.2f}s | Steps: {steps} | Step Time: {avg_step:.4f}s\n")
f.write(f"Tokens/s: {tokens_per_step / avg_step:.2f} | TFLOPs/s: {tflops_per_s:.3f}\n\n")
# Benchmark starten
for impl in ["torch", "flash", "flash2"]:
train_model(impl)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment