diff --git a/Attention-Benchmarking/bechmark_with_mem.py b/Attention-Benchmarking/bechmark_with_mem.py index ba654e1a654c889397e63ef7a7d6660880ac3f28..f32c5a0475517cfaffc9b57bcb17ec4085cefd00 100644 --- a/Attention-Benchmarking/bechmark_with_mem.py +++ b/Attention-Benchmarking/bechmark_with_mem.py @@ -114,7 +114,7 @@ for causal in causal_vals: results[config]["Pytorch"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak} - # xFormers + # FlashAttention mit xFormers if xops is not None: q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] f, b, mem_used, mem_peak = benchmark_with_memory(