diff --git a/Testing/training_setup.ipynb b/Testing/training_setup.ipynb deleted file mode 100644 index 35d8d6f43a6d57462be78c8d7a9cfeaaec1c76fc..0000000000000000000000000000000000000000 --- a/Testing/training_setup.ipynb +++ /dev/null @@ -1,1103 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# FlashAttention-2 Replikation (GPT2-small auf Wikitext-2)\n", - "# Ziel: Vergleich von Trainingsgeschwindigkeit für \"torch\", \"flash\", \"flash2\"\n", - "# Achtung: Fokus liegt ausschließlich auf Speed, nicht auf Modellgüte" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Imports & Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import GPT2Config, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling\n", - "from transformers import TrainerCallback\n", - "from datasets import load_dataset\n", - "from flash_attn.models.gpt import GPTLMHeadModel\n", - "from transformers import Trainer\n", - "import torch.nn.functional as F\n", - "import torch\n", - "import time\n", - "import os\n", - "import logging\n", - "import random, numpy as np\n", - "\n", - "def set_seed(seed=42):\n", - " random.seed(seed)\n", - " np.random.seed(seed)\n", - " torch.manual_seed(seed)\n", - " torch.cuda.manual_seed_all(seed)\n", - "\n", - "set_seed(42)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Fallback für ColumnParallelLinear" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from flash_attn.models import gpt\n", - "if not hasattr(gpt, \"ColumnParallelLinear\") or not isinstance(gpt.ColumnParallelLinear, type):\n", - " import torch.nn as nn\n", - " class ColumnParallelLinear(nn.Module):\n", - " def __init__(self, *args, **kwargs):\n", - " super().__init__()\n", - " gpt.ColumnParallelLinear = ColumnParallelLinear\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Custom Trainer" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "class FlashTrainer(Trainer):\n", - " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n", - " labels = inputs.pop(\"labels\")\n", - " inputs.pop(\"attention_mask\", None)\n", - " outputs = model(**inputs)\n", - " logits = outputs[0]\n", - "\n", - " shift_logits = logits[..., :-1, :].contiguous()\n", - " shift_labels = labels[..., 1:].contiguous()\n", - "\n", - " loss = F.cross_entropy(\n", - " shift_logits.view(-1, shift_logits.size(-1)),\n", - " shift_labels.view(-1),\n", - " ignore_index=-100,\n", - " )\n", - " return (loss, outputs) if return_outputs else loss\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Timing Callback" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class TimingCallback(TrainerCallback):\n", - " def on_train_begin(self, args, state, control, **kwargs):\n", - " self.start_time = time.time()\n", - " self.step_times = []\n", - "\n", - " def on_step_end(self, args, state, control, **kwargs):\n", - " self.step_times.append(time.time())\n", - "\n", - " def on_train_end(self, args, state, control, **kwargs):\n", - " elapsed = time.time() - self.start_time\n", - " steps = len(self.step_times)\n", - " avg_step = elapsed / steps if steps else float('nan')\n", - " tokens_per_step = args.per_device_train_batch_size * 128 # seq len\n", - " print(f\"Training runtime: {elapsed:.2f}s\")\n", - " print(f\"Steps: {steps} | Zeit/Step: {avg_step:.3f}s | Tokens/s: {tokens_per_step/avg_step:.2f}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Modell-Factory" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_gpt2_model(attention_impl=\"torch\"):\n", - " config = GPT2Config(\n", - " n_layer=12, n_head=12, n_embd=768, vocab_size=50257,\n", - " n_positions=1024, resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1,\n", - " layer_norm_epsilon=1e-5\n", - " )\n", - " config.attention_config = {\n", - " \"attn_impl\": attention_impl,\n", - " \"alibi\": False,\n", - " \"rope\": True,\n", - " \"rope_theta\": 10000.0,\n", - " \"use_flash_rotary\": True\n", - " }\n", - " return GPTLMHeadModel(config)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Tokenizer & Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "139181d887964b47b85758d59b4834de", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Map: 0%| | 0/4358 [00:00<?, ? examples/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7e41177ce0a74d318bf48d92842cfc7d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Map: 0%| | 0/36718 [00:00<?, ? examples/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "16ac6c7c6bf24121a0b2d6337e7b89ba", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Map: 0%| | 0/3760 [00:00<?, ? examples/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n", - "tokenizer.pad_token = tokenizer.eos_token\n", - "\n", - "dataset = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\")\n", - "\n", - "def tokenize(example):\n", - " return tokenizer(\n", - " example[\"text\"], truncation=True, padding=\"max_length\", max_length=128\n", - " )\n", - "\n", - "tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=[\"text\"])\n", - "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Modellparameter zählen" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def count_model_params(model):\n", - " return sum(p.numel() for p in model.parameters() if p.requires_grad)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Trainingsfunktion" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "results = []\n", - "\n", - "def train_model(attention_impl=\"torch\"):\n", - " model = get_gpt2_model(attention_impl)\n", - " train_args = TrainingArguments(\n", - " output_dir=f\"./gpt2_{attention_impl}\",\n", - " overwrite_output_dir=True,\n", - " per_device_train_batch_size=4,\n", - " num_train_epochs=1,\n", - " logging_steps=999999, # deaktiviert Zwischen-Logging\n", - " report_to=\"none\",\n", - " save_strategy=\"no\",\n", - " remove_unused_columns=False\n", - " )\n", - "\n", - " # Timer Start\n", - " start_time = time.time()\n", - "\n", - " trainer = FlashTrainer(\n", - " model=model,\n", - " args=train_args,\n", - " train_dataset=tokenized_dataset[\"train\"].select(range(2000)),\n", - " data_collator=data_collator,\n", - " tokenizer=tokenizer\n", - " )\n", - "\n", - " trainer.train()\n", - "\n", - " # FLOPs Berechnung\n", - " elapsed = time.time() - start_time\n", - " num_params = count_model_params(model)\n", - " seq_len = 128\n", - " n_layer = model.config.n_layer\n", - " hidden_dim = model.config.n_embd\n", - "\n", - " steps = int(2000 / train_args.per_device_train_batch_size)\n", - " avg_step = elapsed / steps if steps else float('nan')\n", - " tokens_per_step = train_args.per_device_train_batch_size * seq_len\n", - "\n", - " flops_total = (\n", - " 6 * seq_len * num_params\n", - " + 12 * n_layer * hidden_dim * seq_len * seq_len\n", - " )\n", - "\n", - " tflops_per_s = flops_total / (elapsed * 1e12)\n", - "\n", - " print(f\"\\n⚙️ Ergebnisse für {attention_impl} Attention\")\n", - " print(f\"-------------------------------\")\n", - " print(f\"Training runtime: {elapsed:.2f}s\")\n", - " print(f\"Steps: {steps} | Zeit/Step: {avg_step:.3f}s\")\n", - " print(f\"Tokens/s: {tokens_per_step / avg_step:.2f}\")\n", - " print(f\"TFLOPs/s: {tflops_per_s:.3f}\")\n", - "\n", - " results.append({\n", - " \"impl\": attention_impl,\n", - " \"runtime\": elapsed,\n", - " \"steps\": steps,\n", - " \"step_time\": avg_step,\n", - " \"tokens_per_s\": tokens_per_step / avg_step,\n", - " \"tflops_per_s\": tflops_per_s\n", - " })" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Training ausführen" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "====== Training mit torch Attention ======\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_103146/772411891.py:14: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `FlashTrainer.__init__`. Use `processing_class` instead.\n", - " trainer = FlashTrainer(\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " <div>\n", - " \n", - " <progress value='500' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", - " [500/500 01:10, Epoch 1/1]\n", - " </div>\n", - " <table border=\"1\" class=\"dataframe\">\n", - " <thead>\n", - " <tr style=\"text-align: left;\">\n", - " <th>Step</th>\n", - " <th>Training Loss</th>\n", - " </tr>\n", - " </thead>\n", - " <tbody>\n", - " <tr>\n", - " <td>10</td>\n", - " <td>9.853900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>20</td>\n", - " <td>9.238400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>30</td>\n", - " <td>8.887500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>40</td>\n", - " <td>8.335400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>50</td>\n", - " <td>7.903500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>60</td>\n", - " <td>8.377200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>70</td>\n", - " <td>7.905800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>80</td>\n", - " <td>7.842000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>90</td>\n", - " <td>7.483400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>100</td>\n", - " <td>7.373200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>110</td>\n", - " <td>7.223200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>120</td>\n", - " <td>6.112400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>130</td>\n", - " <td>7.025000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>140</td>\n", - " <td>6.774200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>150</td>\n", - " <td>6.816700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>160</td>\n", - " <td>6.402700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>170</td>\n", - " <td>7.248800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>180</td>\n", - " <td>7.165800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>190</td>\n", - " <td>6.733900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>200</td>\n", - " <td>6.543300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>210</td>\n", - " <td>6.040700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>220</td>\n", - " <td>7.140300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>230</td>\n", - " <td>6.716900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>240</td>\n", - " <td>6.968700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>250</td>\n", - " <td>6.511500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>260</td>\n", - " <td>6.789300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>270</td>\n", - " <td>6.436600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>280</td>\n", - " <td>6.637300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>290</td>\n", - " <td>7.051100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>300</td>\n", - " <td>6.926900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>310</td>\n", - " <td>6.794500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>320</td>\n", - " <td>6.902200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>330</td>\n", - " <td>6.413500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>340</td>\n", - " <td>6.618800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>350</td>\n", - " <td>6.836900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>360</td>\n", - " <td>6.214700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>370</td>\n", - " <td>6.639100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>380</td>\n", - " <td>6.771800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>390</td>\n", - " <td>6.499500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>400</td>\n", - " <td>6.771600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>410</td>\n", - " <td>6.790300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>420</td>\n", - " <td>6.679800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>430</td>\n", - " <td>6.848100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>440</td>\n", - " <td>6.703200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>450</td>\n", - " <td>6.816400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>460</td>\n", - " <td>6.456600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>470</td>\n", - " <td>6.723100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>480</td>\n", - " <td>6.211500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>490</td>\n", - " <td>6.745100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>500</td>\n", - " <td>6.307300</td>\n", - " </tr>\n", - " </tbody>\n", - "</table><p>" - ], - "text/plain": [ - "<IPython.core.display.HTML object>" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training runtime: 70.81 seconds\n", - "\n", - "====== Training mit flash Attention ======\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_103146/772411891.py:14: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `FlashTrainer.__init__`. Use `processing_class` instead.\n", - " trainer = FlashTrainer(\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " <div>\n", - " \n", - " <progress value='500' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", - " [500/500 01:11, Epoch 1/1]\n", - " </div>\n", - " <table border=\"1\" class=\"dataframe\">\n", - " <thead>\n", - " <tr style=\"text-align: left;\">\n", - " <th>Step</th>\n", - " <th>Training Loss</th>\n", - " </tr>\n", - " </thead>\n", - " <tbody>\n", - " <tr>\n", - " <td>10</td>\n", - " <td>9.790800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>20</td>\n", - " <td>9.218500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>30</td>\n", - " <td>8.901400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>40</td>\n", - " <td>8.293000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>50</td>\n", - " <td>7.922500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>60</td>\n", - " <td>8.341000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>70</td>\n", - " <td>7.855600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>80</td>\n", - " <td>7.782000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>90</td>\n", - " <td>7.498900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>100</td>\n", - " <td>7.365200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>110</td>\n", - " <td>7.251100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>120</td>\n", - " <td>6.080400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>130</td>\n", - " <td>6.999900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>140</td>\n", - " <td>6.777400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>150</td>\n", - " <td>6.834300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>160</td>\n", - " <td>6.397900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>170</td>\n", - " <td>7.255600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>180</td>\n", - " <td>7.165900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>190</td>\n", - " <td>6.768000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>200</td>\n", - " <td>6.512000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>210</td>\n", - " <td>6.040000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>220</td>\n", - " <td>7.139500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>230</td>\n", - " <td>6.729700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>240</td>\n", - " <td>6.983100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>250</td>\n", - " <td>6.522900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>260</td>\n", - " <td>6.821300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>270</td>\n", - " <td>6.405800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>280</td>\n", - " <td>6.652300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>290</td>\n", - " <td>7.041700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>300</td>\n", - " <td>6.929000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>310</td>\n", - " <td>6.797900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>320</td>\n", - " <td>6.891200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>330</td>\n", - " <td>6.383600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>340</td>\n", - " <td>6.655400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>350</td>\n", - " <td>6.830800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>360</td>\n", - " <td>6.199900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>370</td>\n", - " <td>6.595300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>380</td>\n", - " <td>6.784300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>390</td>\n", - " <td>6.497400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>400</td>\n", - " <td>6.793800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>410</td>\n", - " <td>6.775600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>420</td>\n", - " <td>6.673100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>430</td>\n", - " <td>6.853500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>440</td>\n", - " <td>6.701600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>450</td>\n", - " <td>6.797000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>460</td>\n", - " <td>6.456400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>470</td>\n", - " <td>6.769700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>480</td>\n", - " <td>6.221300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>490</td>\n", - " <td>6.734100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>500</td>\n", - " <td>6.292700</td>\n", - " </tr>\n", - " </tbody>\n", - "</table><p>" - ], - "text/plain": [ - "<IPython.core.display.HTML object>" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training runtime: 71.79 seconds\n", - "\n", - "====== Training mit flash2 Attention ======\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_103146/772411891.py:14: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `FlashTrainer.__init__`. Use `processing_class` instead.\n", - " trainer = FlashTrainer(\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " <div>\n", - " \n", - " <progress value='500' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", - " [500/500 01:12, Epoch 1/1]\n", - " </div>\n", - " <table border=\"1\" class=\"dataframe\">\n", - " <thead>\n", - " <tr style=\"text-align: left;\">\n", - " <th>Step</th>\n", - " <th>Training Loss</th>\n", - " </tr>\n", - " </thead>\n", - " <tbody>\n", - " <tr>\n", - " <td>10</td>\n", - " <td>9.790800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>20</td>\n", - " <td>9.218500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>30</td>\n", - " <td>8.901400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>40</td>\n", - " <td>8.293000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>50</td>\n", - " <td>7.922500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>60</td>\n", - " <td>8.341000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>70</td>\n", - " <td>7.855600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>80</td>\n", - " <td>7.782000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>90</td>\n", - " <td>7.498900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>100</td>\n", - " <td>7.365200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>110</td>\n", - " <td>7.251100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>120</td>\n", - " <td>6.080400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>130</td>\n", - " <td>6.999900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>140</td>\n", - " <td>6.777400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>150</td>\n", - " <td>6.834300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>160</td>\n", - " <td>6.397900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>170</td>\n", - " <td>7.255600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>180</td>\n", - " <td>7.165900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>190</td>\n", - " <td>6.768000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>200</td>\n", - " <td>6.512000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>210</td>\n", - " <td>6.040000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>220</td>\n", - " <td>7.139500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>230</td>\n", - " <td>6.729700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>240</td>\n", - " <td>6.983100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>250</td>\n", - " <td>6.522900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>260</td>\n", - " <td>6.821300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>270</td>\n", - " <td>6.405800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>280</td>\n", - " <td>6.652300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>290</td>\n", - " <td>7.041700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>300</td>\n", - " <td>6.929000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>310</td>\n", - " <td>6.797900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>320</td>\n", - " <td>6.891200</td>\n", - " </tr>\n", - " <tr>\n", - " <td>330</td>\n", - " <td>6.383600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>340</td>\n", - " <td>6.655400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>350</td>\n", - " <td>6.830800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>360</td>\n", - " <td>6.199900</td>\n", - " </tr>\n", - " <tr>\n", - " <td>370</td>\n", - " <td>6.595300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>380</td>\n", - " <td>6.784300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>390</td>\n", - " <td>6.497400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>400</td>\n", - " <td>6.793800</td>\n", - " </tr>\n", - " <tr>\n", - " <td>410</td>\n", - " <td>6.775600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>420</td>\n", - " <td>6.673100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>430</td>\n", - " <td>6.853500</td>\n", - " </tr>\n", - " <tr>\n", - " <td>440</td>\n", - " <td>6.701600</td>\n", - " </tr>\n", - " <tr>\n", - " <td>450</td>\n", - " <td>6.797000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>460</td>\n", - " <td>6.456400</td>\n", - " </tr>\n", - " <tr>\n", - " <td>470</td>\n", - " <td>6.769700</td>\n", - " </tr>\n", - " <tr>\n", - " <td>480</td>\n", - " <td>6.221300</td>\n", - " </tr>\n", - " <tr>\n", - " <td>490</td>\n", - " <td>6.734100</td>\n", - " </tr>\n", - " <tr>\n", - " <td>500</td>\n", - " <td>6.292700</td>\n", - " </tr>\n", - " </tbody>\n", - "</table><p>" - ], - "text/plain": [ - "<IPython.core.display.HTML object>" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training runtime: 72.39 seconds\n" - ] - } - ], - "source": [ - "# Vergleich durchführen\n", - "for impl in [\"torch\", \"flash\", \"flash2\"]:\n", - " print(f\"\\n====== Training mit {impl} Attention ======\")\n", - " train_model(impl)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python (venv1)", - "language": "python", - "name": "venv1" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}