From bf0363e740aa7a0f963c1171eea475365b37824d Mon Sep 17 00:00:00 2001
From: Matthias Keck <matthias.keck@student.uni-halle.de>
Date: Mon, 31 Mar 2025 20:30:47 +0000
Subject: [PATCH] =?UTF-8?q?Attention-Benchmarking=20Code=20hinzugef=C3=BCg?=
 =?UTF-8?q?t?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 Attention-Benchmarking/bechmark_with_mem.py | 140 ++++++++++++++++++++
 1 file changed, 140 insertions(+)
 create mode 100644 Attention-Benchmarking/bechmark_with_mem.py

diff --git a/Attention-Benchmarking/bechmark_with_mem.py b/Attention-Benchmarking/bechmark_with_mem.py
new file mode 100644
index 0000000..ba654e1
--- /dev/null
+++ b/Attention-Benchmarking/bechmark_with_mem.py
@@ -0,0 +1,140 @@
+
+# Import Standard- und Deep-Learning-Bibliotheken
+import pickle
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange, repeat # für einfaches Tensor-Reshaping
+from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_fwd_bwd
+from flash_attn import flash_attn_qkvpacked_func # FlashAttention-2 Funktion
+
+try:
+    import xformers.ops as xops
+except ImportError:
+    xops = None
+
+# # FLOP-Zahl schätzen: Anzahl der Operationen für die Attention je nach Modus
+def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
+    assert mode in ["fwd", "bwd", "fwd_bwd"]
+    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
+    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
+
+# FLOP/s berechnen – also wie effizient die GPU arbeitet
+def efficiency(flop, time):
+    return (flop / time / 10**12) if not math.isnan(time) else 0.0
+
+# # Gibt aktuell genutzten und Peak-Speicher (MB) zurück
+def log_memory():
+    return torch.cuda.memory_allocated() / (1024 ** 2), torch.cuda.max_memory_allocated() / (1024 ** 2)
+
+# Führt Benchmark durch (Forward + Backward) und misst zusätzlich den Speicherverbrauch
+def benchmark_with_memory(func, *args, **kwargs):
+    torch.cuda.reset_peak_memory_stats()
+    mem_before, _ = log_memory()
+
+    time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
+
+    mem_after, mem_peak = log_memory()
+    return time_f[1].mean, time_b[1].mean, mem_after - mem_before, mem_peak
+
+# Attention in PyTorch 
+def attention_pytorch(qkv, dropout_p=0.0, causal=True):
+    batch_size, seqlen, _, nheads, d = qkv.shape
+    q, k, v = qkv.unbind(dim=2) # Extrahiere Q, K, V
+
+     # Reshape für Batch-Matrixmultiplikation
+    q = rearrange(q, 'b t h d -> (b h) t d')
+    k = rearrange(k, 'b s h d -> (b h) d s')
+    softmax_scale = 1.0 / math.sqrt(d)
+    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
+
+     # Q @ K^T (mit Skalierung)
+    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
+                       '(b h) t s -> b h t s', h=nheads)
+    if causal:
+        causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
+        scores = scores + causal_mask.to(dtype=scores.dtype)
+    
+    # Softmax + Dropout
+    attention = torch.softmax(scores, dim=-1)
+    attention_drop = F.dropout(attention, dropout_p)
+
+    # Ausgabe berechnen
+    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
+    return output.to(dtype=qkv.dtype)
+
+# globale Einstellungen
+repeats = 30
+device = 'cuda'
+dtype = torch.float16
+
+# Batchgrößen und Sequenzlängen
+bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
+
+# Causal Masking aus/an
+causal_vals = [False, True]
+
+# Head-Dimensionen 64 / 128
+headdim_vals = [64, 128]
+dim = 2048
+dropout_p = 0.0
+
+# Attention-Implementierungen
+methods = ["Flash2", "Pytorch"] + (["xformers.f"] if xops is not None else [])
+
+results = {}
+
+# Schleife über alle Konfigurationen
+for causal in causal_vals:
+    for headdim in headdim_vals:
+        for batch_size, seqlen in bs_seqlen_vals:
+            config = (causal, headdim, batch_size, seqlen)
+            nheads = dim // headdim
+
+            # Initialisiere QKV-Tensor
+            qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
+
+            results[config] = {}
+
+            # Flash2
+            f, b, mem_used, mem_peak = benchmark_with_memory(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats)
+            results[config]["Flash2"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak}
+
+
+            # PyTorch Attention
+            try:
+                qkv = qkv.detach().requires_grad_(True)
+                f, b, mem_used, mem_peak = benchmark_with_memory(
+                    attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats
+                )
+            except:
+                f, b, mem_used, mem_peak = float('nan'), float('nan'), float('nan'), float('nan')
+
+            results[config]["Pytorch"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak}
+
+            # xFormers
+            if xops is not None:
+                q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
+                f, b, mem_used, mem_peak = benchmark_with_memory(
+                    xops.memory_efficient_attention, q, k, v,
+                    attn_bias=xops.LowerTriangularMask() if causal else None,
+                    op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
+                )
+                results[config]["xformers.f"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak}
+
+            # Ergebnisse berechnen und ausgeben
+            print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
+            for method in methods:
+                if method in results[config]:
+                    entry = results[config][method]
+                    fwd_tflops = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, "fwd"), entry["fwd_time"])
+                    bwd_tflops = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, "bwd"), entry["bwd_time"])
+                    fwd_bwd_tflops = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, "fwd_bwd"), entry["fwd_time"] + entry["bwd_time"])
+                    print(f"{method} fwd: {fwd_tflops:.2f} TFLOPs/s, bwd: {bwd_tflops:.2f} TFLOPs/s, fwd+bwd: {fwd_bwd_tflops:.2f} TFLOPs/s, mem: {entry['mem_used_MB']:.2f} MB (peak {entry['mem_peak_MB']:.2f} MB)")
+
+# Ergebnisse speichern
+with open('flash2_attn_results.pkl', 'wb') as fp:
+     pickle.dump(results, fp, protocol=pickle.HIGHEST_PROTOCOL)
+
-- 
GitLab