Skip to content

Commit

Permalink
check torch version inside functions
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Oct 23, 2024
1 parent e70dfb1 commit 39a9fe6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
34 changes: 11 additions & 23 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,16 @@ def in_fp8_activation_recompute_phase() -> bool:
return _FP8_ACTIVATION_RECOMPUTE_PHASE


TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:

def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast
state at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()
def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()

TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
Expand All @@ -273,18 +272,7 @@ def _get_active_autocast_contexts():
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)

return gpu_autocast_ctx, cpu_autocast_ctx

else:

def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast
state at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()

else:
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
Expand All @@ -297,7 +285,7 @@ def _get_active_autocast_contexts():
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)

return gpu_autocast_ctx, cpu_autocast_ctx
return gpu_autocast_ctx, cpu_autocast_ctx


class _CheckpointFunction(torch.autograd.Function):
Expand Down
17 changes: 6 additions & 11 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,16 +307,11 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:
return device1 == device2


TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:

def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:
return torch.get_autocast_dtype("cuda")

else:

def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
else:
return torch.get_autocast_gpu_dtype()

0 comments on commit 39a9fe6

Please sign in to comment.