diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 5f8357a01b..07001be742 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7446,7 +7446,7 @@ def __init__( ), "The number of attention heads must be divisible by the number of GQA groups!" self.rng_states_tracker = None - if sequence_parallel or get_rng_state_tracker is None: + if get_rng_state_tracker is None: attention_dropout_ctx = nullcontext else: self.rng_states_tracker = get_rng_state_tracker()