Skip to content
Snippets Groups Projects
Commit a4c6ff78 authored by Armin Bacher's avatar Armin Bacher
Browse files

Delete GPT-2-Small-1k-opt-v2.py

parent 712a9a0f
Branches
No related tags found
No related merge requests found
import time
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling
from datasets import load_dataset
from torch.utils.data import DataLoader
# --- Einstellungen ---
BATCH_SIZE = 32
SEQ_LEN = 2048
NUM_STEPS = 1000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MIXED_PRECISION = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
# GPT-2 Small Model Parameter
NUM_LAYERS = 12
HIDDEN_SIZE = 768
NUM_HEADS = 12
HEAD_DIM = HIDDEN_SIZE // NUM_HEADS
# Datensatz laden
dataset = load_dataset("wikitext", "wikitext-103-v1", streaming=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=SEQ_LEN)
tokens["labels"] = tokens["input_ids"].copy()
return tokens
tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, desc="Tokenisierung läuft...")
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
dataloader = DataLoader(tokenized_datasets["train"], batch_size=BATCH_SIZE, collate_fn=data_collator)
# Korrekte FLOP-Berechnung (basierend auf Paper-Formel)
def compute_flops(batch_size, seq_len, num_layers, hidden_size, num_params):
flops_weight_input = 6 * seq_len * num_params
flops_attention = 12 * num_layers * hidden_size * seq_len ** 2
return (flops_weight_input + flops_attention) * batch_size
# Debugging aktivieren für CUDA-Fehler
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Benchmark-Funktion mit genauer Zeitmessung
def benchmark_training(model, dataloader, num_steps=NUM_STEPS):
model.to(DEVICE)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
torch.cuda.synchronize()
start_time_total = time.time()
total_forward_time = 0
total_backward_time = 0
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
try:
batch = {k: v.to(DEVICE, non_blocking=True) for k, v in batch.items()}
torch.cuda.synchronize()
start_fwd = time.time()
loss = model(**batch).loss
torch.cuda.synchronize()
total_forward_time += time.time() - start_fwd
torch.cuda.synchronize()
start_bwd = time.time()
loss.backward()
torch.cuda.synchronize()
total_backward_time += time.time() - start_bwd
optimizer.step()
optimizer.zero_grad()
except RuntimeError as e:
print(f"RuntimeError detected: {e}")
continue
total_time = time.time() - start_time_total
tokens_per_second = (num_steps * BATCH_SIZE * SEQ_LEN) / total_time
flops_per_step = compute_flops(BATCH_SIZE, SEQ_LEN, NUM_LAYERS, HIDDEN_SIZE, model.num_parameters())
tflops_per_sec = (flops_per_step * (tokens_per_second / (BATCH_SIZE * SEQ_LEN))) / 1e12
return tokens_per_second, tflops_per_sec, total_forward_time / num_steps, total_backward_time / num_steps
# FlashAttention-2 korrekt integrieren
try:
from flash_attn.flash_attention import FlashAttention2
def replace_attention_with_flash(model):
for module in model.modules():
if isinstance(module, torch.nn.MultiheadAttention):
module.forward = FlashAttention2.apply
except ImportError:
print("FlashAttention-2 nicht installiert. Standard Attention wird verwendet.")
def load_model(attn_type):
model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=MIXED_PRECISION, device_map="auto")
if attn_type == "flash2":
replace_attention_with_flash(model)
return model
results = {}
for attn_type in ["standard", "flash2"]:
print(f"Teste GPT-2 Small mit {attn_type} Attention...")
model = load_model(attn_type)
tokens_per_sec, tflops_per_sec, avg_fwd, avg_bwd = benchmark_training(model, dataloader)
results[attn_type] = (tokens_per_sec, tflops_per_sec, avg_fwd, avg_bwd)
print(f"{attn_type} Attention: {tokens_per_sec:.2f} Tokens/Sek, {tflops_per_sec:.2f} TFLOPS/s")
print(f" Durchschnittliche Forward-Zeit: {avg_fwd:.4f}s, Backward-Zeit: {avg_bwd:.4f}s")
# Ergebnisse ausgeben
print("--- GPT-2-Small-1k-opt-v2 ---")
print(f"BATCH_SIZE: {BATCH_SIZE}, SEQ_LEN: {SEQ_LEN}, NUM_STEPS: {NUM_STEPS}")
print("Endergebnisse:")
for attn_type, (speed, tflops, avg_fwd, avg_bwd) in results.items():
print(f"{attn_type.capitalize()} Attention: {speed:.2f} Tokens/Sek, {tflops:.2f} TFLOPS/s")
print(f" Durchschnittliche Forward-Zeit: {avg_fwd:.4f}s, Backward-Zeit: {avg_bwd:.4f}s")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment