diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 62ffec2cd6..5ca525ce97 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1075,7 +1075,7 @@ def get_swa_mask( attn_mask_type = "arbitrary" mask = mask_lower.logical_not() if attention_mask is not None: - mask = torch.logical_and(attention_mask, mask) + mask = torch.logical_or(attention_mask, mask) return attn_mask_type, mask