diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 28c1b45ffa..8159f20e90 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7952,7 +7952,10 @@ def forward( assert ( key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" + ), ( + "Keys and values must have num_gqa_group =" + f" {self.num_gqa_groups_per_partition} heads!" + ) assert qkv_format in [ "sbhd", "bshd",