You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Specifically, whether a new attention backend should be updated is indicated by a global flag, _attention_backends["backend_selection_requires_update"]. And this flag will be updated by the comparison result of attention_params != _attention_backends["attention_params"].
Nevertheless, fp8 training / inference may come across attention-backend update throttle, where each iteration will trigger the dispensable selection of new backend, resulting in the same selection result as last one.
Hence, if AttentionParams.fp8_meta["recipe"] stays the same across different iterations, while other fields of AttentionParams.fp8_meta are altered, attention_params != _attention_backends["attention_params"] will always trigger get_attention_backend(), leading to significant overhead.
Solution
The solution is quite straight: only compares fields used in get_attention_backend() when conducting attention_params != _attention_backends["attention_params"], disregarding other non-used fields.
Hi,
Inside the
DotProductAttention
layer,get_attention_backend()
will be invoked, trying to select best-fitattention implementation CUDA kernels
. See https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L8145-L8160Specifically, whether a
new attention backend
should be updated is indicated by a global flag,_attention_backends["backend_selection_requires_update"]
. And this flag will be updated by the comparison result ofattention_params != _attention_backends["attention_params"]
.Nevertheless,
fp8
training / inference may come acrossattention-backend update
throttle, where each iteration will trigger the dispensable selection of new backend, resulting in the same selection result as last one.The root cause is located at
class AttentionParams
. https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L220-L298attention_params != _attention_backends["attention_params"]
will compare all data fields ofclass AttentionParams
, if one of them doesn't match, then this comparison will succeed, triggeringget_attention_backend()
.Inside
class AttentionParams
, onlyfp8
andfp8_meta
are designated forFP8
workflows. Most importantly, onlyAttentionParams.fp8_meta["recipe"]
will be used to pass through different filters to select best-fit kernel inget_attention_backend()
, whilst other fields ofAttentionParams.fp8_meta
don't. https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L321-L977Hence, if
AttentionParams.fp8_meta["recipe"]
stays the same across different iterations, while other fields ofAttentionParams.fp8_meta
are altered,attention_params != _attention_backends["attention_params"]
will always triggerget_attention_backend()
, leading to significant overhead.Solution
The solution is quite straight: only compares fields used in
get_attention_backend()
when conductingattention_params != _attention_backends["attention_params"]
, disregarding other non-used fields.For example, add
__eq__
toclass AttentionParams
: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L298The text was updated successfully, but these errors were encountered: