Skip to content

Commit 1420ea6

Browse files
committed
add docstring
1 parent 3fe0c0f commit 1420ea6

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

transformer_engine/pytorch/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,11 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:
312312
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:
313313

314314
def torch_get_autocast_gpu_dtype() -> torch.dtype:
315+
"""Get PyTorch autocast GPU dtype."""
315316
return torch.get_autocast_dtype("cuda")
316317

317318
else:
318319

319320
def torch_get_autocast_gpu_dtype() -> torch.dtype:
321+
"""Get PyTorch autocast GPU dtype."""
320322
return torch.get_autocast_gpu_dtype()

0 commit comments

Comments
 (0)