diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ef6a96fe4d..49bf07aa11 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1266,7 +1266,10 @@ def get_alibi( _alibi_cache["_bottom_right_diagonal"] = bottom_right_diagonal bias_dtype = torch.float32 if bias_dtype is None else bias_dtype _alibi_cache["_bias_dtype"] = bias_dtype - _alibi_cache["_actual_seqlens_q"], _alibi_cache["_actual_seqlens_kv"] = actual_seqlens_q, actual_seqlens_kv + _alibi_cache["_actual_seqlens_q"], _alibi_cache["_actual_seqlens_kv"] = ( + actual_seqlens_q, + actual_seqlens_kv, + ) _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") _alibi_cache["_alibi_bias_require_update"] = False