Skip to content
Snippets Groups Projects
Commit 4cc8ed24 authored by Matthias Keck's avatar Matthias Keck
Browse files

Delete benchmark.py

parent 0c3f83b8
Branches
No related tags found
No related merge requests found
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn.utils.benchmark import benchmark_fwd_bwd
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size, seqlen, _, nheads, d = qkv.shape
q, k, v = qkv.unbind(dim=2)
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
return output.to(dtype=qkv.dtype)
def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean
repeats = 30
device = 'cuda'
dtype = torch.float16
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0
methods = ["Flash2", "Flash", "Pytorch"]
time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
# FlashAttention 2
f, b = time_fwd_bwd(
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[config, "Flash2"] = f
time_b[config, "Flash2"] = b
# ✅ FlashAttention (flash_attn_func)
try:
# Extract q, k, v from qkv: [B, N, 3, H, D] → [B, H, N, D]
q, k, v = qkv.unbind(dim=2) # [B, N, H, D] each
q = q.permute(0, 2, 1, 3).contiguous() # → [B, H, N, D]
k = k.permute(0, 2, 1, 3).contiguous()
v = v.permute(0, 2, 1, 3).contiguous()
def flash_attn_func_wrapper(qkv_input, dropout_p=0.0, causal=False):
return flash_attn_func(q, k, v, causal=causal)
f, b = time_fwd_bwd(
flash_attn_func_wrapper,
qkv, # dummy input, real data comes from closure
dropout_p,
causal=causal,
repeats=repeats,
verbose=False
)
except Exception as e:
print(f"❌ Flash failed: {e}")
f, b = float('nan'), float('nan')
time_f[config, "Flash"] = f
time_b[config, "Flash"] = b
# PyTorch
try:
qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
except:
f, b = float('nan'), float('nan')
time_f[config, "Pytorch"] = f
time_b[config, "Pytorch"] = b
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
for method in methods:
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
time_f[config, method]
)
speed_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
time_b[config, method]
)
speed_f_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
time_f_b[config, method]
)
print(
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
)
# Optionally save results
with open('flash2_vs_pytorch_attn_time.pkl', 'wb') as fp:
pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment