diff --git a/Testing/GPT-2-Small.ipynb b/Testing/GPT-2-Small.ipynb deleted file mode 100644 index 7f5f0515320f38b40ccfca9c1aba52eee735b790..0000000000000000000000000000000000000000 --- a/Testing/GPT-2-Small.ipynb +++ /dev/null @@ -1,527 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Importe" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--- Mini-Testsetup für A2000 ---\n", - "DEVICE: cuda, BATCH_SIZE: 4, SEQ_LEN: 256, NUM_STEPS: 10\n" - ] - } - ], - "source": [ - "# Cell 1: Imports und Basiseinstellungen\n", - "import time\n", - "import torch\n", - "import torch.nn as nn\n", - "from transformers import GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling\n", - "from datasets import load_dataset\n", - "from torch.utils.data import DataLoader\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Kleine Test-Einstellungen für A2000\n", - "BATCH_SIZE = 4\n", - "SEQ_LEN = 256\n", - "NUM_STEPS = 10\n", - "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "MIXED_PRECISION = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16\n", - "\n", - "print(\"--- Mini-Testsetup für A2000 ---\")\n", - "print(f\"DEVICE: {DEVICE}, BATCH_SIZE: {BATCH_SIZE}, SEQ_LEN: {SEQ_LEN}, NUM_STEPS: {NUM_STEPS}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Gleiche Trainingsbedingungen sicherstellen" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "import numpy as np\n", - "import torch\n", - "\n", - "def set_seed(seed=42):\n", - " random.seed(seed)\n", - " np.random.seed(seed)\n", - " torch.manual_seed(seed)\n", - " torch.cuda.manual_seed_all(seed)\n", - "\n", - " # Für vollständige Reproduzierbarkeit (optional):\n", - " torch.backends.cudnn.deterministic = True\n", - " torch.backends.cudnn.benchmark = False" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "set_seed(42)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Datensatz + Tokenizer vorbereiten" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# Cell 2: Dataset & Tokenizer vorbereiten\n", - "dataset = load_dataset(\"wikitext\", \"wikitext-103-v1\", split=\"train[:1%]\") # Nur 1% für Speed\n", - "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", - "tokenizer.pad_token = tokenizer.eos_token\n", - "\n", - "def tokenize_function(examples):\n", - " tokens = tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True, max_length=SEQ_LEN)\n", - " tokens[\"attention_mask\"] = torch.tensor(tokens[\"attention_mask\"], dtype=torch.long)\n", - " tokens[\"labels\"] = tokens[\"input_ids\"].copy()\n", - " return tokens\n", - "\n", - "tokenized_dataset = dataset.map(tokenize_function, batched=True)\n", - "tokenized_dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n", - "\n", - "def custom_collate_fn(batch):\n", - " input_ids = torch.stack([example[\"input_ids\"] for example in batch])\n", - " attention_mask = torch.stack([\n", - " example[\"attention_mask\"] if example[\"attention_mask\"].dim() == 1 else example[\"attention_mask\"].squeeze()\n", - " for example in batch\n", - " ])\n", - " labels = torch.stack([example[\"labels\"] for example in batch])\n", - " return {\n", - " \"input_ids\": input_ids,\n", - " \"attention_mask\": attention_mask,\n", - " \"labels\": labels,\n", - " }\n", - "\n", - "dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "FLOP Berechnung" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "# Cell 3: FLOP-Schätzung wie im Paper\n", - "def compute_flops(batch_size, seq_len, num_layers, hidden_size, num_params):\n", - " flops_weight_input = 6 * seq_len * num_params\n", - " flops_attention = 12 * num_layers * hidden_size * seq_len ** 2\n", - " return (flops_weight_input + flops_attention) * batch_size\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Trainings- und Benchmark-Funktion" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "# Cell 6: Trainingsloop mit integrierter Batchprüfung\n", - "\n", - "def benchmark_training(model, dataloader, num_steps=10):\n", - " model.train()\n", - " optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)\n", - " losses = []\n", - " times_fwd = []\n", - " times_bwd = []\n", - "\n", - " def check_batch(batch):\n", - " assert isinstance(batch, dict), \"Batch ist kein dict\"\n", - " for key in [\"input_ids\", \"attention_mask\", \"labels\"]:\n", - " assert key in batch, f\"Fehlender Key: {key}\"\n", - " assert batch[key].dim() == 2, f\"{key} hat {batch[key].dim()} Dimensionen (erwartet: 2)\"\n", - " B, S = batch[key].shape\n", - " assert S == SEQ_LEN, f\"{key} hat falsche Länge {S}, erwartet {SEQ_LEN}\"\n", - " return True\n", - "\n", - " for step, batch in enumerate(dataloader):\n", - " if step >= num_steps:\n", - " break\n", - "\n", - " try:\n", - " check_batch(batch) # 🔍 automatisch prüfen\n", - " except AssertionError as e:\n", - " print(f\"❌ Fehler im Batch (Step {step}): {e}\")\n", - " break\n", - "\n", - " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", - "\n", - " torch.cuda.synchronize()\n", - " start_fwd = time.time()\n", - " outputs = model(**batch)\n", - " loss = outputs.loss\n", - " torch.cuda.synchronize()\n", - " end_fwd = time.time()\n", - "\n", - " start_bwd = time.time()\n", - " loss.backward()\n", - " optimizer.step()\n", - " optimizer.zero_grad()\n", - " torch.cuda.synchronize()\n", - " end_bwd = time.time()\n", - "\n", - " losses.append(loss.item())\n", - " times_fwd.append(end_fwd - start_fwd)\n", - " times_bwd.append(end_bwd - start_bwd)\n", - "\n", - " print(f\"Step {step+1}/{num_steps} - Loss: {loss.item():.4f}\")\n", - "\n", - " # Kennzahlen berechnen\n", - " total_tokens = num_steps * BATCH_SIZE * SEQ_LEN\n", - " total_time = sum(times_fwd) + sum(times_bwd)\n", - " tokens_per_sec = total_tokens / total_time\n", - " avg_fwd = sum(times_fwd) / len(times_fwd)\n", - " avg_bwd = sum(times_bwd) / len(times_bwd)\n", - "\n", - " tflops = (\n", - " 24 * num_steps * BATCH_SIZE * SEQ_LEN * model.config.n_embd * model.config.n_layer\n", - " / (total_time * 1e12)\n", - " )\n", - "\n", - " return losses, tokens_per_sec, tflops, avg_fwd, avg_bwd\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Cell 4: Trainings- und Benchmark-Funktion\n", - "def benchmark_training(model, dataloader, num_steps=NUM_STEPS):\n", - " model.to(DEVICE)\n", - " model.train()\n", - " optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)\n", - "\n", - " total_forward_time = 0\n", - " total_backward_time = 0\n", - " loss_history = []\n", - "\n", - " torch.cuda.synchronize()\n", - " start_time = time.time()\n", - "\n", - " for step, batch in enumerate(dataloader):\n", - " if step >= num_steps:\n", - " break\n", - " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", - "\n", - " torch.cuda.synchronize()\n", - " start_fwd = time.time()\n", - " outputs = model(**batch)\n", - " loss = outputs.loss\n", - " torch.cuda.synchronize()\n", - " total_forward_time += time.time() - start_fwd\n", - "\n", - " torch.cuda.synchronize()\n", - " start_bwd = time.time()\n", - " loss.backward()\n", - " torch.cuda.synchronize()\n", - " total_backward_time += time.time() - start_bwd\n", - "\n", - " optimizer.step()\n", - " optimizer.zero_grad()\n", - " loss_history.append(loss.item())\n", - "\n", - " print(f\"Step {step+1}/{num_steps} - Loss: {loss.item():.4f}\")\n", - "\n", - " total_time = time.time() - start_time\n", - " tokens_per_sec = (num_steps * BATCH_SIZE * SEQ_LEN) / total_time\n", - " flops_per_step = compute_flops(BATCH_SIZE, SEQ_LEN, 12, 768, model.num_parameters())\n", - " tflops_per_sec = (flops_per_step * (tokens_per_sec / (BATCH_SIZE * SEQ_LEN))) / 1e12\n", - "\n", - " return loss_history, tokens_per_sec, tflops_per_sec, total_forward_time / num_steps, total_backward_time / num_steps\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Modell laden" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "# Cell 5: Modell laden mit FlashAttention-2 Patch (oder Standard)\n", - "\n", - "def patch_gpt2_with_flash_attention(model, causal=True):\n", - " \"\"\"\n", - " Replaces GPT2 Attention modules with FlashSelfAttention v2 (unpadded mode with attention_mask).\n", - " \"\"\"\n", - " from flash_attn.modules.mha import FlashSelfAttention\n", - "\n", - " def custom_forward(self, hidden_states, layer_past=None, attention_mask=None,\n", - " head_mask=None, use_cache=False, output_attentions=False):\n", - " B, S, _ = hidden_states.size()\n", - " device = hidden_states.device\n", - "\n", - " # === QKV erstellen ===\n", - " qkv = self.c_attn(hidden_states) # [B, S, 3*hidden]\n", - " qkv = qkv.view(B, S, 3, self.num_heads, self.head_dim) # [B, S, 3, H, D]\n", - " qkv = qkv.permute(0, 2, 1, 3, 4).contiguous() # [B, 3, S, H, D]\n", - "\n", - " # === Attention-Maske korrekt verarbeiten ===\n", - " if attention_mask is not None:\n", - " if attention_mask.dim() == 4:\n", - " attention_mask = attention_mask[:, 0, 0, :] # [B, S]\n", - " elif attention_mask.dim() != 2:\n", - " raise ValueError(f\"⚠️ attention_mask hat ungültige Form: {attention_mask.shape} (erwartet [B, S] oder [B, 1, S, S])\")\n", - " seqlens = attention_mask.sum(dim=1, dtype=torch.int32) # [B]\n", - " else:\n", - " attention_mask = torch.ones((B, S), dtype=torch.int32, device=device)\n", - " seqlens = torch.full((B,), S, dtype=torch.int32, device=device)\n", - "\n", - " # === cu_seqlens ===\n", - " cu_seqlens = torch.zeros(B + 1, dtype=torch.int32, device=device)\n", - " cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)\n", - "\n", - " # === QKV packen ===\n", - " qkv = qkv.permute(0, 2, 1, 3, 4).contiguous() # [B, S, 3, H, D]\n", - " qkv = qkv.view(B * S, 3, self.num_heads, self.head_dim) # [B*S, 3, H, D]\n", - " qkv_packed = qkv[attention_mask.view(-1).bool()] # [L, 3, H, D]\n", - "\n", - " # === FlashAttention aufrufen ===\n", - " output = self.flash(\n", - " qkv_packed,\n", - " cu_seqlens=cu_seqlens,\n", - " max_seqlen=S,\n", - " unpadded=True\n", - " ) # [total_tokens, H, D]\n", - "\n", - " # === Re-padden ===\n", - " output_padded = torch.zeros(B * S, self.num_heads, self.head_dim, device=device)\n", - " output_padded[attention_mask.view(-1).bool()] = output\n", - " output_padded = output_padded.view(B, S, self.num_heads, self.head_dim)\n", - "\n", - " # === zurück in GPT-2 Form ===\n", - " attn_output = output_padded.transpose(1, 2).contiguous().view(B, S, -1)\n", - " attn_output = self.c_proj(attn_output)\n", - " attn_output = self.resid_dropout(attn_output)\n", - "\n", - " return attn_output, None\n", - "\n", - "\n", - " print(f\"🔧 Patching {len(model.transformer.h)} Transformer-Blöcke mit FlashSelfAttention (unpadded, masked)\")\n", - "\n", - " for block in model.transformer.h:\n", - " attn = block.attn\n", - " flash = FlashSelfAttention(causal=causal).to(model.device)\n", - " attn.flash = flash\n", - " attn.forward = custom_forward.__get__(attn, attn.__class__)\n", - "\n", - " print(\"✅ FlashAttention-2 gepatcht mit Maskierung & unpadded-Modus\")\n", - " return model\n", - "\n", - "\n", - "def load_model(attn_type=\"standard\"):\n", - " model = GPT2LMHeadModel.from_pretrained(\"gpt2\", torch_dtype=MIXED_PRECISION).to(DEVICE)\n", - "\n", - " if attn_type == \"flash2\":\n", - " model = patch_gpt2_with_flash_attention(model)\n", - "\n", - " return model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "manueller Test FA-2 implementierung" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Cell 6: Mini-Training starten und auswerten\n", - "#model = load_model()\n", - "model = load_model(attn_type=\"flash2\")\n", - "print(\"Verwendete Attention:\", model.transformer.h[0].attn.flash.__class__)\n", - "losses, tok_sec, tflops, avg_fwd, avg_bwd = benchmark_training(model, dataloader)\n", - "\n", - "print(\"\\n--- Ergebnisse ---\")\n", - "print(f\"Tokens/sec: {tok_sec:.2f}\")\n", - "print(f\"TFLOPs/sec: {tflops:.2f}\")\n", - "print(f\"Ø Forward-Zeit: {avg_fwd:.4f}s\")\n", - "print(f\"Ø Backward-Zeit: {avg_bwd:.4f}s\")\n", - "\n", - "# Optional: Loss-Verlauf plotten\n", - "plt.plot(losses)\n", - "plt.title(\"Loss-Verlauf\")\n", - "plt.xlabel(\"Steps\")\n", - "plt.ylabel(\"Loss\")\n", - "plt.grid(True)\n", - "plt.show()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Vergleichsfunktion" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "def run_comparison(attn_types=(\"standard\", \"flash2\"), steps=100):\n", - " results = {}\n", - "\n", - " for attn in attn_types:\n", - " print(f\"\\n🔁 Benchmarking '{attn}' Attention\")\n", - "\n", - " set_seed(42) # sicherstellen, dass beide Runs identisch starten\n", - " model = load_model(attn_type=attn)\n", - " losses, tok_sec, tflops, avg_fwd, avg_bwd = benchmark_training(model, dataloader, num_steps=steps)\n", - "\n", - " results[attn] = {\n", - " \"tokens/sec\": tok_sec,\n", - " \"tflops/sec\": tflops,\n", - " \"avg_forward_time_s\": avg_fwd,\n", - " \"avg_backward_time_s\": avg_bwd,\n", - " \"final_loss\": losses[-1],\n", - " \"loss_curve\": losses,\n", - " }\n", - "\n", - " print(f\"📊 {attn.upper()} | Tokens/sec: {tok_sec:.2f} | TFLOPs/sec: {tflops:.2f} | \"\n", - " f\"FWD: {avg_fwd:.4f}s | BWD: {avg_bwd:.4f}s | Final Loss: {losses[-1]:.4f}\")\n", - "\n", - " return results" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "🔁 Benchmarking 'standard' Attention\n", - "Step 1/10 - Loss: 8.4946\n", - "Step 2/10 - Loss: 1.4220\n", - "Step 3/10 - Loss: 3.3182\n", - "Step 4/10 - Loss: 1.1913\n", - "Step 5/10 - Loss: 2.3714\n", - "Step 6/10 - Loss: 1.8165\n", - "Step 7/10 - Loss: 1.0560\n", - "Step 8/10 - Loss: 1.5705\n", - "Step 9/10 - Loss: 0.4295\n", - "Step 10/10 - Loss: 1.8563\n", - "📊 STANDARD | Tokens/sec: 10885.90 | TFLOPs/sec: 0.00 | FWD: 0.0271s | BWD: 0.0670s | Final Loss: 1.8563\n", - "\n", - "🔁 Benchmarking 'flash2' Attention\n", - "🔧 Patching 12 Transformer-Blöcke mit FlashSelfAttention (unpadded, masked)\n", - "✅ FlashAttention-2 gepatcht mit Maskierung & unpadded-Modus\n" - ] - }, - { - "ename": "RuntimeError", - "evalue": "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[25]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m results = \u001b[43mrun_comparison\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattn_types\u001b[49m\u001b[43m=\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mstandard\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mflash2\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteps\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m10\u001b[39;49m\u001b[43m)\u001b[49m\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 9\u001b[39m, in \u001b[36mrun_comparison\u001b[39m\u001b[34m(attn_types, steps)\u001b[39m\n\u001b[32m 7\u001b[39m set_seed(\u001b[32m42\u001b[39m) \u001b[38;5;66;03m# sicherstellen, dass beide Runs identisch starten\u001b[39;00m\n\u001b[32m 8\u001b[39m model = load_model(attn_type=attn)\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m losses, tok_sec, tflops, avg_fwd, avg_bwd = \u001b[43mbenchmark_training\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_steps\u001b[49m\u001b[43m=\u001b[49m\u001b[43msteps\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 11\u001b[39m results[attn] = {\n\u001b[32m 12\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mtokens/sec\u001b[39m\u001b[33m\"\u001b[39m: tok_sec,\n\u001b[32m 13\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mtflops/sec\u001b[39m\u001b[33m\"\u001b[39m: tflops,\n\u001b[32m (...)\u001b[39m\u001b[32m 17\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mloss_curve\u001b[39m\u001b[33m\"\u001b[39m: losses,\n\u001b[32m 18\u001b[39m }\n\u001b[32m 20\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m📊 \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mattn.upper()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | Tokens/sec: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtok_sec\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | TFLOPs/sec: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtflops\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 21\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFWD: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mavg_fwd\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33ms | BWD: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mavg_bwd\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33ms | Final Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlosses[-\u001b[32m1\u001b[39m]\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[19]\u001b[39m\u001b[32m, line 33\u001b[39m, in \u001b[36mbenchmark_training\u001b[39m\u001b[34m(model, dataloader, num_steps)\u001b[39m\n\u001b[32m 31\u001b[39m torch.cuda.synchronize()\n\u001b[32m 32\u001b[39m start_fwd = time.time()\n\u001b[32m---> \u001b[39m\u001b[32m33\u001b[39m outputs = \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 34\u001b[39m loss = outputs.loss\n\u001b[32m 35\u001b[39m torch.cuda.synchronize()\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py:1062\u001b[39m, in \u001b[36mGPT2LMHeadModel.forward\u001b[39m\u001b[34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[39m\n\u001b[32m 1054\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33mr\u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 1055\u001b[39m \u001b[33;03mlabels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\u001b[39;00m\n\u001b[32m 1056\u001b[39m \u001b[33;03m Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set\u001b[39;00m\n\u001b[32m 1057\u001b[39m \u001b[33;03m `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`\u001b[39;00m\n\u001b[32m 1058\u001b[39m \u001b[33;03m are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`\u001b[39;00m\n\u001b[32m 1059\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 1060\u001b[39m return_dict = return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m.config.use_return_dict\n\u001b[32m-> \u001b[39m\u001b[32m1062\u001b[39m transformer_outputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1063\u001b[39m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1064\u001b[39m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1065\u001b[39m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1066\u001b[39m \u001b[43m \u001b[49m\u001b[43mtoken_type_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtoken_type_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1067\u001b[39m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1068\u001b[39m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1069\u001b[39m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1070\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m=\u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1071\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1072\u001b[39m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1073\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1074\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1075\u001b[39m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1076\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1077\u001b[39m hidden_states = transformer_outputs[\u001b[32m0\u001b[39m]\n\u001b[32m 1079\u001b[39m \u001b[38;5;66;03m# Set device for model parallelism\u001b[39;00m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py:922\u001b[39m, in \u001b[36mGPT2Model.forward\u001b[39m\u001b[34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[39m\n\u001b[32m 910\u001b[39m outputs = \u001b[38;5;28mself\u001b[39m._gradient_checkpointing_func(\n\u001b[32m 911\u001b[39m block.\u001b[34m__call__\u001b[39m,\n\u001b[32m 912\u001b[39m hidden_states,\n\u001b[32m (...)\u001b[39m\u001b[32m 919\u001b[39m output_attentions,\n\u001b[32m 920\u001b[39m )\n\u001b[32m 921\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m922\u001b[39m outputs = \u001b[43mblock\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 923\u001b[39m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 924\u001b[39m \u001b[43m \u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 925\u001b[39m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 926\u001b[39m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 927\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m=\u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 928\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 929\u001b[39m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 930\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 931\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 933\u001b[39m hidden_states = outputs[\u001b[32m0\u001b[39m]\n\u001b[32m 934\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m use_cache \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py:404\u001b[39m, in \u001b[36mGPT2Block.forward\u001b[39m\u001b[34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[39m\n\u001b[32m 402\u001b[39m residual = hidden_states\n\u001b[32m 403\u001b[39m hidden_states = \u001b[38;5;28mself\u001b[39m.ln_1(hidden_states)\n\u001b[32m--> \u001b[39m\u001b[32m404\u001b[39m attn_outputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mattn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 405\u001b[39m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 406\u001b[39m \u001b[43m \u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 407\u001b[39m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 408\u001b[39m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 409\u001b[39m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 410\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 411\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 412\u001b[39m attn_output = attn_outputs[\u001b[32m0\u001b[39m] \u001b[38;5;66;03m# output_attn: a, present, (attentions)\u001b[39;00m\n\u001b[32m 413\u001b[39m outputs = attn_outputs[\u001b[32m1\u001b[39m:]\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[23]\u001b[39m\u001b[32m, line 37\u001b[39m, in \u001b[36mpatch_gpt2_with_flash_attention.<locals>.custom_forward\u001b[39m\u001b[34m(self, hidden_states, layer_past, attention_mask, head_mask, use_cache, output_attentions)\u001b[39m\n\u001b[32m 35\u001b[39m qkv = qkv.permute(\u001b[32m0\u001b[39m, \u001b[32m2\u001b[39m, \u001b[32m1\u001b[39m, \u001b[32m3\u001b[39m, \u001b[32m4\u001b[39m).contiguous() \u001b[38;5;66;03m# [B, S, 3, H, D]\u001b[39;00m\n\u001b[32m 36\u001b[39m qkv = qkv.view(B * S, \u001b[32m3\u001b[39m, \u001b[38;5;28mself\u001b[39m.num_heads, \u001b[38;5;28mself\u001b[39m.head_dim) \u001b[38;5;66;03m# [B*S, 3, H, D]\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m37\u001b[39m qkv_packed = qkv[\u001b[43mattention_mask\u001b[49m\u001b[43m.\u001b[49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[43m-\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m)\u001b[49m.bool()] \u001b[38;5;66;03m# [L, 3, H, D]\u001b[39;00m\n\u001b[32m 39\u001b[39m \u001b[38;5;66;03m# === FlashAttention aufrufen ===\u001b[39;00m\n\u001b[32m 40\u001b[39m output = \u001b[38;5;28mself\u001b[39m.flash(\n\u001b[32m 41\u001b[39m qkv_packed,\n\u001b[32m 42\u001b[39m cu_seqlens=cu_seqlens,\n\u001b[32m 43\u001b[39m max_seqlen=S,\n\u001b[32m 44\u001b[39m unpadded=\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m 45\u001b[39m ) \u001b[38;5;66;03m# [total_tokens, H, D]\u001b[39;00m\n", - "\u001b[31mRuntimeError\u001b[39m: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead." - ] - } - ], - "source": [ - "results = run_comparison(attn_types=(\"standard\", \"flash2\"), steps=10)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "plt.figure(figsize=(10,5))\n", - "for attn_type, data in results.items():\n", - " plt.plot(data[\"loss_curve\"], label=attn_type)\n", - "\n", - "plt.xlabel(\"Training Step\")\n", - "plt.ylabel(\"Loss\")\n", - "plt.title(\"Loss-Verlauf: standard vs flash2\")\n", - "plt.legend()\n", - "plt.grid(True)\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}