Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
DMML-Replikation
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Matthias Keck
DMML-Replikation
Commits
bf0363e7
Commit
bf0363e7
authored
2 weeks ago
by
Matthias Keck
Browse files
Options
Downloads
Patches
Plain Diff
Attention-Benchmarking Code hinzugefügt
parent
74b29421
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
Attention-Benchmarking/bechmark_with_mem.py
+140
-0
140 additions, 0 deletions
Attention-Benchmarking/bechmark_with_mem.py
with
140 additions
and
0 deletions
Attention-Benchmarking/bechmark_with_mem.py
0 → 100644
+
140
−
0
View file @
bf0363e7
# Import Standard- und Deep-Learning-Bibliotheken
import
pickle
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
# für einfaches Tensor-Reshaping
from
flash_attn.utils.benchmark
import
benchmark_forward
,
benchmark_backward
,
benchmark_fwd_bwd
from
flash_attn
import
flash_attn_qkvpacked_func
# FlashAttention-2 Funktion
try
:
import
xformers.ops
as
xops
except
ImportError
:
xops
=
None
# # FLOP-Zahl schätzen: Anzahl der Operationen für die Attention je nach Modus
def
flops
(
batch
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"
fwd
"
):
assert
mode
in
[
"
fwd
"
,
"
bwd
"
,
"
fwd_bwd
"
]
f
=
4
*
batch
*
seqlen
**
2
*
nheads
*
headdim
//
(
2
if
causal
else
1
)
return
f
if
mode
==
"
fwd
"
else
(
2.5
*
f
if
mode
==
"
bwd
"
else
3.5
*
f
)
# FLOP/s berechnen – also wie effizient die GPU arbeitet
def
efficiency
(
flop
,
time
):
return
(
flop
/
time
/
10
**
12
)
if
not
math
.
isnan
(
time
)
else
0.0
# # Gibt aktuell genutzten und Peak-Speicher (MB) zurück
def
log_memory
():
return
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
),
torch
.
cuda
.
max_memory_allocated
()
/
(
1024
**
2
)
# Führt Benchmark durch (Forward + Backward) und misst zusätzlich den Speicherverbrauch
def
benchmark_with_memory
(
func
,
*
args
,
**
kwargs
):
torch
.
cuda
.
reset_peak_memory_stats
()
mem_before
,
_
=
log_memory
()
time_f
,
time_b
=
benchmark_fwd_bwd
(
func
,
*
args
,
**
kwargs
)
mem_after
,
mem_peak
=
log_memory
()
return
time_f
[
1
].
mean
,
time_b
[
1
].
mean
,
mem_after
-
mem_before
,
mem_peak
# Attention in PyTorch
def
attention_pytorch
(
qkv
,
dropout_p
=
0.0
,
causal
=
True
):
batch_size
,
seqlen
,
_
,
nheads
,
d
=
qkv
.
shape
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
# Extrahiere Q, K, V
# Reshape für Batch-Matrixmultiplikation
q
=
rearrange
(
q
,
'
b t h d -> (b h) t d
'
)
k
=
rearrange
(
k
,
'
b s h d -> (b h) d s
'
)
softmax_scale
=
1.0
/
math
.
sqrt
(
d
)
scores
=
torch
.
empty
(
batch_size
*
nheads
,
seqlen
,
seqlen
,
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
)
# Q @ K^T (mit Skalierung)
scores
=
rearrange
(
torch
.
baddbmm
(
scores
,
q
,
k
,
beta
=
0
,
alpha
=
softmax_scale
),
'
(b h) t s -> b h t s
'
,
h
=
nheads
)
if
causal
:
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
# Softmax + Dropout
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention_drop
=
F
.
dropout
(
attention
,
dropout_p
)
# Ausgabe berechnen
output
=
torch
.
einsum
(
'
bhts,bshd->bthd
'
,
attention_drop
,
v
)
return
output
.
to
(
dtype
=
qkv
.
dtype
)
# globale Einstellungen
repeats
=
30
device
=
'
cuda
'
dtype
=
torch
.
float16
# Batchgrößen und Sequenzlängen
bs_seqlen_vals
=
[(
32
,
512
),
(
16
,
1024
),
(
8
,
2048
),
(
4
,
4096
),
(
2
,
8192
),
(
1
,
16384
)]
# Causal Masking aus/an
causal_vals
=
[
False
,
True
]
# Head-Dimensionen 64 / 128
headdim_vals
=
[
64
,
128
]
dim
=
2048
dropout_p
=
0.0
# Attention-Implementierungen
methods
=
[
"
Flash2
"
,
"
Pytorch
"
]
+
([
"
xformers.f
"
]
if
xops
is
not
None
else
[])
results
=
{}
# Schleife über alle Konfigurationen
for
causal
in
causal_vals
:
for
headdim
in
headdim_vals
:
for
batch_size
,
seqlen
in
bs_seqlen_vals
:
config
=
(
causal
,
headdim
,
batch_size
,
seqlen
)
nheads
=
dim
//
headdim
# Initialisiere QKV-Tensor
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
results
[
config
]
=
{}
# Flash2
f
,
b
,
mem_used
,
mem_peak
=
benchmark_with_memory
(
flash_attn_qkvpacked_func
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
)
results
[
config
][
"
Flash2
"
]
=
{
"
fwd_time
"
:
f
,
"
bwd_time
"
:
b
,
"
mem_used_MB
"
:
mem_used
,
"
mem_peak_MB
"
:
mem_peak
}
# PyTorch Attention
try
:
qkv
=
qkv
.
detach
().
requires_grad_
(
True
)
f
,
b
,
mem_used
,
mem_peak
=
benchmark_with_memory
(
attention_pytorch
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
)
except
:
f
,
b
,
mem_used
,
mem_peak
=
float
(
'
nan
'
),
float
(
'
nan
'
),
float
(
'
nan
'
),
float
(
'
nan
'
)
results
[
config
][
"
Pytorch
"
]
=
{
"
fwd_time
"
:
f
,
"
bwd_time
"
:
b
,
"
mem_used_MB
"
:
mem_used
,
"
mem_peak_MB
"
:
mem_peak
}
# xFormers
if
xops
is
not
None
:
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
f
,
b
,
mem_used
,
mem_peak
=
benchmark_with_memory
(
xops
.
memory_efficient_attention
,
q
,
k
,
v
,
attn_bias
=
xops
.
LowerTriangularMask
()
if
causal
else
None
,
op
=
(
xops
.
fmha
.
flash
.
FwOp
,
xops
.
fmha
.
flash
.
BwOp
)
)
results
[
config
][
"
xformers.f
"
]
=
{
"
fwd_time
"
:
f
,
"
bwd_time
"
:
b
,
"
mem_used_MB
"
:
mem_used
,
"
mem_peak_MB
"
:
mem_peak
}
# Ergebnisse berechnen und ausgeben
print
(
f
"
### causal=
{
causal
}
, headdim=
{
headdim
}
, batch_size=
{
batch_size
}
, seqlen=
{
seqlen
}
###
"
)
for
method
in
methods
:
if
method
in
results
[
config
]:
entry
=
results
[
config
][
method
]
fwd_tflops
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
"
fwd
"
),
entry
[
"
fwd_time
"
])
bwd_tflops
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
"
bwd
"
),
entry
[
"
bwd_time
"
])
fwd_bwd_tflops
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
"
fwd_bwd
"
),
entry
[
"
fwd_time
"
]
+
entry
[
"
bwd_time
"
])
print
(
f
"
{
method
}
fwd:
{
fwd_tflops
:
.
2
f
}
TFLOPs/s, bwd:
{
bwd_tflops
:
.
2
f
}
TFLOPs/s, fwd+bwd:
{
fwd_bwd_tflops
:
.
2
f
}
TFLOPs/s, mem:
{
entry
[
'
mem_used_MB
'
]
:
.
2
f
}
MB (peak
{
entry
[
'
mem_peak_MB
'
]
:
.
2
f
}
MB)
"
)
# Ergebnisse speichern
with
open
(
'
flash2_attn_results.pkl
'
,
'
wb
'
)
as
fp
:
pickle
.
dump
(
results
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment