Skip to content
Snippets Groups Projects
Commit 06dab7b4 authored by Armin Bacher's avatar Armin Bacher
Browse files

Delete GPT-2-Small-train.py

parent c9bcc6b6
Branches
No related tags found
No related merge requests found
import time
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer
from datasets import load_dataset
from torch.utils.data import DataLoader
# --- Einstellungen ---
BATCH_SIZE = 16 # Größere Batchgröße für bessere GPU-Nutzung
SEQ_LEN = 1024 # FlashAttention-2 profitiert von längeren Sequenzen
NUM_STEPS = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MIXED_PRECISION = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
# GPT-2 Small Model Parameter
NUM_LAYERS = 12
HIDDEN_SIZE = 768
dataset = load_dataset("wikitext", "wikitext-103-v1")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=SEQ_LEN)
tokens["labels"] = tokens["input_ids"].copy() # Labels = Shifted input_ids
return tokens
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
dataloader = DataLoader(tokenized_datasets["train"], batch_size=BATCH_SIZE, collate_fn=data_collator)
def compute_flops(batch_size, seq_len, num_layers, hidden_size):
return 6 * num_layers * (hidden_size ** 2) * seq_len * batch_size
def benchmark_training(model, dataloader, num_steps=NUM_STEPS):
model.to(DEVICE)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
torch.cuda.empty_cache() # GPU-Cache vor Lauf leeren
torch.cuda.synchronize()
start_time = time.time()
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
batch = {k: v.to(DEVICE, dtype=torch.long) if k in ["input_ids", "labels"] else v.to(DEVICE, dtype=MIXED_PRECISION) for k, v in batch.items()}
loss = model(**batch).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
torch.cuda.synchronize()
total_time = time.time() - start_time
tokens_per_second = (num_steps * BATCH_SIZE * SEQ_LEN) / total_time
flops_per_step = compute_flops(BATCH_SIZE, SEQ_LEN, NUM_LAYERS, HIDDEN_SIZE)
tflops_per_sec = (flops_per_step * (tokens_per_second / (BATCH_SIZE * SEQ_LEN))) / 1e12
return tokens_per_second, tflops_per_sec
results = {}
for attn_type in ["standard", "flash1", "flash2"]:
print(f"Teste GPT-2 Small mit {attn_type} Attention...")
if attn_type == "standard":
model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=MIXED_PRECISION, device_map="auto")
elif attn_type == "flash1":
model = GPT2LMHeadModel.from_pretrained("gpt2", attn_implementation="flash_attention", torch_dtype=MIXED_PRECISION, device_map="auto")
elif attn_type == "flash2":
model = GPT2LMHeadModel.from_pretrained("gpt2", attn_implementation="flash_attention_2", torch_dtype=MIXED_PRECISION, device_map="auto")
tokens_per_sec, tflops_per_sec = benchmark_training(model, dataloader)
results[attn_type] = (tokens_per_sec, tflops_per_sec)
print(f"{attn_type} Attention: {tokens_per_sec:.2f} Tokens/Sekunde, {tflops_per_sec:.2f} TFLOPS/s")
print("Endergebnisse:")
for attn_type, (speed, tflops) in results.items():
print(f"{attn_type.capitalize()} Attention: {speed:.2f} Tokens/Sekunde, {tflops:.2f} TFLOPS/s")
"""
Ergebnisse
BATCH_SIZE = 8
SEQ_LEN = 512
NUM_STEPS = 100
Starte Job GPT2-train-0001
^MMap: 0%| | 0/4358 [00:00<?, ? examples/s]^MMap: 23%|██▎ | 1000/4358 [00:00<00:02, 1>^MMap: 0%| | 0/1801350 [00:00<?, ? examples/s]^MMap: 0%| | 1000/1801350 [00:00<22>^MMap: 0%| | 0/3760 [00:00<?, ? examples/s]^MMap: 27%|██▋ | 1000/3760 [00:00<00:01, 1>You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move th>Teste GPT-2 Small mit standard Attention...
standard Attention: 66882.94 Tokens pro Sekunde
Teste GPT-2 Small mit flash2 Attention...
flash2 Attention: 41922.98 Tokens pro Sekunde
Endergebnisse:
Standard Attention: 66882.94 Tokens pro Sekunde
Flash2 Attention: 41922.98 Tokens pro Sekunde
Job abgeschlossen
"""
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment