diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 49bf07aa11..704b27df4b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -8166,14 +8166,14 @@ def forward( if "padding" in attn_mask_type: actual_seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] actual_seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - if ( - _alibi_cache["_max_seqlen_q"] != max_seqlen_q - or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bias_dtype"] != query_layer.dtype - or _alibi_cache["_bottom_right_diagonal"] != bottom_right_diagonal - or _alibi_cache["_actual_seqlens_q"] != actual_seqlens_q - or _alibi_cache["_actual_seqlens_kv"] != actual_seqlens_kv - ): + alibi_dict = {} + alibi_dict["_max_seqlen_q"] = max_seqlen_q + alibi_dict["_max_seqlen_kv"] = max_seqlen_kv + alibi_dict["_bias_dtype"] = query_layer.dtype + alibi_dict["_bottom_right_diagonal"] = bottom_right_diagonal + alibi_dict["_actual_seqlens_q"] = actual_seqlens_q + alibi_dict["_actual_seqlens_kv"] = actual_seqlens_kv + if any(y != _alibi_cache[x] for x,y in alibi_dict.items()): _alibi_cache["_alibi_bias_require_update"] = True core_attention_bias_shape = None