From 8c4d83633b17231e4af4828bffbbb348619b0b14 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 16 Dec 2024 05:11:29 -0800 Subject: [PATCH] fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 49bf07aa11..704b27df4b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -8166,14 +8166,14 @@ def forward( if "padding" in attn_mask_type: actual_seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] actual_seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - if ( - _alibi_cache["_max_seqlen_q"] != max_seqlen_q - or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bias_dtype"] != query_layer.dtype - or _alibi_cache["_bottom_right_diagonal"] != bottom_right_diagonal - or _alibi_cache["_actual_seqlens_q"] != actual_seqlens_q - or _alibi_cache["_actual_seqlens_kv"] != actual_seqlens_kv - ): + alibi_dict = {} + alibi_dict["_max_seqlen_q"] = max_seqlen_q + alibi_dict["_max_seqlen_kv"] = max_seqlen_kv + alibi_dict["_bias_dtype"] = query_layer.dtype + alibi_dict["_bottom_right_diagonal"] = bottom_right_diagonal + alibi_dict["_actual_seqlens_q"] = actual_seqlens_q + alibi_dict["_actual_seqlens_kv"] = actual_seqlens_kv + if any(y != _alibi_cache[x] for x,y in alibi_dict.items()): _alibi_cache["_alibi_bias_require_update"] = True core_attention_bias_shape = None