Skip to content
Snippets Groups Projects
Commit 95624117 authored by Matthias Keck's avatar Matthias Keck
Browse files

Delete with_xformers_19609.log

parent 0f1da049
No related merge requests found
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbef610>
fn_amp(*inputs, **kwinputs)
432.29 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc68210>
f(*inputs, y=y, grad=grad)
1.57 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 145.43 TFLOPs/s, bwd: 112.28 TFLOPs/s, fwd + bwd: 120.10 TFLOPs/s
Flash fwd: 58.00 TFLOPs/s, bwd: 50.81 TFLOPs/s, fwd + bwd: 52.67 TFLOPs/s
Pytorch fwd: 29.52 TFLOPs/s, bwd: 38.70 TFLOPs/s, fwd + bwd: 35.54 TFLOPs/s
xformers.f fwd: 158.97 TFLOPs/s, bwd: 109.57 TFLOPs/s, fwd + bwd: 120.24 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94db7e7d0>
fn_amp(*inputs, **kwinputs)
785.54 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc62890>
f(*inputs, y=y, grad=grad)
2.59 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 177.28 TFLOPs/s, bwd: 138.96 TFLOPs/s, fwd + bwd: 148.10 TFLOPs/s
Flash fwd: 126.04 TFLOPs/s, bwd: 106.44 TFLOPs/s, fwd + bwd: 111.39 TFLOPs/s
Pytorch fwd: 34.24 TFLOPs/s, bwd: 42.90 TFLOPs/s, fwd + bwd: 40.01 TFLOPs/s
xformers.f fwd: 174.96 TFLOPs/s, bwd: 132.50 TFLOPs/s, fwd + bwd: 142.38 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9547bd650>
fn_amp(*inputs, **kwinputs)
1.50 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbed790>
f(*inputs, y=y, grad=grad)
4.54 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 185.77 TFLOPs/s, bwd: 162.20 TFLOPs/s, fwd + bwd: 168.30 TFLOPs/s
Flash fwd: 252.19 TFLOPs/s, bwd: 213.94 TFLOPs/s, fwd + bwd: 223.63 TFLOPs/s
Pytorch fwd: 22.71 TFLOPs/s, bwd: 47.41 TFLOPs/s, fwd + bwd: 36.17 TFLOPs/s
xformers.f fwd: 183.00 TFLOPs/s, bwd: 151.52 TFLOPs/s, fwd + bwd: 159.35 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a8fdec50>
fn_amp(*inputs, **kwinputs)
2.95 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc35310>
f(*inputs, y=y, grad=grad)
8.49 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 189.75 TFLOPs/s, bwd: 171.92 TFLOPs/s, fwd + bwd: 176.66 TFLOPs/s
Flash fwd: 504.11 TFLOPs/s, bwd: 429.30 TFLOPs/s, fwd + bwd: 448.31 TFLOPs/s
Pytorch fwd: 31.25 TFLOPs/s, bwd: 49.06 TFLOPs/s, fwd + bwd: 42.19 TFLOPs/s
xformers.f fwd: 186.08 TFLOPs/s, bwd: 161.93 TFLOPs/s, fwd + bwd: 168.17 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc27c50>
fn_amp(*inputs, **kwinputs)
5.80 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc36f10>
f(*inputs, y=y, grad=grad)
16.26 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 190.67 TFLOPs/s, bwd: 177.57 TFLOPs/s, fwd + bwd: 181.13 TFLOPs/s
Flash fwd: 1008.03 TFLOPs/s, bwd: 856.80 TFLOPs/s, fwd + bwd: 895.17 TFLOPs/s
Pytorch fwd: 38.75 TFLOPs/s, bwd: 50.81 TFLOPs/s, fwd + bwd: 46.66 TFLOPs/s
xformers.f fwd: 189.72 TFLOPs/s, bwd: 169.00 TFLOPs/s, fwd + bwd: 174.44 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc41090>
fn_amp(*inputs, **kwinputs)
11.51 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c911790>
f(*inputs, y=y, grad=grad)
32.04 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 191.62 TFLOPs/s, bwd: 180.79 TFLOPs/s, fwd + bwd: 183.76 TFLOPs/s
Flash fwd: 2061.06 TFLOPs/s, bwd: 1732.97 TFLOPs/s, fwd + bwd: 1815.55 TFLOPs/s
Pytorch fwd: 42.55 TFLOPs/s, bwd: 51.06 TFLOPs/s, fwd + bwd: 48.30 TFLOPs/s
xformers.f fwd: 191.11 TFLOPs/s, bwd: 171.57 TFLOPs/s, fwd + bwd: 176.73 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c937e50>
fn_amp(*inputs, **kwinputs)
382.27 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc07ed0>
f(*inputs, y=y, grad=grad)
1.39 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=128, batch_size=32, seqlen=512 ###
Flash2 fwd: 193.45 TFLOPs/s, bwd: 130.15 TFLOPs/s, fwd + bwd: 143.57 TFLOPs/s
Flash fwd: 66.98 TFLOPs/s, bwd: 55.29 TFLOPs/s, fwd + bwd: 58.19 TFLOPs/s
Pytorch fwd: 42.10 TFLOPs/s, bwd: 59.42 TFLOPs/s, fwd + bwd: 53.17 TFLOPs/s
xformers.f fwd: 179.77 TFLOPs/s, bwd: 124.03 TFLOPs/s, fwd + bwd: 136.08 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbf8590>
fn_amp(*inputs, **kwinputs)
698.35 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9833bd8d0>
f(*inputs, y=y, grad=grad)
2.27 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=128, batch_size=16, seqlen=1024 ###
Flash2 fwd: 212.75 TFLOPs/s, bwd: 158.20 TFLOPs/s, fwd + bwd: 170.71 TFLOPs/s
Flash fwd: 133.57 TFLOPs/s, bwd: 110.98 TFLOPs/s, fwd + bwd: 116.62 TFLOPs/s
Pytorch fwd: 56.03 TFLOPs/s, bwd: 73.64 TFLOPs/s, fwd + bwd: 67.57 TFLOPs/s
xformers.f fwd: 196.80 TFLOPs/s, bwd: 151.04 TFLOPs/s, fwd + bwd: 161.79 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c92dcd0>
fn_amp(*inputs, **kwinputs)
1.34 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c9348d0>
f(*inputs, y=y, grad=grad)
4.08 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=128, batch_size=8, seqlen=2048 ###
Flash2 fwd: 221.37 TFLOPs/s, bwd: 176.91 TFLOPs/s, fwd + bwd: 187.68 TFLOPs/s
Flash fwd: 266.88 TFLOPs/s, bwd: 222.63 TFLOPs/s, fwd + bwd: 233.70 TFLOPs/s
Pytorch fwd: 41.79 TFLOPs/s, bwd: 86.10 TFLOPs/s, fwd + bwd: 66.08 TFLOPs/s
xformers.f fwd: 205.74 TFLOPs/s, bwd: 168.26 TFLOPs/s, fwd + bwd: 177.50 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c922790>
fn_amp(*inputs, **kwinputs)
2.62 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbaf490>
f(*inputs, y=y, grad=grad)
7.71 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=128, batch_size=4, seqlen=4096 ###
Flash2 fwd: 224.62 TFLOPs/s, bwd: 189.36 TFLOPs/s, fwd + bwd: 198.25 TFLOPs/s
Flash fwd: 535.08 TFLOPs/s, bwd: 447.31 TFLOPs/s, fwd + bwd: 469.30 TFLOPs/s
Pytorch fwd: 56.00 TFLOPs/s, bwd: 88.50 TFLOPs/s, fwd + bwd: 75.91 TFLOPs/s
xformers.f fwd: 210.06 TFLOPs/s, bwd: 178.18 TFLOPs/s, fwd + bwd: 186.26 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbaca90>
fn_amp(*inputs, **kwinputs)
5.22 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c935b10>
f(*inputs, y=y, grad=grad)
14.94 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 226.81 TFLOPs/s, bwd: 196.35 TFLOPs/s, fwd + bwd: 204.18 TFLOPs/s
Flash fwd: 1071.20 TFLOPs/s, bwd: 891.67 TFLOPs/s, fwd + bwd: 936.52 TFLOPs/s
Pytorch fwd: 71.34 TFLOPs/s, bwd: 97.32 TFLOPs/s, fwd + bwd: 88.15 TFLOPs/s
xformers.f fwd: 210.46 TFLOPs/s, bwd: 183.95 TFLOPs/s, fwd + bwd: 190.82 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c92df10>
fn_amp(*inputs, **kwinputs)
10.38 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc41f50>
f(*inputs, y=y, grad=grad)
29.27 ms
1 measurement, 10 runs , 8 threads
### causal=False, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 226.09 TFLOPs/s, bwd: 199.75 TFLOPs/s, fwd + bwd: 206.63 TFLOPs/s
Flash fwd: 2178.96 TFLOPs/s, bwd: 1787.50 TFLOPs/s, fwd + bwd: 1884.22 TFLOPs/s
Pytorch fwd: 79.16 TFLOPs/s, bwd: 98.27 TFLOPs/s, fwd + bwd: 91.93 TFLOPs/s
xformers.f fwd: 211.87 TFLOPs/s, bwd: 187.81 TFLOPs/s, fwd + bwd: 194.11 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbdd590>
fn_amp(*inputs, **kwinputs)
340.20 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc36690>
f(*inputs, y=y, grad=grad)
1.17 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 110.01 TFLOPs/s, bwd: 80.00 TFLOPs/s, fwd + bwd: 86.76 TFLOPs/s
Flash fwd: 32.13 TFLOPs/s, bwd: 26.32 TFLOPs/s, fwd + bwd: 27.75 TFLOPs/s
Pytorch fwd: 9.61 TFLOPs/s, bwd: 19.32 TFLOPs/s, fwd + bwd: 14.99 TFLOPs/s
xformers.f fwd: 101.00 TFLOPs/s, bwd: 73.69 TFLOPs/s, fwd + bwd: 79.86 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c91eb10>
fn_amp(*inputs, **kwinputs)
517.94 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c920d90>
f(*inputs, y=y, grad=grad)
1.69 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 139.77 TFLOPs/s, bwd: 111.59 TFLOPs/s, fwd + bwd: 118.41 TFLOPs/s
Flash fwd: 64.12 TFLOPs/s, bwd: 52.82 TFLOPs/s, fwd + bwd: 55.62 TFLOPs/s
Pytorch fwd: 10.39 TFLOPs/s, bwd: 21.39 TFLOPs/s, fwd + bwd: 16.42 TFLOPs/s
xformers.f fwd: 132.68 TFLOPs/s, bwd: 101.60 TFLOPs/s, fwd + bwd: 108.89 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c91b0d0>
fn_amp(*inputs, **kwinputs)
886.85 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94db7cb50>
f(*inputs, y=y, grad=grad)
2.72 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 161.50 TFLOPs/s, bwd: 138.51 TFLOPs/s, fwd + bwd: 144.38 TFLOPs/s
Flash fwd: 128.09 TFLOPs/s, bwd: 106.54 TFLOPs/s, fwd + bwd: 111.92 TFLOPs/s
Pytorch fwd: 6.95 TFLOPs/s, bwd: 23.70 TFLOPs/s, fwd + bwd: 14.03 TFLOPs/s
xformers.f fwd: 154.97 TFLOPs/s, bwd: 126.22 TFLOPs/s, fwd + bwd: 133.29 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c919dd0>
fn_amp(*inputs, **kwinputs)
1.65 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbdea10>
f(*inputs, y=y, grad=grad)
4.78 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 175.56 TFLOPs/s, bwd: 156.57 TFLOPs/s, fwd + bwd: 161.56 TFLOPs/s
Flash fwd: 255.87 TFLOPs/s, bwd: 214.87 TFLOPs/s, fwd + bwd: 225.17 TFLOPs/s
Pytorch fwd: 8.92 TFLOPs/s, bwd: 24.52 TFLOPs/s, fwd + bwd: 16.35 TFLOPs/s
xformers.f fwd: 166.64 TFLOPs/s, bwd: 143.66 TFLOPs/s, fwd + bwd: 149.55 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbadcd0>
fn_amp(*inputs, **kwinputs)
3.18 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc34790>
f(*inputs, y=y, grad=grad)
8.88 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 182.20 TFLOPs/s, bwd: 169.09 TFLOPs/s, fwd + bwd: 172.64 TFLOPs/s
Flash fwd: 511.64 TFLOPs/s, bwd: 428.58 TFLOPs/s, fwd + bwd: 449.43 TFLOPs/s
Pytorch fwd: 10.39 TFLOPs/s, bwd: 25.40 TFLOPs/s, fwd + bwd: 17.98 TFLOPs/s
xformers.f fwd: 172.73 TFLOPs/s, bwd: 154.82 TFLOPs/s, fwd + bwd: 159.55 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbdee90>
fn_amp(*inputs, **kwinputs)
6.25 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94db7dcd0>
f(*inputs, y=y, grad=grad)
16.98 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 186.64 TFLOPs/s, bwd: 177.38 TFLOPs/s, fwd + bwd: 179.93 TFLOPs/s
Flash fwd: 1047.78 TFLOPs/s, bwd: 868.53 TFLOPs/s, fwd + bwd: 913.16 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
xformers.f fwd: 175.80 TFLOPs/s, bwd: 161.91 TFLOPs/s, fwd + bwd: 165.65 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dbdf850>
fn_amp(*inputs, **kwinputs)
302.01 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94605fed0>
f(*inputs, y=y, grad=grad)
1.05 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=128, batch_size=32, seqlen=512 ###
Flash2 fwd: 123.78 TFLOPs/s, bwd: 87.24 TFLOPs/s, fwd + bwd: 95.28 TFLOPs/s
Flash fwd: 33.65 TFLOPs/s, bwd: 27.77 TFLOPs/s, fwd + bwd: 29.23 TFLOPs/s
Pytorch fwd: 15.13 TFLOPs/s, bwd: 29.73 TFLOPs/s, fwd + bwd: 23.31 TFLOPs/s
xformers.f fwd: 113.77 TFLOPs/s, bwd: 81.54 TFLOPs/s, fwd + bwd: 88.72 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94605c0d0>
fn_amp(*inputs, **kwinputs)
460.27 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c910350>
f(*inputs, y=y, grad=grad)
1.54 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=128, batch_size=16, seqlen=1024 ###
Flash2 fwd: 156.64 TFLOPs/s, bwd: 120.41 TFLOPs/s, fwd + bwd: 128.93 TFLOPs/s
Flash fwd: 67.30 TFLOPs/s, bwd: 55.36 TFLOPs/s, fwd + bwd: 58.32 TFLOPs/s
Pytorch fwd: 18.17 TFLOPs/s, bwd: 36.83 TFLOPs/s, fwd + bwd: 28.48 TFLOPs/s
xformers.f fwd: 149.30 TFLOPs/s, bwd: 111.68 TFLOPs/s, fwd + bwd: 120.34 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c92ded0>
fn_amp(*inputs, **kwinputs)
794.37 us
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c919650>
f(*inputs, y=y, grad=grad)
2.51 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=128, batch_size=8, seqlen=2048 ###
Flash2 fwd: 180.55 TFLOPs/s, bwd: 148.23 TFLOPs/s, fwd + bwd: 156.22 TFLOPs/s
Flash fwd: 134.72 TFLOPs/s, bwd: 111.35 TFLOPs/s, fwd + bwd: 117.15 TFLOPs/s
Pytorch fwd: 13.16 TFLOPs/s, bwd: 43.07 TFLOPs/s, fwd + bwd: 26.12 TFLOPs/s
xformers.f fwd: 173.02 TFLOPs/s, bwd: 137.07 TFLOPs/s, fwd + bwd: 145.72 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc46a90>
fn_amp(*inputs, **kwinputs)
1.46 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c9221d0>
f(*inputs, y=y, grad=grad)
4.43 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=128, batch_size=4, seqlen=4096 ###
Flash2 fwd: 194.13 TFLOPs/s, bwd: 167.47 TFLOPs/s, fwd + bwd: 174.31 TFLOPs/s
Flash fwd: 269.37 TFLOPs/s, bwd: 222.47 TFLOPs/s, fwd + bwd: 234.12 TFLOPs/s
Pytorch fwd: 16.65 TFLOPs/s, bwd: 44.24 TFLOPs/s, fwd + bwd: 30.02 TFLOPs/s
xformers.f fwd: 188.35 TFLOPs/s, bwd: 155.22 TFLOPs/s, fwd + bwd: 163.44 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94dc52d90>
fn_amp(*inputs, **kwinputs)
2.79 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c91e1d0>
f(*inputs, y=y, grad=grad)
8.21 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 201.61 TFLOPs/s, bwd: 180.81 TFLOPs/s, fwd + bwd: 186.30 TFLOPs/s
Flash fwd: 536.48 TFLOPs/s, bwd: 446.26 TFLOPs/s, fwd + bwd: 468.79 TFLOPs/s
Pytorch fwd: 19.63 TFLOPs/s, bwd: 48.71 TFLOPs/s, fwd + bwd: 34.22 TFLOPs/s
xformers.f fwd: 196.80 TFLOPs/s, bwd: 167.47 TFLOPs/s, fwd + bwd: 174.92 TFLOPs/s
- Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c934d10>
fn_amp(*inputs, **kwinputs)
5.52 ms
1 measurement, 10 runs , 8 threads
- Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa94c9120d0>
f(*inputs, y=y, grad=grad)
15.59 ms
1 measurement, 10 runs , 8 threads
### causal=True, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 204.29 TFLOPs/s, bwd: 191.33 TFLOPs/s, fwd + bwd: 194.86 TFLOPs/s
Flash fwd: 1104.60 TFLOPs/s, bwd: 890.77 TFLOPs/s, fwd + bwd: 942.92 TFLOPs/s
Pytorch fwd: 20.57 TFLOPs/s, bwd: 49.27 TFLOPs/s, fwd + bwd: 35.22 TFLOPs/s
xformers.f fwd: 199.26 TFLOPs/s, bwd: 176.33 TFLOPs/s, fwd + bwd: 182.33 TFLOPs/s
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment