Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] attention_backend update throttle #1348

Open
Jianbing-D opened this issue Nov 27, 2024 · 0 comments
Open

[Bug] attention_backend update throttle #1348

Jianbing-D opened this issue Nov 27, 2024 · 0 comments

Comments

@Jianbing-D
Copy link

Hi,

Inside the DotProductAttention layer, get_attention_backend() will be invoked, trying to select best-fit attention implementation CUDA kernels. See https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L8145-L8160

global _attention_backends, _use_flash_attn_3
if (
    _attention_backends["attention_params"] is None
    or attention_params != _attention_backends["attention_params"]
):
    _attention_backends["attention_params"] = attention_params
    _attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]:
    _use_flash_attn_3 = _flash_attn_3_is_installed
    (
        use_flash_attention,
        use_fused_attention,
        fused_attention_backend,
        use_unfused_attention,
        _,
    ) = get_attention_backend(attention_params)

Specifically, whether a new attention backend should be updated is indicated by a global flag, _attention_backends["backend_selection_requires_update"]. And this flag will be updated by the comparison result of attention_params != _attention_backends["attention_params"].

Nevertheless, fp8 training / inference may come across attention-backend update throttle, where each iteration will trigger the dispensable selection of new backend, resulting in the same selection result as last one.

The root cause is located at class AttentionParams. https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L220-L298
attention_params != _attention_backends["attention_params"] will compare all data fields of class AttentionParams, if one of them doesn't match, then this comparison will succeed, triggering get_attention_backend().

Inside class AttentionParams, only fp8 and fp8_meta are designated for FP8 workflows. Most importantly, only AttentionParams.fp8_meta["recipe"] will be used to pass through different filters to select best-fit kernel in get_attention_backend(), whilst other fields of AttentionParams.fp8_meta don't. https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L321-L977

Hence, if AttentionParams.fp8_meta["recipe"] stays the same across different iterations, while other fields of AttentionParams.fp8_meta are altered, attention_params != _attention_backends["attention_params"] will always trigger get_attention_backend(), leading to significant overhead.

Solution

The solution is quite straight: only compares fields used in get_attention_backend() when conducting attention_params != _attention_backends["attention_params"], disregarding other non-used fields.

For example, add __eq__ to class AttentionParams: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L298

class AttentionParams:
    def __eq​​__(self, other):
        if not isinstance(other, self.__class__):
            return NotImplemented
        for field in fields(self):
            fname = field.name
            sf = getattr(self, fname)
            of = getattr(other, fname)
            iffp8_meta!= fname:
                if sf != of:
                    return False
            elif sf[“recipe”] != of[“recipe”]:
                return False
        return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant