Skip to content

Commit 3fe0c0f

Browse files
committed
fix autocast related deprecation warning
Signed-off-by: Xin Yao <[email protected]>
1 parent 3ea7dd3 commit 3fe0c0f

File tree

8 files changed

+82
-33
lines changed

8 files changed

+82
-33
lines changed

tests/pytorch/test_fused_optimizer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from transformer_engine.pytorch import fp8_model_init
1515
from transformer_engine.pytorch.utils import is_bf16_compatible
1616
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
17+
from transformer_engine.pytorch.jit import gpu_autocast_ctx
1718

1819
# Check if FP8 is supported
1920
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@@ -333,7 +334,7 @@ def test_grad_scaler(self):
333334
gt_ = gt.clone()
334335

335336
# Reference
336-
with torch.cuda.amp.autocast(enabled=True):
337+
with gpu_autocast_ctx(enabled=True):
337338
y = self.model(x)
338339
loss = ((gt - y) ** 2).mean()
339340

@@ -342,7 +343,7 @@ def test_grad_scaler(self):
342343
scaler.update()
343344

344345
# DUT
345-
with torch.cuda.amp.autocast(enabled=True):
346+
with gpu_autocast_ctx(enabled=True):
346347
y = self.model_(x)
347348
loss_ = ((gt_ - y) ** 2).mean()
348349

@@ -384,7 +385,7 @@ def test_grad_scaler_capturable(self):
384385
gt_ = gt.clone()
385386

386387
# Reference
387-
with torch.cuda.amp.autocast(enabled=True):
388+
with gpu_autocast_ctx(enabled=True):
388389
y = self.model(x)
389390
loss = ((gt - y) ** 2).mean()
390391

@@ -393,7 +394,7 @@ def test_grad_scaler_capturable(self):
393394
scaler.update()
394395

395396
# DUT
396-
with torch.cuda.amp.autocast(enabled=True):
397+
with gpu_autocast_ctx(enabled=True):
397398
y = self.model_(x)
398399
loss_ = ((gt_ - y) ** 2).mean()
399400

@@ -442,7 +443,7 @@ def test_grad_scaler_capturable_master(self):
442443
gt_ = gt.clone()
443444

444445
# Reference
445-
with torch.cuda.amp.autocast(enabled=True):
446+
with gpu_autocast_ctx(enabled=True):
446447
y = self.model(x)
447448
loss = ((gt - y) ** 2).mean()
448449

@@ -451,7 +452,7 @@ def test_grad_scaler_capturable_master(self):
451452
scaler.update()
452453

453454
# DUT
454-
with torch.cuda.amp.autocast(enabled=True):
455+
with gpu_autocast_ctx(enabled=True):
455456
y = self.model_(x)
456457
loss_ = ((gt_ - y) ** 2).mean()
457458

transformer_engine/pytorch/distributed.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -245,26 +245,53 @@ def in_fp8_activation_recompute_phase() -> bool:
245245
return _FP8_ACTIVATION_RECOMPUTE_PHASE
246246

247247

248-
def _get_active_autocast_contexts():
249-
"""
250-
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
251-
at the time of this function's execution.
252-
"""
253-
autocast_cached = torch.is_autocast_cache_enabled()
248+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
249+
TORCH_MINOR = int(torch.__version__.split(".")[1])
250+
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:
254251

255-
gpu_autocast_enabled = torch.is_autocast_enabled()
256-
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
257-
gpu_autocast_ctx = torch.cuda.amp.autocast(
258-
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
259-
)
252+
def _get_active_autocast_contexts():
253+
"""
254+
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
255+
at the time of this function's execution.
256+
"""
257+
autocast_cached = torch.is_autocast_cache_enabled()
260258

261-
cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
262-
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
263-
cpu_autocast_ctx = torch.cpu.amp.autocast(
264-
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
265-
)
259+
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
260+
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
261+
gpu_autocast_ctx = torch.amp.autocast(
262+
"cuda", gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
263+
)
264+
265+
cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
266+
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
267+
cpu_autocast_ctx = torch.amp.autocast(
268+
"cpu", cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
269+
)
270+
271+
return gpu_autocast_ctx, cpu_autocast_ctx
272+
273+
else:
274+
275+
def _get_active_autocast_contexts():
276+
"""
277+
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
278+
at the time of this function's execution.
279+
"""
280+
autocast_cached = torch.is_autocast_cache_enabled()
281+
282+
gpu_autocast_enabled = torch.is_autocast_enabled()
283+
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
284+
gpu_autocast_ctx = torch.cuda.amp.autocast(
285+
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
286+
)
287+
288+
cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
289+
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
290+
cpu_autocast_ctx = torch.cpu.amp.autocast(
291+
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
292+
)
266293

267-
return gpu_autocast_ctx, cpu_autocast_ctx
294+
return gpu_autocast_ctx, cpu_autocast_ctx
268295

269296

270297
class _CheckpointFunction(torch.autograd.Function):

transformer_engine/pytorch/jit.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""NVFuser functions and JIT utilities"""
66
import os
77
from typing import Callable, Optional, Tuple
8+
from functools import partial
89

910
import torch
1011

@@ -33,6 +34,11 @@
3334
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
3435
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
3536

37+
if torch.__version__ >= "2.4":
38+
gpu_autocast_ctx = partial(torch.amp.autocast, device_type="cuda")
39+
else:
40+
gpu_autocast_ctx = torch.cuda.amp.autocast
41+
3642

3743
def set_jit_fusion_options() -> None:
3844
"""Set PyTorch JIT layer fusion options."""
@@ -110,7 +116,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
110116

111117
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
112118
"""Disable native AMP for bias_gelu_fused_"""
113-
with torch.cuda.amp.autocast(enabled=False):
119+
with gpu_autocast_ctx(enabled=False):
114120
if bias is not None and bias.numel() != 0:
115121
return bias_gelu_fused_(inp, bias)
116122
return gelu_fused_(inp)
@@ -120,7 +126,7 @@ def bgrad_dgelu_fused(
120126
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
121127
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
122128
"""Disable native AMP for `bgrad_dgelu_fused_`"""
123-
with torch.cuda.amp.autocast(enabled=False):
129+
with gpu_autocast_ctx(enabled=False):
124130
if bias is not None and bias.numel() != 0:
125131
return bgrad_dgelu_fused_(grad_output, inp, bias)
126132
return None, dgelu_fused_(grad_output, inp)
@@ -161,7 +167,7 @@ def bias_dropout_add_fused_train(
161167
) -> torch.Tensor:
162168
"""Disable native AMP and enable grad for BDA"""
163169
with torch.enable_grad():
164-
with torch.cuda.amp.autocast(enabled=False):
170+
with gpu_autocast_ctx(enabled=False):
165171
return bias_dropout_add_fused_train_(x, bias, residual, prob)
166172

167173

@@ -177,7 +183,7 @@ def bias_dropout_add_fused_inference(
177183
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
178184
) -> torch.Tensor:
179185
"""Disable native AMP for BDA"""
180-
with torch.cuda.amp.autocast(enabled=False):
186+
with gpu_autocast_ctx(enabled=False):
181187
return bias_dropout_add_fused_inference_(x, bias, residual, prob)
182188

183189

transformer_engine/pytorch/module/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from ..constants import dist_group_type
4040
from ..float8_tensor import Float8Tensor
41+
from ..utils import torch_get_autocast_gpu_dtype
4142

4243
__all__ = ["initialize_ub", "destroy_ub"]
4344

@@ -619,7 +620,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
619620
"""Get activation data type for AMP."""
620621
# Native AMP (`torch.autocast`) gets highest priority
621622
if torch.is_autocast_enabled():
622-
self.activation_dtype = torch.get_autocast_gpu_dtype()
623+
self.activation_dtype = torch_get_autocast_gpu_dtype()
623624
return
624625

625626
# All checks after this have already been performed once, thus skip

transformer_engine/pytorch/module/layernorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
layernorm_fwd_inf,
1717
)
1818
from ..jit import no_torch_dynamo
19-
from ..utils import cast_if_needed
19+
from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype
2020

2121
__all__ = ["LayerNorm"]
2222

@@ -193,7 +193,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
193193
# Note: This will soon be deprecated with
194194
# https://github.com/NVIDIA/TransformerEngine/pull/1033
195195
if torch.is_autocast_enabled():
196-
self.activation_dtype = torch.get_autocast_gpu_dtype()
196+
self.activation_dtype = torch_get_autocast_gpu_dtype()
197197
elif self.activation_dtype != inp.dtype:
198198
dtype = inp.dtype
199199
for name, param in self.named_parameters():

transformer_engine/pytorch/module/rmsnorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .. import cpp_extensions as tex
1515
from ..jit import no_torch_dynamo
16-
from ..utils import cast_if_needed
16+
from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype
1717

1818

1919
__all__ = ["RMSNorm"]
@@ -190,7 +190,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
190190
# Note: This will soon be deprecated with
191191
# https://github.com/NVIDIA/TransformerEngine/pull/1033
192192
if torch.is_autocast_enabled():
193-
self.activation_dtype = torch.get_autocast_gpu_dtype()
193+
self.activation_dtype = torch_get_autocast_gpu_dtype()
194194
elif self.activation_dtype != inp.dtype:
195195
dtype = inp.dtype
196196
for name, param in self.named_parameters():

transformer_engine/pytorch/transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from transformer_engine.pytorch.utils import (
2727
cast_if_needed,
2828
get_default_init_method,
29+
torch_get_autocast_gpu_dtype,
2930
)
3031
from transformer_engine.pytorch.constants import (
3132
AttnMaskTypes,
@@ -677,7 +678,7 @@ def forward(
677678

678679
# For AMP
679680
if torch.is_autocast_enabled():
680-
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
681+
hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
681682

682683
# Self attention.
683684
self_attention_outputs = self.self_attention(

transformer_engine/pytorch/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,16 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:
305305
index2 = torch.cuda.current_device()
306306
return index1 == index2
307307
return device1 == device2
308+
309+
310+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
311+
TORCH_MINOR = int(torch.__version__.split(".")[1])
312+
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:
313+
314+
def torch_get_autocast_gpu_dtype() -> torch.dtype:
315+
return torch.get_autocast_dtype("cuda")
316+
317+
else:
318+
319+
def torch_get_autocast_gpu_dtype() -> torch.dtype:
320+
return torch.get_autocast_gpu_dtype()

0 commit comments

Comments
 (0)