From 744a96f7dee37222b0bc9abd576196488ac23c3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 00:41:21 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index cfdf0fd472..5234fa59a6 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -927,15 +927,14 @@ def __init__( assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." self.ub_name = ub_name - assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), ( - "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time." - ) - assert not (self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad), ( - "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time." - ) assert not ( - self.ub_overlap_ag_dgrad - and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad) + self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop + ), "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time." + assert not ( + self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad + ), "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time." + assert not ( + self.ub_overlap_ag_dgrad and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad) ), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time." self.get_rng_state_tracker = get_rng_state_tracker