From 1d888d709ef57f3d1a1d200ae3f16758bfdd1673 Mon Sep 17 00:00:00 2001
From: Armin Bacher <armin.bacher@student.uni-halle.de>
Date: Mon, 31 Mar 2025 21:51:51 +0000
Subject: [PATCH] Delete GPT-2-Small.ipynb

---
 Testing/GPT-2-Small.ipynb | 527 --------------------------------------
 1 file changed, 527 deletions(-)
 delete mode 100644 Testing/GPT-2-Small.ipynb

diff --git a/Testing/GPT-2-Small.ipynb b/Testing/GPT-2-Small.ipynb
deleted file mode 100644
index 7f5f051..0000000
--- a/Testing/GPT-2-Small.ipynb
+++ /dev/null
@@ -1,527 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Importe"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "--- Mini-Testsetup für A2000 ---\n",
-      "DEVICE: cuda, BATCH_SIZE: 4, SEQ_LEN: 256, NUM_STEPS: 10\n"
-     ]
-    }
-   ],
-   "source": [
-    "# Cell 1: Imports und Basiseinstellungen\n",
-    "import time\n",
-    "import torch\n",
-    "import torch.nn as nn\n",
-    "from transformers import GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling\n",
-    "from datasets import load_dataset\n",
-    "from torch.utils.data import DataLoader\n",
-    "import matplotlib.pyplot as plt\n",
-    "\n",
-    "# Kleine Test-Einstellungen für A2000\n",
-    "BATCH_SIZE = 4\n",
-    "SEQ_LEN = 256\n",
-    "NUM_STEPS = 10\n",
-    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
-    "MIXED_PRECISION = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16\n",
-    "\n",
-    "print(\"--- Mini-Testsetup für A2000 ---\")\n",
-    "print(f\"DEVICE: {DEVICE}, BATCH_SIZE: {BATCH_SIZE}, SEQ_LEN: {SEQ_LEN}, NUM_STEPS: {NUM_STEPS}\")\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Gleiche Trainingsbedingungen sicherstellen"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import random\n",
-    "import numpy as np\n",
-    "import torch\n",
-    "\n",
-    "def set_seed(seed=42):\n",
-    "    random.seed(seed)\n",
-    "    np.random.seed(seed)\n",
-    "    torch.manual_seed(seed)\n",
-    "    torch.cuda.manual_seed_all(seed)\n",
-    "\n",
-    "    # Für vollständige Reproduzierbarkeit (optional):\n",
-    "    torch.backends.cudnn.deterministic = True\n",
-    "    torch.backends.cudnn.benchmark = False"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "set_seed(42)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Datensatz + Tokenizer vorbereiten"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Cell 2: Dataset & Tokenizer vorbereiten\n",
-    "dataset = load_dataset(\"wikitext\", \"wikitext-103-v1\", split=\"train[:1%]\")  # Nur 1% für Speed\n",
-    "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
-    "tokenizer.pad_token = tokenizer.eos_token\n",
-    "\n",
-    "def tokenize_function(examples):\n",
-    "    tokens = tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True, max_length=SEQ_LEN)\n",
-    "    tokens[\"attention_mask\"] = torch.tensor(tokens[\"attention_mask\"], dtype=torch.long)\n",
-    "    tokens[\"labels\"] = tokens[\"input_ids\"].copy()\n",
-    "    return tokens\n",
-    "\n",
-    "tokenized_dataset = dataset.map(tokenize_function, batched=True)\n",
-    "tokenized_dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
-    "\n",
-    "def custom_collate_fn(batch):\n",
-    "    input_ids = torch.stack([example[\"input_ids\"] for example in batch])\n",
-    "    attention_mask = torch.stack([\n",
-    "        example[\"attention_mask\"] if example[\"attention_mask\"].dim() == 1 else example[\"attention_mask\"].squeeze()\n",
-    "        for example in batch\n",
-    "    ])\n",
-    "    labels = torch.stack([example[\"labels\"] for example in batch])\n",
-    "    return {\n",
-    "        \"input_ids\": input_ids,\n",
-    "        \"attention_mask\": attention_mask,\n",
-    "        \"labels\": labels,\n",
-    "    }\n",
-    "\n",
-    "dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "FLOP Berechnung"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Cell 3: FLOP-Schätzung wie im Paper\n",
-    "def compute_flops(batch_size, seq_len, num_layers, hidden_size, num_params):\n",
-    "    flops_weight_input = 6 * seq_len * num_params\n",
-    "    flops_attention = 12 * num_layers * hidden_size * seq_len ** 2\n",
-    "    return (flops_weight_input + flops_attention) * batch_size\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Trainings- und Benchmark-Funktion"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 19,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Cell 6: Trainingsloop mit integrierter Batchprüfung\n",
-    "\n",
-    "def benchmark_training(model, dataloader, num_steps=10):\n",
-    "    model.train()\n",
-    "    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
-    "    losses = []\n",
-    "    times_fwd = []\n",
-    "    times_bwd = []\n",
-    "\n",
-    "    def check_batch(batch):\n",
-    "        assert isinstance(batch, dict), \"Batch ist kein dict\"\n",
-    "        for key in [\"input_ids\", \"attention_mask\", \"labels\"]:\n",
-    "            assert key in batch, f\"Fehlender Key: {key}\"\n",
-    "            assert batch[key].dim() == 2, f\"{key} hat {batch[key].dim()} Dimensionen (erwartet: 2)\"\n",
-    "            B, S = batch[key].shape\n",
-    "            assert S == SEQ_LEN, f\"{key} hat falsche Länge {S}, erwartet {SEQ_LEN}\"\n",
-    "        return True\n",
-    "\n",
-    "    for step, batch in enumerate(dataloader):\n",
-    "        if step >= num_steps:\n",
-    "            break\n",
-    "\n",
-    "        try:\n",
-    "            check_batch(batch)  # 🔍 automatisch prüfen\n",
-    "        except AssertionError as e:\n",
-    "            print(f\"❌ Fehler im Batch (Step {step}): {e}\")\n",
-    "            break\n",
-    "\n",
-    "        batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
-    "\n",
-    "        torch.cuda.synchronize()\n",
-    "        start_fwd = time.time()\n",
-    "        outputs = model(**batch)\n",
-    "        loss = outputs.loss\n",
-    "        torch.cuda.synchronize()\n",
-    "        end_fwd = time.time()\n",
-    "\n",
-    "        start_bwd = time.time()\n",
-    "        loss.backward()\n",
-    "        optimizer.step()\n",
-    "        optimizer.zero_grad()\n",
-    "        torch.cuda.synchronize()\n",
-    "        end_bwd = time.time()\n",
-    "\n",
-    "        losses.append(loss.item())\n",
-    "        times_fwd.append(end_fwd - start_fwd)\n",
-    "        times_bwd.append(end_bwd - start_bwd)\n",
-    "\n",
-    "        print(f\"Step {step+1}/{num_steps} - Loss: {loss.item():.4f}\")\n",
-    "\n",
-    "    # Kennzahlen berechnen\n",
-    "    total_tokens = num_steps * BATCH_SIZE * SEQ_LEN\n",
-    "    total_time = sum(times_fwd) + sum(times_bwd)\n",
-    "    tokens_per_sec = total_tokens / total_time\n",
-    "    avg_fwd = sum(times_fwd) / len(times_fwd)\n",
-    "    avg_bwd = sum(times_bwd) / len(times_bwd)\n",
-    "\n",
-    "    tflops = (\n",
-    "        24 * num_steps * BATCH_SIZE * SEQ_LEN * model.config.n_embd * model.config.n_layer\n",
-    "        / (total_time * 1e12)\n",
-    "    )\n",
-    "\n",
-    "    return losses, tokens_per_sec, tflops, avg_fwd, avg_bwd\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Cell 4: Trainings- und Benchmark-Funktion\n",
-    "def benchmark_training(model, dataloader, num_steps=NUM_STEPS):\n",
-    "    model.to(DEVICE)\n",
-    "    model.train()\n",
-    "    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
-    "\n",
-    "    total_forward_time = 0\n",
-    "    total_backward_time = 0\n",
-    "    loss_history = []\n",
-    "\n",
-    "    torch.cuda.synchronize()\n",
-    "    start_time = time.time()\n",
-    "\n",
-    "    for step, batch in enumerate(dataloader):\n",
-    "        if step >= num_steps:\n",
-    "            break\n",
-    "        batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
-    "\n",
-    "        torch.cuda.synchronize()\n",
-    "        start_fwd = time.time()\n",
-    "        outputs = model(**batch)\n",
-    "        loss = outputs.loss\n",
-    "        torch.cuda.synchronize()\n",
-    "        total_forward_time += time.time() - start_fwd\n",
-    "\n",
-    "        torch.cuda.synchronize()\n",
-    "        start_bwd = time.time()\n",
-    "        loss.backward()\n",
-    "        torch.cuda.synchronize()\n",
-    "        total_backward_time += time.time() - start_bwd\n",
-    "\n",
-    "        optimizer.step()\n",
-    "        optimizer.zero_grad()\n",
-    "        loss_history.append(loss.item())\n",
-    "\n",
-    "        print(f\"Step {step+1}/{num_steps} - Loss: {loss.item():.4f}\")\n",
-    "\n",
-    "    total_time = time.time() - start_time\n",
-    "    tokens_per_sec = (num_steps * BATCH_SIZE * SEQ_LEN) / total_time\n",
-    "    flops_per_step = compute_flops(BATCH_SIZE, SEQ_LEN, 12, 768, model.num_parameters())\n",
-    "    tflops_per_sec = (flops_per_step * (tokens_per_sec / (BATCH_SIZE * SEQ_LEN))) / 1e12\n",
-    "\n",
-    "    return loss_history, tokens_per_sec, tflops_per_sec, total_forward_time / num_steps, total_backward_time / num_steps\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Modell laden"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 23,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Cell 5: Modell laden mit FlashAttention-2 Patch (oder Standard)\n",
-    "\n",
-    "def patch_gpt2_with_flash_attention(model, causal=True):\n",
-    "    \"\"\"\n",
-    "    Replaces GPT2 Attention modules with FlashSelfAttention v2 (unpadded mode with attention_mask).\n",
-    "    \"\"\"\n",
-    "    from flash_attn.modules.mha import FlashSelfAttention\n",
-    "\n",
-    "    def custom_forward(self, hidden_states, layer_past=None, attention_mask=None,\n",
-    "                   head_mask=None, use_cache=False, output_attentions=False):\n",
-    "        B, S, _ = hidden_states.size()\n",
-    "        device = hidden_states.device\n",
-    "\n",
-    "        # === QKV erstellen ===\n",
-    "        qkv = self.c_attn(hidden_states)  # [B, S, 3*hidden]\n",
-    "        qkv = qkv.view(B, S, 3, self.num_heads, self.head_dim)  # [B, S, 3, H, D]\n",
-    "        qkv = qkv.permute(0, 2, 1, 3, 4).contiguous()  # [B, 3, S, H, D]\n",
-    "\n",
-    "        # === Attention-Maske korrekt verarbeiten ===\n",
-    "        if attention_mask is not None:\n",
-    "            if attention_mask.dim() == 4:\n",
-    "                attention_mask = attention_mask[:, 0, 0, :]  # [B, S]\n",
-    "            elif attention_mask.dim() != 2:\n",
-    "                raise ValueError(f\"⚠️ attention_mask hat ungültige Form: {attention_mask.shape} (erwartet [B, S] oder [B, 1, S, S])\")\n",
-    "            seqlens = attention_mask.sum(dim=1, dtype=torch.int32)  # [B]\n",
-    "        else:\n",
-    "            attention_mask = torch.ones((B, S), dtype=torch.int32, device=device)\n",
-    "            seqlens = torch.full((B,), S, dtype=torch.int32, device=device)\n",
-    "\n",
-    "        # === cu_seqlens ===\n",
-    "        cu_seqlens = torch.zeros(B + 1, dtype=torch.int32, device=device)\n",
-    "        cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)\n",
-    "\n",
-    "        # === QKV packen ===\n",
-    "        qkv = qkv.permute(0, 2, 1, 3, 4).contiguous()        # [B, S, 3, H, D]\n",
-    "        qkv = qkv.view(B * S, 3, self.num_heads, self.head_dim)  # [B*S, 3, H, D]\n",
-    "        qkv_packed = qkv[attention_mask.view(-1).bool()]     # [L, 3, H, D]\n",
-    "\n",
-    "        # === FlashAttention aufrufen ===\n",
-    "        output = self.flash(\n",
-    "            qkv_packed,\n",
-    "            cu_seqlens=cu_seqlens,\n",
-    "            max_seqlen=S,\n",
-    "            unpadded=True\n",
-    "        )  # [total_tokens, H, D]\n",
-    "\n",
-    "        # === Re-padden ===\n",
-    "        output_padded = torch.zeros(B * S, self.num_heads, self.head_dim, device=device)\n",
-    "        output_padded[attention_mask.view(-1).bool()] = output\n",
-    "        output_padded = output_padded.view(B, S, self.num_heads, self.head_dim)\n",
-    "\n",
-    "        # === zurück in GPT-2 Form ===\n",
-    "        attn_output = output_padded.transpose(1, 2).contiguous().view(B, S, -1)\n",
-    "        attn_output = self.c_proj(attn_output)\n",
-    "        attn_output = self.resid_dropout(attn_output)\n",
-    "\n",
-    "        return attn_output, None\n",
-    "\n",
-    "\n",
-    "    print(f\"🔧 Patching {len(model.transformer.h)} Transformer-Blöcke mit FlashSelfAttention (unpadded, masked)\")\n",
-    "\n",
-    "    for block in model.transformer.h:\n",
-    "        attn = block.attn\n",
-    "        flash = FlashSelfAttention(causal=causal).to(model.device)\n",
-    "        attn.flash = flash\n",
-    "        attn.forward = custom_forward.__get__(attn, attn.__class__)\n",
-    "\n",
-    "    print(\"✅ FlashAttention-2 gepatcht mit Maskierung & unpadded-Modus\")\n",
-    "    return model\n",
-    "\n",
-    "\n",
-    "def load_model(attn_type=\"standard\"):\n",
-    "    model = GPT2LMHeadModel.from_pretrained(\"gpt2\", torch_dtype=MIXED_PRECISION).to(DEVICE)\n",
-    "\n",
-    "    if attn_type == \"flash2\":\n",
-    "        model = patch_gpt2_with_flash_attention(model)\n",
-    "\n",
-    "    return model"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "manueller Test FA-2 implementierung"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Cell 6: Mini-Training starten und auswerten\n",
-    "#model = load_model()\n",
-    "model = load_model(attn_type=\"flash2\")\n",
-    "print(\"Verwendete Attention:\", model.transformer.h[0].attn.flash.__class__)\n",
-    "losses, tok_sec, tflops, avg_fwd, avg_bwd = benchmark_training(model, dataloader)\n",
-    "\n",
-    "print(\"\\n--- Ergebnisse ---\")\n",
-    "print(f\"Tokens/sec: {tok_sec:.2f}\")\n",
-    "print(f\"TFLOPs/sec: {tflops:.2f}\")\n",
-    "print(f\"Ø Forward-Zeit: {avg_fwd:.4f}s\")\n",
-    "print(f\"Ø Backward-Zeit: {avg_bwd:.4f}s\")\n",
-    "\n",
-    "# Optional: Loss-Verlauf plotten\n",
-    "plt.plot(losses)\n",
-    "plt.title(\"Loss-Verlauf\")\n",
-    "plt.xlabel(\"Steps\")\n",
-    "plt.ylabel(\"Loss\")\n",
-    "plt.grid(True)\n",
-    "plt.show()\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Vergleichsfunktion"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 24,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def run_comparison(attn_types=(\"standard\", \"flash2\"), steps=100):\n",
-    "    results = {}\n",
-    "\n",
-    "    for attn in attn_types:\n",
-    "        print(f\"\\n🔁 Benchmarking '{attn}' Attention\")\n",
-    "\n",
-    "        set_seed(42)  # sicherstellen, dass beide Runs identisch starten\n",
-    "        model = load_model(attn_type=attn)\n",
-    "        losses, tok_sec, tflops, avg_fwd, avg_bwd = benchmark_training(model, dataloader, num_steps=steps)\n",
-    "\n",
-    "        results[attn] = {\n",
-    "            \"tokens/sec\": tok_sec,\n",
-    "            \"tflops/sec\": tflops,\n",
-    "            \"avg_forward_time_s\": avg_fwd,\n",
-    "            \"avg_backward_time_s\": avg_bwd,\n",
-    "            \"final_loss\": losses[-1],\n",
-    "            \"loss_curve\": losses,\n",
-    "        }\n",
-    "\n",
-    "        print(f\"📊 {attn.upper()} | Tokens/sec: {tok_sec:.2f} | TFLOPs/sec: {tflops:.2f} | \"\n",
-    "              f\"FWD: {avg_fwd:.4f}s | BWD: {avg_bwd:.4f}s | Final Loss: {losses[-1]:.4f}\")\n",
-    "\n",
-    "    return results"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 25,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "\n",
-      "🔁 Benchmarking 'standard' Attention\n",
-      "Step 1/10 - Loss: 8.4946\n",
-      "Step 2/10 - Loss: 1.4220\n",
-      "Step 3/10 - Loss: 3.3182\n",
-      "Step 4/10 - Loss: 1.1913\n",
-      "Step 5/10 - Loss: 2.3714\n",
-      "Step 6/10 - Loss: 1.8165\n",
-      "Step 7/10 - Loss: 1.0560\n",
-      "Step 8/10 - Loss: 1.5705\n",
-      "Step 9/10 - Loss: 0.4295\n",
-      "Step 10/10 - Loss: 1.8563\n",
-      "📊 STANDARD | Tokens/sec: 10885.90 | TFLOPs/sec: 0.00 | FWD: 0.0271s | BWD: 0.0670s | Final Loss: 1.8563\n",
-      "\n",
-      "🔁 Benchmarking 'flash2' Attention\n",
-      "🔧 Patching 12 Transformer-Blöcke mit FlashSelfAttention (unpadded, masked)\n",
-      "✅ FlashAttention-2 gepatcht mit Maskierung & unpadded-Modus\n"
-     ]
-    },
-    {
-     "ename": "RuntimeError",
-     "evalue": "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
-      "\u001b[31mRuntimeError\u001b[39m                              Traceback (most recent call last)",
-      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[25]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m results = \u001b[43mrun_comparison\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattn_types\u001b[49m\u001b[43m=\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mstandard\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mflash2\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteps\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m10\u001b[39;49m\u001b[43m)\u001b[49m\n",
-      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 9\u001b[39m, in \u001b[36mrun_comparison\u001b[39m\u001b[34m(attn_types, steps)\u001b[39m\n\u001b[32m      7\u001b[39m set_seed(\u001b[32m42\u001b[39m)  \u001b[38;5;66;03m# sicherstellen, dass beide Runs identisch starten\u001b[39;00m\n\u001b[32m      8\u001b[39m model = load_model(attn_type=attn)\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m losses, tok_sec, tflops, avg_fwd, avg_bwd = \u001b[43mbenchmark_training\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_steps\u001b[49m\u001b[43m=\u001b[49m\u001b[43msteps\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     11\u001b[39m results[attn] = {\n\u001b[32m     12\u001b[39m     \u001b[33m\"\u001b[39m\u001b[33mtokens/sec\u001b[39m\u001b[33m\"\u001b[39m: tok_sec,\n\u001b[32m     13\u001b[39m     \u001b[33m\"\u001b[39m\u001b[33mtflops/sec\u001b[39m\u001b[33m\"\u001b[39m: tflops,\n\u001b[32m   (...)\u001b[39m\u001b[32m     17\u001b[39m     \u001b[33m\"\u001b[39m\u001b[33mloss_curve\u001b[39m\u001b[33m\"\u001b[39m: losses,\n\u001b[32m     18\u001b[39m }\n\u001b[32m     20\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m📊 \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mattn.upper()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | Tokens/sec: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtok_sec\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | TFLOPs/sec: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtflops\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m     21\u001b[39m       \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFWD: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mavg_fwd\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33ms | BWD: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mavg_bwd\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33ms | Final Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlosses[-\u001b[32m1\u001b[39m]\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n",
-      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[19]\u001b[39m\u001b[32m, line 33\u001b[39m, in \u001b[36mbenchmark_training\u001b[39m\u001b[34m(model, dataloader, num_steps)\u001b[39m\n\u001b[32m     31\u001b[39m torch.cuda.synchronize()\n\u001b[32m     32\u001b[39m start_fwd = time.time()\n\u001b[32m---> \u001b[39m\u001b[32m33\u001b[39m outputs = \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     34\u001b[39m loss = outputs.loss\n\u001b[32m     35\u001b[39m torch.cuda.synchronize()\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1737\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m   1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m   1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m   1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m   1748\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m   1749\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m   1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py:1062\u001b[39m, in \u001b[36mGPT2LMHeadModel.forward\u001b[39m\u001b[34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[39m\n\u001b[32m   1054\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33mr\u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m   1055\u001b[39m \u001b[33;03mlabels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\u001b[39;00m\n\u001b[32m   1056\u001b[39m \u001b[33;03m    Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set\u001b[39;00m\n\u001b[32m   1057\u001b[39m \u001b[33;03m    `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`\u001b[39;00m\n\u001b[32m   1058\u001b[39m \u001b[33;03m    are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`\u001b[39;00m\n\u001b[32m   1059\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m   1060\u001b[39m return_dict = return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m.config.use_return_dict\n\u001b[32m-> \u001b[39m\u001b[32m1062\u001b[39m transformer_outputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m   1063\u001b[39m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1064\u001b[39m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1065\u001b[39m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1066\u001b[39m \u001b[43m    \u001b[49m\u001b[43mtoken_type_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtoken_type_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1067\u001b[39m \u001b[43m    \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1068\u001b[39m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1069\u001b[39m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1070\u001b[39m \u001b[43m    \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m=\u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1071\u001b[39m \u001b[43m    \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1072\u001b[39m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1073\u001b[39m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1074\u001b[39m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1075\u001b[39m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   1076\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1077\u001b[39m hidden_states = transformer_outputs[\u001b[32m0\u001b[39m]\n\u001b[32m   1079\u001b[39m \u001b[38;5;66;03m# Set device for model parallelism\u001b[39;00m\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1737\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m   1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m   1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m   1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m   1748\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m   1749\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m   1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py:922\u001b[39m, in \u001b[36mGPT2Model.forward\u001b[39m\u001b[34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[39m\n\u001b[32m    910\u001b[39m     outputs = \u001b[38;5;28mself\u001b[39m._gradient_checkpointing_func(\n\u001b[32m    911\u001b[39m         block.\u001b[34m__call__\u001b[39m,\n\u001b[32m    912\u001b[39m         hidden_states,\n\u001b[32m   (...)\u001b[39m\u001b[32m    919\u001b[39m         output_attentions,\n\u001b[32m    920\u001b[39m     )\n\u001b[32m    921\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m922\u001b[39m     outputs = \u001b[43mblock\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    923\u001b[39m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    924\u001b[39m \u001b[43m        \u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    925\u001b[39m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    926\u001b[39m \u001b[43m        \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    927\u001b[39m \u001b[43m        \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m=\u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    928\u001b[39m \u001b[43m        \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    929\u001b[39m \u001b[43m        \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    930\u001b[39m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    931\u001b[39m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    933\u001b[39m hidden_states = outputs[\u001b[32m0\u001b[39m]\n\u001b[32m    934\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m use_cache \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1737\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m   1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m   1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m   1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m   1748\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m   1749\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m   1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py:404\u001b[39m, in \u001b[36mGPT2Block.forward\u001b[39m\u001b[34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[39m\n\u001b[32m    402\u001b[39m residual = hidden_states\n\u001b[32m    403\u001b[39m hidden_states = \u001b[38;5;28mself\u001b[39m.ln_1(hidden_states)\n\u001b[32m--> \u001b[39m\u001b[32m404\u001b[39m attn_outputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mattn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    405\u001b[39m \u001b[43m    \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    406\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    407\u001b[39m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    408\u001b[39m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    409\u001b[39m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    410\u001b[39m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    411\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    412\u001b[39m attn_output = attn_outputs[\u001b[32m0\u001b[39m]  \u001b[38;5;66;03m# output_attn: a, present, (attentions)\u001b[39;00m\n\u001b[32m    413\u001b[39m outputs = attn_outputs[\u001b[32m1\u001b[39m:]\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1737\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m   1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
-      "\u001b[36mFile \u001b[39m\u001b[32m~/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m   1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m   1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m   1748\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m   1749\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m   1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
-      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[23]\u001b[39m\u001b[32m, line 37\u001b[39m, in \u001b[36mpatch_gpt2_with_flash_attention.<locals>.custom_forward\u001b[39m\u001b[34m(self, hidden_states, layer_past, attention_mask, head_mask, use_cache, output_attentions)\u001b[39m\n\u001b[32m     35\u001b[39m qkv = qkv.permute(\u001b[32m0\u001b[39m, \u001b[32m2\u001b[39m, \u001b[32m1\u001b[39m, \u001b[32m3\u001b[39m, \u001b[32m4\u001b[39m).contiguous()        \u001b[38;5;66;03m# [B, S, 3, H, D]\u001b[39;00m\n\u001b[32m     36\u001b[39m qkv = qkv.view(B * S, \u001b[32m3\u001b[39m, \u001b[38;5;28mself\u001b[39m.num_heads, \u001b[38;5;28mself\u001b[39m.head_dim)  \u001b[38;5;66;03m# [B*S, 3, H, D]\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m37\u001b[39m qkv_packed = qkv[\u001b[43mattention_mask\u001b[49m\u001b[43m.\u001b[49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[43m-\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m)\u001b[49m.bool()]     \u001b[38;5;66;03m# [L, 3, H, D]\u001b[39;00m\n\u001b[32m     39\u001b[39m \u001b[38;5;66;03m# === FlashAttention aufrufen ===\u001b[39;00m\n\u001b[32m     40\u001b[39m output = \u001b[38;5;28mself\u001b[39m.flash(\n\u001b[32m     41\u001b[39m     qkv_packed,\n\u001b[32m     42\u001b[39m     cu_seqlens=cu_seqlens,\n\u001b[32m     43\u001b[39m     max_seqlen=S,\n\u001b[32m     44\u001b[39m     unpadded=\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m     45\u001b[39m )  \u001b[38;5;66;03m# [total_tokens, H, D]\u001b[39;00m\n",
-      "\u001b[31mRuntimeError\u001b[39m: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead."
-     ]
-    }
-   ],
-   "source": [
-    "results = run_comparison(attn_types=(\"standard\", \"flash2\"), steps=10)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import matplotlib.pyplot as plt\n",
-    "\n",
-    "plt.figure(figsize=(10,5))\n",
-    "for attn_type, data in results.items():\n",
-    "    plt.plot(data[\"loss_curve\"], label=attn_type)\n",
-    "\n",
-    "plt.xlabel(\"Training Step\")\n",
-    "plt.ylabel(\"Loss\")\n",
-    "plt.title(\"Loss-Verlauf: standard vs flash2\")\n",
-    "plt.legend()\n",
-    "plt.grid(True)\n",
-    "plt.show()"
-   ]
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
-   "language": "python",
-   "name": "python3"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
-- 
GitLab