Skip to content
Snippets Groups Projects
Commit bf0363e7 authored by Matthias Keck's avatar Matthias Keck
Browse files

Attention-Benchmarking Code hinzugefügt

parent 74b29421
No related branches found
No related tags found
No related merge requests found
# Import Standard- und Deep-Learning-Bibliotheken
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat # für einfaches Tensor-Reshaping
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_fwd_bwd
from flash_attn import flash_attn_qkvpacked_func # FlashAttention-2 Funktion
try:
import xformers.ops as xops
except ImportError:
xops = None
# # FLOP-Zahl schätzen: Anzahl der Operationen für die Attention je nach Modus
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
# FLOP/s berechnen – also wie effizient die GPU arbeitet
def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0
# # Gibt aktuell genutzten und Peak-Speicher (MB) zurück
def log_memory():
return torch.cuda.memory_allocated() / (1024 ** 2), torch.cuda.max_memory_allocated() / (1024 ** 2)
# Führt Benchmark durch (Forward + Backward) und misst zusätzlich den Speicherverbrauch
def benchmark_with_memory(func, *args, **kwargs):
torch.cuda.reset_peak_memory_stats()
mem_before, _ = log_memory()
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
mem_after, mem_peak = log_memory()
return time_f[1].mean, time_b[1].mean, mem_after - mem_before, mem_peak
# Attention in PyTorch
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
batch_size, seqlen, _, nheads, d = qkv.shape
q, k, v = qkv.unbind(dim=2) # Extrahiere Q, K, V
# Reshape für Batch-Matrixmultiplikation
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
# Q @ K^T (mit Skalierung)
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
scores = scores + causal_mask.to(dtype=scores.dtype)
# Softmax + Dropout
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
# Ausgabe berechnen
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
return output.to(dtype=qkv.dtype)
# globale Einstellungen
repeats = 30
device = 'cuda'
dtype = torch.float16
# Batchgrößen und Sequenzlängen
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
# Causal Masking aus/an
causal_vals = [False, True]
# Head-Dimensionen 64 / 128
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0
# Attention-Implementierungen
methods = ["Flash2", "Pytorch"] + (["xformers.f"] if xops is not None else [])
results = {}
# Schleife über alle Konfigurationen
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim
# Initialisiere QKV-Tensor
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
results[config] = {}
# Flash2
f, b, mem_used, mem_peak = benchmark_with_memory(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats)
results[config]["Flash2"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak}
# PyTorch Attention
try:
qkv = qkv.detach().requires_grad_(True)
f, b, mem_used, mem_peak = benchmark_with_memory(
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats
)
except:
f, b, mem_used, mem_peak = float('nan'), float('nan'), float('nan'), float('nan')
results[config]["Pytorch"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak}
# xFormers
if xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
f, b, mem_used, mem_peak = benchmark_with_memory(
xops.memory_efficient_attention, q, k, v,
attn_bias=xops.LowerTriangularMask() if causal else None,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
)
results[config]["xformers.f"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak}
# Ergebnisse berechnen und ausgeben
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
for method in methods:
if method in results[config]:
entry = results[config][method]
fwd_tflops = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, "fwd"), entry["fwd_time"])
bwd_tflops = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, "bwd"), entry["bwd_time"])
fwd_bwd_tflops = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, "fwd_bwd"), entry["fwd_time"] + entry["bwd_time"])
print(f"{method} fwd: {fwd_tflops:.2f} TFLOPs/s, bwd: {bwd_tflops:.2f} TFLOPs/s, fwd+bwd: {fwd_bwd_tflops:.2f} TFLOPs/s, mem: {entry['mem_used_MB']:.2f} MB (peak {entry['mem_peak_MB']:.2f} MB)")
# Ergebnisse speichern
with open('flash2_attn_results.pkl', 'wb') as fp:
pickle.dump(results, fp, protocol=pickle.HIGHEST_PROTOCOL)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment