Skip to content
Snippets Groups Projects
Commit 16a320df authored by Matthias Keck's avatar Matthias Keck
Browse files

Edit bechmark_with_mem.py

parent 56c87289
No related branches found
No related tags found
No related merge requests found
......@@ -114,7 +114,7 @@ for causal in causal_vals:
results[config]["Pytorch"] = {"fwd_time": f, "bwd_time": b, "mem_used_MB": mem_used, "mem_peak_MB": mem_peak}
# xFormers
# FlashAttention mit xFormers
if xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
f, b, mem_used, mem_peak = benchmark_with_memory(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment