diff --git a/Testing/training_setup.ipynb b/Testing/training_setup.ipynb
deleted file mode 100644
index 35d8d6f43a6d57462be78c8d7a9cfeaaec1c76fc..0000000000000000000000000000000000000000
--- a/Testing/training_setup.ipynb
+++ /dev/null
@@ -1,1103 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# FlashAttention-2 Replikation (GPT2-small auf Wikitext-2)\n",
-    "# Ziel: Vergleich von Trainingsgeschwindigkeit für \"torch\", \"flash\", \"flash2\"\n",
-    "# Achtung: Fokus liegt ausschließlich auf Speed, nicht auf Modellgüte"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Imports & Setup"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from transformers import GPT2Config, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling\n",
-    "from transformers import TrainerCallback\n",
-    "from datasets import load_dataset\n",
-    "from flash_attn.models.gpt import GPTLMHeadModel\n",
-    "from transformers import Trainer\n",
-    "import torch.nn.functional as F\n",
-    "import torch\n",
-    "import time\n",
-    "import os\n",
-    "import logging\n",
-    "import random, numpy as np\n",
-    "\n",
-    "def set_seed(seed=42):\n",
-    "    random.seed(seed)\n",
-    "    np.random.seed(seed)\n",
-    "    torch.manual_seed(seed)\n",
-    "    torch.cuda.manual_seed_all(seed)\n",
-    "\n",
-    "set_seed(42)\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Fallback für ColumnParallelLinear"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from flash_attn.models import gpt\n",
-    "if not hasattr(gpt, \"ColumnParallelLinear\") or not isinstance(gpt.ColumnParallelLinear, type):\n",
-    "    import torch.nn as nn\n",
-    "    class ColumnParallelLinear(nn.Module):\n",
-    "        def __init__(self, *args, **kwargs):\n",
-    "            super().__init__()\n",
-    "    gpt.ColumnParallelLinear = ColumnParallelLinear\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Custom Trainer"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "class FlashTrainer(Trainer):\n",
-    "    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
-    "        labels = inputs.pop(\"labels\")\n",
-    "        inputs.pop(\"attention_mask\", None)\n",
-    "        outputs = model(**inputs)\n",
-    "        logits = outputs[0]\n",
-    "\n",
-    "        shift_logits = logits[..., :-1, :].contiguous()\n",
-    "        shift_labels = labels[..., 1:].contiguous()\n",
-    "\n",
-    "        loss = F.cross_entropy(\n",
-    "            shift_logits.view(-1, shift_logits.size(-1)),\n",
-    "            shift_labels.view(-1),\n",
-    "            ignore_index=-100,\n",
-    "        )\n",
-    "        return (loss, outputs) if return_outputs else loss\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Timing Callback"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "class TimingCallback(TrainerCallback):\n",
-    "    def on_train_begin(self, args, state, control, **kwargs):\n",
-    "        self.start_time = time.time()\n",
-    "        self.step_times = []\n",
-    "\n",
-    "    def on_step_end(self, args, state, control, **kwargs):\n",
-    "        self.step_times.append(time.time())\n",
-    "\n",
-    "    def on_train_end(self, args, state, control, **kwargs):\n",
-    "        elapsed = time.time() - self.start_time\n",
-    "        steps = len(self.step_times)\n",
-    "        avg_step = elapsed / steps if steps else float('nan')\n",
-    "        tokens_per_step = args.per_device_train_batch_size * 128  # seq len\n",
-    "        print(f\"Training runtime: {elapsed:.2f}s\")\n",
-    "        print(f\"Steps: {steps} | Zeit/Step: {avg_step:.3f}s | Tokens/s: {tokens_per_step/avg_step:.2f}\")\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Modell-Factory"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def get_gpt2_model(attention_impl=\"torch\"):\n",
-    "    config = GPT2Config(\n",
-    "        n_layer=12, n_head=12, n_embd=768, vocab_size=50257,\n",
-    "        n_positions=1024, resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1,\n",
-    "        layer_norm_epsilon=1e-5\n",
-    "    )\n",
-    "    config.attention_config = {\n",
-    "        \"attn_impl\": attention_impl,\n",
-    "        \"alibi\": False,\n",
-    "        \"rope\": True,\n",
-    "        \"rope_theta\": 10000.0,\n",
-    "        \"use_flash_rotary\": True\n",
-    "    }\n",
-    "    return GPTLMHeadModel(config)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Tokenizer & Dataset"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "139181d887964b47b85758d59b4834de",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Map:   0%|          | 0/4358 [00:00<?, ? examples/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "7e41177ce0a74d318bf48d92842cfc7d",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Map:   0%|          | 0/36718 [00:00<?, ? examples/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "16ac6c7c6bf24121a0b2d6337e7b89ba",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Map:   0%|          | 0/3760 [00:00<?, ? examples/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
-   "source": [
-    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
-    "tokenizer.pad_token = tokenizer.eos_token\n",
-    "\n",
-    "dataset = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\")\n",
-    "\n",
-    "def tokenize(example):\n",
-    "    return tokenizer(\n",
-    "        example[\"text\"], truncation=True, padding=\"max_length\", max_length=128\n",
-    "    )\n",
-    "\n",
-    "tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=[\"text\"])\n",
-    "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Modellparameter zählen"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def count_model_params(model):\n",
-    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Trainingsfunktion"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "results = []\n",
-    "\n",
-    "def train_model(attention_impl=\"torch\"):\n",
-    "    model = get_gpt2_model(attention_impl)\n",
-    "    train_args = TrainingArguments(\n",
-    "        output_dir=f\"./gpt2_{attention_impl}\",\n",
-    "        overwrite_output_dir=True,\n",
-    "        per_device_train_batch_size=4,\n",
-    "        num_train_epochs=1,\n",
-    "        logging_steps=999999,  # deaktiviert Zwischen-Logging\n",
-    "        report_to=\"none\",\n",
-    "        save_strategy=\"no\",\n",
-    "        remove_unused_columns=False\n",
-    "    )\n",
-    "\n",
-    "    # Timer Start\n",
-    "    start_time = time.time()\n",
-    "\n",
-    "    trainer = FlashTrainer(\n",
-    "        model=model,\n",
-    "        args=train_args,\n",
-    "        train_dataset=tokenized_dataset[\"train\"].select(range(2000)),\n",
-    "        data_collator=data_collator,\n",
-    "        tokenizer=tokenizer\n",
-    "    )\n",
-    "\n",
-    "    trainer.train()\n",
-    "\n",
-    "    # FLOPs Berechnung\n",
-    "    elapsed = time.time() - start_time\n",
-    "    num_params = count_model_params(model)\n",
-    "    seq_len = 128\n",
-    "    n_layer = model.config.n_layer\n",
-    "    hidden_dim = model.config.n_embd\n",
-    "\n",
-    "    steps = int(2000 / train_args.per_device_train_batch_size)\n",
-    "    avg_step = elapsed / steps if steps else float('nan')\n",
-    "    tokens_per_step = train_args.per_device_train_batch_size * seq_len\n",
-    "\n",
-    "    flops_total = (\n",
-    "        6 * seq_len * num_params\n",
-    "        + 12 * n_layer * hidden_dim * seq_len * seq_len\n",
-    "    )\n",
-    "\n",
-    "    tflops_per_s = flops_total / (elapsed * 1e12)\n",
-    "\n",
-    "    print(f\"\\n⚙️  Ergebnisse für {attention_impl} Attention\")\n",
-    "    print(f\"-------------------------------\")\n",
-    "    print(f\"Training runtime: {elapsed:.2f}s\")\n",
-    "    print(f\"Steps: {steps} | Zeit/Step: {avg_step:.3f}s\")\n",
-    "    print(f\"Tokens/s: {tokens_per_step / avg_step:.2f}\")\n",
-    "    print(f\"TFLOPs/s: {tflops_per_s:.3f}\")\n",
-    "\n",
-    "    results.append({\n",
-    "        \"impl\": attention_impl,\n",
-    "        \"runtime\": elapsed,\n",
-    "        \"steps\": steps,\n",
-    "        \"step_time\": avg_step,\n",
-    "        \"tokens_per_s\": tokens_per_step / avg_step,\n",
-    "        \"tflops_per_s\": tflops_per_s\n",
-    "    })"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Training ausführen"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "\n",
-      "====== Training mit torch Attention ======\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/tmp/ipykernel_103146/772411891.py:14: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `FlashTrainer.__init__`. Use `processing_class` instead.\n",
-      "  trainer = FlashTrainer(\n"
-     ]
-    },
-    {
-     "data": {
-      "text/html": [
-       "\n",
-       "    <div>\n",
-       "      \n",
-       "      <progress value='500' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
-       "      [500/500 01:10, Epoch 1/1]\n",
-       "    </div>\n",
-       "    <table border=\"1\" class=\"dataframe\">\n",
-       "  <thead>\n",
-       " <tr style=\"text-align: left;\">\n",
-       "      <th>Step</th>\n",
-       "      <th>Training Loss</th>\n",
-       "    </tr>\n",
-       "  </thead>\n",
-       "  <tbody>\n",
-       "    <tr>\n",
-       "      <td>10</td>\n",
-       "      <td>9.853900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>20</td>\n",
-       "      <td>9.238400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>30</td>\n",
-       "      <td>8.887500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>40</td>\n",
-       "      <td>8.335400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>50</td>\n",
-       "      <td>7.903500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>60</td>\n",
-       "      <td>8.377200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>70</td>\n",
-       "      <td>7.905800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>80</td>\n",
-       "      <td>7.842000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>90</td>\n",
-       "      <td>7.483400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>100</td>\n",
-       "      <td>7.373200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>110</td>\n",
-       "      <td>7.223200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>120</td>\n",
-       "      <td>6.112400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>130</td>\n",
-       "      <td>7.025000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>140</td>\n",
-       "      <td>6.774200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>150</td>\n",
-       "      <td>6.816700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>160</td>\n",
-       "      <td>6.402700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>170</td>\n",
-       "      <td>7.248800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>180</td>\n",
-       "      <td>7.165800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>190</td>\n",
-       "      <td>6.733900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>200</td>\n",
-       "      <td>6.543300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>210</td>\n",
-       "      <td>6.040700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>220</td>\n",
-       "      <td>7.140300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>230</td>\n",
-       "      <td>6.716900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>240</td>\n",
-       "      <td>6.968700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>250</td>\n",
-       "      <td>6.511500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>260</td>\n",
-       "      <td>6.789300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>270</td>\n",
-       "      <td>6.436600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>280</td>\n",
-       "      <td>6.637300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>290</td>\n",
-       "      <td>7.051100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>300</td>\n",
-       "      <td>6.926900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>310</td>\n",
-       "      <td>6.794500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>320</td>\n",
-       "      <td>6.902200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>330</td>\n",
-       "      <td>6.413500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>340</td>\n",
-       "      <td>6.618800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>350</td>\n",
-       "      <td>6.836900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>360</td>\n",
-       "      <td>6.214700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>370</td>\n",
-       "      <td>6.639100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>380</td>\n",
-       "      <td>6.771800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>390</td>\n",
-       "      <td>6.499500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>400</td>\n",
-       "      <td>6.771600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>410</td>\n",
-       "      <td>6.790300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>420</td>\n",
-       "      <td>6.679800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>430</td>\n",
-       "      <td>6.848100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>440</td>\n",
-       "      <td>6.703200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>450</td>\n",
-       "      <td>6.816400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>460</td>\n",
-       "      <td>6.456600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>470</td>\n",
-       "      <td>6.723100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>480</td>\n",
-       "      <td>6.211500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>490</td>\n",
-       "      <td>6.745100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>500</td>\n",
-       "      <td>6.307300</td>\n",
-       "    </tr>\n",
-       "  </tbody>\n",
-       "</table><p>"
-      ],
-      "text/plain": [
-       "<IPython.core.display.HTML object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Training runtime: 70.81 seconds\n",
-      "\n",
-      "====== Training mit flash Attention ======\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/tmp/ipykernel_103146/772411891.py:14: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `FlashTrainer.__init__`. Use `processing_class` instead.\n",
-      "  trainer = FlashTrainer(\n"
-     ]
-    },
-    {
-     "data": {
-      "text/html": [
-       "\n",
-       "    <div>\n",
-       "      \n",
-       "      <progress value='500' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
-       "      [500/500 01:11, Epoch 1/1]\n",
-       "    </div>\n",
-       "    <table border=\"1\" class=\"dataframe\">\n",
-       "  <thead>\n",
-       " <tr style=\"text-align: left;\">\n",
-       "      <th>Step</th>\n",
-       "      <th>Training Loss</th>\n",
-       "    </tr>\n",
-       "  </thead>\n",
-       "  <tbody>\n",
-       "    <tr>\n",
-       "      <td>10</td>\n",
-       "      <td>9.790800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>20</td>\n",
-       "      <td>9.218500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>30</td>\n",
-       "      <td>8.901400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>40</td>\n",
-       "      <td>8.293000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>50</td>\n",
-       "      <td>7.922500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>60</td>\n",
-       "      <td>8.341000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>70</td>\n",
-       "      <td>7.855600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>80</td>\n",
-       "      <td>7.782000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>90</td>\n",
-       "      <td>7.498900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>100</td>\n",
-       "      <td>7.365200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>110</td>\n",
-       "      <td>7.251100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>120</td>\n",
-       "      <td>6.080400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>130</td>\n",
-       "      <td>6.999900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>140</td>\n",
-       "      <td>6.777400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>150</td>\n",
-       "      <td>6.834300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>160</td>\n",
-       "      <td>6.397900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>170</td>\n",
-       "      <td>7.255600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>180</td>\n",
-       "      <td>7.165900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>190</td>\n",
-       "      <td>6.768000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>200</td>\n",
-       "      <td>6.512000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>210</td>\n",
-       "      <td>6.040000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>220</td>\n",
-       "      <td>7.139500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>230</td>\n",
-       "      <td>6.729700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>240</td>\n",
-       "      <td>6.983100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>250</td>\n",
-       "      <td>6.522900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>260</td>\n",
-       "      <td>6.821300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>270</td>\n",
-       "      <td>6.405800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>280</td>\n",
-       "      <td>6.652300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>290</td>\n",
-       "      <td>7.041700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>300</td>\n",
-       "      <td>6.929000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>310</td>\n",
-       "      <td>6.797900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>320</td>\n",
-       "      <td>6.891200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>330</td>\n",
-       "      <td>6.383600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>340</td>\n",
-       "      <td>6.655400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>350</td>\n",
-       "      <td>6.830800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>360</td>\n",
-       "      <td>6.199900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>370</td>\n",
-       "      <td>6.595300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>380</td>\n",
-       "      <td>6.784300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>390</td>\n",
-       "      <td>6.497400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>400</td>\n",
-       "      <td>6.793800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>410</td>\n",
-       "      <td>6.775600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>420</td>\n",
-       "      <td>6.673100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>430</td>\n",
-       "      <td>6.853500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>440</td>\n",
-       "      <td>6.701600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>450</td>\n",
-       "      <td>6.797000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>460</td>\n",
-       "      <td>6.456400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>470</td>\n",
-       "      <td>6.769700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>480</td>\n",
-       "      <td>6.221300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>490</td>\n",
-       "      <td>6.734100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>500</td>\n",
-       "      <td>6.292700</td>\n",
-       "    </tr>\n",
-       "  </tbody>\n",
-       "</table><p>"
-      ],
-      "text/plain": [
-       "<IPython.core.display.HTML object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Training runtime: 71.79 seconds\n",
-      "\n",
-      "====== Training mit flash2 Attention ======\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/tmp/ipykernel_103146/772411891.py:14: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `FlashTrainer.__init__`. Use `processing_class` instead.\n",
-      "  trainer = FlashTrainer(\n"
-     ]
-    },
-    {
-     "data": {
-      "text/html": [
-       "\n",
-       "    <div>\n",
-       "      \n",
-       "      <progress value='500' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
-       "      [500/500 01:12, Epoch 1/1]\n",
-       "    </div>\n",
-       "    <table border=\"1\" class=\"dataframe\">\n",
-       "  <thead>\n",
-       " <tr style=\"text-align: left;\">\n",
-       "      <th>Step</th>\n",
-       "      <th>Training Loss</th>\n",
-       "    </tr>\n",
-       "  </thead>\n",
-       "  <tbody>\n",
-       "    <tr>\n",
-       "      <td>10</td>\n",
-       "      <td>9.790800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>20</td>\n",
-       "      <td>9.218500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>30</td>\n",
-       "      <td>8.901400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>40</td>\n",
-       "      <td>8.293000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>50</td>\n",
-       "      <td>7.922500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>60</td>\n",
-       "      <td>8.341000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>70</td>\n",
-       "      <td>7.855600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>80</td>\n",
-       "      <td>7.782000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>90</td>\n",
-       "      <td>7.498900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>100</td>\n",
-       "      <td>7.365200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>110</td>\n",
-       "      <td>7.251100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>120</td>\n",
-       "      <td>6.080400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>130</td>\n",
-       "      <td>6.999900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>140</td>\n",
-       "      <td>6.777400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>150</td>\n",
-       "      <td>6.834300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>160</td>\n",
-       "      <td>6.397900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>170</td>\n",
-       "      <td>7.255600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>180</td>\n",
-       "      <td>7.165900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>190</td>\n",
-       "      <td>6.768000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>200</td>\n",
-       "      <td>6.512000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>210</td>\n",
-       "      <td>6.040000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>220</td>\n",
-       "      <td>7.139500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>230</td>\n",
-       "      <td>6.729700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>240</td>\n",
-       "      <td>6.983100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>250</td>\n",
-       "      <td>6.522900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>260</td>\n",
-       "      <td>6.821300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>270</td>\n",
-       "      <td>6.405800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>280</td>\n",
-       "      <td>6.652300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>290</td>\n",
-       "      <td>7.041700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>300</td>\n",
-       "      <td>6.929000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>310</td>\n",
-       "      <td>6.797900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>320</td>\n",
-       "      <td>6.891200</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>330</td>\n",
-       "      <td>6.383600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>340</td>\n",
-       "      <td>6.655400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>350</td>\n",
-       "      <td>6.830800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>360</td>\n",
-       "      <td>6.199900</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>370</td>\n",
-       "      <td>6.595300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>380</td>\n",
-       "      <td>6.784300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>390</td>\n",
-       "      <td>6.497400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>400</td>\n",
-       "      <td>6.793800</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>410</td>\n",
-       "      <td>6.775600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>420</td>\n",
-       "      <td>6.673100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>430</td>\n",
-       "      <td>6.853500</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>440</td>\n",
-       "      <td>6.701600</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>450</td>\n",
-       "      <td>6.797000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>460</td>\n",
-       "      <td>6.456400</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>470</td>\n",
-       "      <td>6.769700</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>480</td>\n",
-       "      <td>6.221300</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>490</td>\n",
-       "      <td>6.734100</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <td>500</td>\n",
-       "      <td>6.292700</td>\n",
-       "    </tr>\n",
-       "  </tbody>\n",
-       "</table><p>"
-      ],
-      "text/plain": [
-       "<IPython.core.display.HTML object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Training runtime: 72.39 seconds\n"
-     ]
-    }
-   ],
-   "source": [
-    "# Vergleich durchführen\n",
-    "for impl in [\"torch\", \"flash\", \"flash2\"]:\n",
-    "    print(f\"\\n====== Training mit {impl} Attention ======\")\n",
-    "    train_model(impl)"
-   ]
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "Python (venv1)",
-   "language": "python",
-   "name": "venv1"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.11.8"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}