@@ -245,26 +245,53 @@ def in_fp8_activation_recompute_phase() -> bool:
245
245
return _FP8_ACTIVATION_RECOMPUTE_PHASE
246
246
247
247
248
- def _get_active_autocast_contexts ():
249
- """
250
- Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
251
- at the time of this function's execution.
252
- """
253
- autocast_cached = torch .is_autocast_cache_enabled ()
248
+ TORCH_MAJOR = int (torch .__version__ .split ("." )[0 ])
249
+ TORCH_MINOR = int (torch .__version__ .split ("." )[1 ])
250
+ if TORCH_MAJOR == 2 and TORCH_MINOR >= 4 :
254
251
255
- gpu_autocast_enabled = torch .is_autocast_enabled ()
256
- gpu_autocast_dtype = torch .get_autocast_gpu_dtype ()
257
- gpu_autocast_ctx = torch .cuda .amp .autocast (
258
- gpu_autocast_enabled , gpu_autocast_dtype , autocast_cached
259
- )
252
+ def _get_active_autocast_contexts ():
253
+ """
254
+ Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
255
+ at the time of this function's execution.
256
+ """
257
+ autocast_cached = torch .is_autocast_cache_enabled ()
260
258
261
- cpu_autocast_enabled = torch .is_autocast_cpu_enabled ()
262
- cpu_autocast_dtype = torch .get_autocast_cpu_dtype ()
263
- cpu_autocast_ctx = torch .cpu .amp .autocast (
264
- cpu_autocast_enabled , cpu_autocast_dtype , autocast_cached
265
- )
259
+ gpu_autocast_enabled = torch .is_autocast_enabled ("cuda" )
260
+ gpu_autocast_dtype = torch .get_autocast_dtype ("cuda" )
261
+ gpu_autocast_ctx = torch .amp .autocast (
262
+ "cuda" , gpu_autocast_enabled , gpu_autocast_dtype , autocast_cached
263
+ )
264
+
265
+ cpu_autocast_enabled = torch .is_autocast_enabled ("cpu" )
266
+ cpu_autocast_dtype = torch .get_autocast_dtype ("cpu" )
267
+ cpu_autocast_ctx = torch .amp .autocast (
268
+ "cpu" , cpu_autocast_enabled , cpu_autocast_dtype , autocast_cached
269
+ )
270
+
271
+ return gpu_autocast_ctx , cpu_autocast_ctx
272
+
273
+ else :
274
+
275
+ def _get_active_autocast_contexts ():
276
+ """
277
+ Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
278
+ at the time of this function's execution.
279
+ """
280
+ autocast_cached = torch .is_autocast_cache_enabled ()
281
+
282
+ gpu_autocast_enabled = torch .is_autocast_enabled ()
283
+ gpu_autocast_dtype = torch .get_autocast_gpu_dtype ()
284
+ gpu_autocast_ctx = torch .cuda .amp .autocast (
285
+ gpu_autocast_enabled , gpu_autocast_dtype , autocast_cached
286
+ )
287
+
288
+ cpu_autocast_enabled = torch .is_autocast_cpu_enabled ()
289
+ cpu_autocast_dtype = torch .get_autocast_cpu_dtype ()
290
+ cpu_autocast_ctx = torch .cpu .amp .autocast (
291
+ cpu_autocast_enabled , cpu_autocast_dtype , autocast_cached
292
+ )
266
293
267
- return gpu_autocast_ctx , cpu_autocast_ctx
294
+ return gpu_autocast_ctx , cpu_autocast_ctx
268
295
269
296
270
297
class _CheckpointFunction (torch .autograd .Function ):
0 commit comments