Skip to content

Commit

Permalink
fix kwargs for torch.amp.autocast
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Oct 21, 2024
1 parent 0987833 commit 35d7a31
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,19 @@ def _get_active_autocast_contexts():
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
"cuda", gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
"cuda",
enabled=gpu_autocast_enabled,
dtype=gpu_autocast_dtype,
cache_enabled=autocast_cached,
)

cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.amp.autocast(
"cpu", cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
"cpu",
enabled=cpu_autocast_enabled,
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)

return gpu_autocast_ctx, cpu_autocast_ctx
Expand Down

0 comments on commit 35d7a31

Please sign in to comment.