diff --git a/Attention-Benchmarking/bechmark_with_mem.py b/Attention-Benchmarking/bechmark_with_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..ba654e1a654c889397e63ef7a7d6660880ac3f28 --- /dev/null +++ b/Attention-Benchmarking/bechmark_with_mem.py @@ -0,0 +1,140 @@ + +# Import Standard- und Deep-Learning-Bibliotheken +import pickle +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat # für einfaches Tensor-Reshaping +from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_fwd_bwd +from flash_attn import flash_attn_qkvpacked_func # FlashAttention-2 Funktion + +try: + import xformers.ops as xops +except ImportError: + xops = None + +# # FLOP-Zahl schätzen: Anzahl der Operationen für die Attention je nach Modus +def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) + +# FLOP/s berechnen – also wie effizient die GPU arbeitet +def efficiency(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + +# # Gibt aktuell genutzten und Peak-Speicher (MB) zurück +def log_memory(): + return torch.cuda.memory_allocated() / (1024 ** 2), torch.cuda.max_memory_allocated() / (1024 ** 2) + +# Führt Benchmark durch (Forward + Backward) und misst zusätzlich den Speicherverbrauch +def benchmark_with_memory(func, *args, **kwargs): + torch.cuda.reset_peak_memory_stats() + mem_before, _ = log_memory() + + time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) + + mem_after, mem_peak = log_memory() + return time_f[1].mean, time_b[1].mean, mem_after - mem_before, mem_peak + +# Attention in PyTorch +def attention_pytorch(qkv, dropout_p=0.0, causal=True): + batch_size, seqlen, _, nheads, d = qkv.shape + q, k, v = qkv.unbind(dim=2) # Extrahiere Q, K, V + + # Reshape für Batch-Matrixmultiplikation + q = rearrange(q, 'b t h d -> (b h) t d') + k = rearrange(k, 'b s h d -> (b h) d s') + softmax_scale = 1.0 / math.sqrt(d) + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) + + # Q @ K^T (mit Skalierung) + scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), + '(b h) t s -> b h t s', h=nheads) + if causal: + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + scores = scores + causal_mask.to(dtype=scores.dtype) + + # Softmax + Dropout + attention = torch.softmax(scores, dim=-1) + attention_drop = F.dropout(attention, dropout_p) + + # Ausgabe berechnen + output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + return output.to(dtype=qkv.dtype) + +# globale Einstellungen +repeats = 30 +device = 'cuda' +dtype = torch.float16 + +# Batchgrößen und Sequenzlängen +bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] + +# Causal Masking aus/an +causal_vals = [False, True] + +# Head-Dimensionen 64 / 128 +headdim_vals = [64, 128] +dim = 2048 +dropout_p = 0.0 + +# Attention-Implementierungen +methods = ["Flash2", "Pytorch"] + (["xformers.f"] if xops is not None else []) + +results = {} + +# Schleife über alle Konfigurationen +for causal in causal_vals: + for headdim in headdim_vals: + for batch_size, seqlen in bs_seqlen_vals: + config = (causal, headdim, batch_size, seqlen) + nheads = dim // headdim + + # Initialisiere QKV-Tensor + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True) + + results[config] = {} + + # Flash2 + f, b, mem_used, mem_peak = benchmark_with_memory(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats) + results[config]["Flash2"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak} + + + # PyTorch Attention + try: + qkv = qkv.detach().requires_grad_(True) + f, b, mem_used, mem_peak = benchmark_with_memory( + attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats + ) + except: + f, b, mem_used, mem_peak = float('nan'), float('nan'), float('nan'), float('nan') + + results[config]["Pytorch"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak} + + # xFormers + if xops is not None: + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] + f, b, mem_used, mem_peak = benchmark_with_memory( + xops.memory_efficient_attention, q, k, v, + attn_bias=xops.LowerTriangularMask() if causal else None, + op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp) + ) + results[config]["xformers.f"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak} + + # Ergebnisse berechnen und ausgeben + print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") + for method in methods: + if method in results[config]: + entry = results[config][method] + fwd_tflops = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, "fwd"), entry["fwd_time"]) + bwd_tflops = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, "bwd"), entry["bwd_time"]) + fwd_bwd_tflops = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, "fwd_bwd"), entry["fwd_time"] + entry["bwd_time"]) + print(f"{method} fwd: {fwd_tflops:.2f} TFLOPs/s, bwd: {bwd_tflops:.2f} TFLOPs/s, fwd+bwd: {fwd_bwd_tflops:.2f} TFLOPs/s, mem: {entry['mem_used_MB']:.2f} MB (peak {entry['mem_peak_MB']:.2f} MB)") + +# Ergebnisse speichern +with open('flash2_attn_results.pkl', 'wb') as fp: + pickle.dump(results, fp, protocol=pickle.HIGHEST_PROTOCOL) +