From 64a506ca01b67a4691ef2aa26b7a7ddaa3492a24 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 22 Oct 2024 20:44:07 -0700 Subject: [PATCH] check torch version inside functions Signed-off-by: Xin Yao --- transformer_engine/pytorch/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 935838ad3a..90a0edc2c0 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -313,5 +313,4 @@ def torch_get_autocast_gpu_dtype() -> torch.dtype: TORCH_MINOR = int(torch.__version__.split(".")[1]) if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: return torch.get_autocast_dtype("cuda") - else: - return torch.get_autocast_gpu_dtype() + return torch.get_autocast_gpu_dtype()