12
12
import torch .nn as nn
13
13
from torch .nn import Parameter
14
14
15
+ from transformer_engine .pytorch .fp8 import fp8_autocast , FP8GlobalStateManager
15
16
from transformer_engine .pytorch .utils import (
16
17
init_method_normal ,
17
18
scaled_init_method_normal ,
25
26
from transformer_engine .pytorch .distributed import _set_cuda_rng_state , CudaRNGStatesTracker
26
27
27
28
29
+ # Only run FP8 tests on H100.
30
+ fp8_available , reason_for_no_fp8 = FP8GlobalStateManager .is_fp8_available ()
31
+
32
+
28
33
seed = 1234
29
34
torch .manual_seed (seed )
30
35
torch .cuda .manual_seed (seed )
@@ -90,20 +95,11 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float)
90
95
91
96
92
97
def reset_rng_states () -> None :
93
- # revert back to initial RNG state.
98
+ """ revert back to initial RNG state."""
94
99
torch .set_rng_state (_cpu_rng_state )
95
100
_set_cuda_rng_state (_cuda_rng_state )
96
101
97
102
98
- _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker ()
99
- _DUMMY_CUDA_RNG_STATE_TRACKER .add ("model-parallel-rng" , seed )
100
-
101
-
102
- def get_dummy_cuda_rng_tracker ():
103
- """Get cuda rng tracker."""
104
- return _DUMMY_CUDA_RNG_STATE_TRACKER
105
-
106
-
107
103
class TorchScaledMaskedSoftmax (nn .Module ):
108
104
def __init__ (self ) -> None :
109
105
super ().__init__ ()
@@ -343,41 +339,21 @@ def forward(
343
339
return x
344
340
345
341
346
- def _test_e2e_selective_recompute (block , bs , dtype , config , recompute = False ):
342
+ def _test_e2e_selective_recompute (bs , dtype , config , fp8 , recompute = False ):
347
343
reset_rng_states ()
348
-
349
- te_inp_hidden_states = torch .randn (
350
- config .seq_len , bs , config .hidden_size , dtype = dtype , requires_grad = True
351
- ).cuda ()
352
- te_inp_hidden_states .retain_grad ()
353
- te_inp_attn_mask = get_causal_attn_mask (config .seq_len )
354
-
355
- te_out = block (
356
- te_inp_hidden_states ,
357
- attention_mask = te_inp_attn_mask ,
358
- checkpoint_core_attention = recompute ,
359
- )
360
- loss = te_out .sum ()
361
- loss .backward ()
362
- torch .cuda .synchronize ()
363
-
364
- outputs = [te_out , te_inp_hidden_states .grad ]
365
- for p in block .parameters ():
366
- if p .requires_grad :
367
- outputs .append (p .grad )
368
- return outputs
369
-
370
-
371
- @pytest .mark .parametrize ("dtype" , param_types )
372
- @pytest .mark .parametrize ("bs" , batch_sizes )
373
- @pytest .mark .parametrize ("model" , model_configs .keys ())
374
- def test_gpt_selective_activation_recompute (dtype , bs , model ):
375
- config = model_configs [model ]
344
+ FP8GlobalStateManager .reset ()
376
345
377
346
sigma = 0.023
378
347
init_method = init_method_normal (sigma )
379
348
output_layer_init_method = scaled_init_method_normal (sigma , config .num_layers )
380
349
350
+ _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker ()
351
+ _DUMMY_CUDA_RNG_STATE_TRACKER .add ("model-parallel-rng" , seed )
352
+
353
+ def get_dummy_cuda_rng_tracker ():
354
+ """Get cuda rng tracker."""
355
+ return _DUMMY_CUDA_RNG_STATE_TRACKER
356
+
381
357
block = (
382
358
TransformerLayer (
383
359
config .hidden_size ,
@@ -395,38 +371,19 @@ def test_gpt_selective_activation_recompute(dtype, bs, model):
395
371
params_dtype = dtype ,
396
372
)
397
373
.cuda ()
398
- .eval ()
399
374
)
400
375
401
- outputs = _test_e2e_selective_recompute (block , bs , dtype , config , recompute = False )
402
- outputs_recompute = _test_e2e_selective_recompute (block , bs , dtype , config , recompute = True )
403
- assert_all_equal (outputs , outputs_recompute )
404
-
405
-
406
- def _test_e2e_full_recompute (block , bs , dtype , config , recompute = False ):
407
- reset_rng_states ()
408
-
409
376
te_inp_hidden_states = torch .randn (
410
377
config .seq_len , bs , config .hidden_size , dtype = dtype , requires_grad = True
411
378
).cuda ()
412
379
te_inp_hidden_states .retain_grad ()
413
380
te_inp_attn_mask = get_causal_attn_mask (config .seq_len )
414
381
415
- if recompute :
416
- te_out = te_checkpoint (
417
- block ,
418
- False , # distribute_saved_activations
419
- get_dummy_cuda_rng_tracker ,
420
- None , # tp_group
421
- te_inp_hidden_states ,
422
- attention_mask = te_inp_attn_mask ,
423
- checkpoint_core_attention = False ,
424
- )
425
- else :
382
+ with fp8_autocast (enabled = fp8 ):
426
383
te_out = block (
427
384
te_inp_hidden_states ,
428
385
attention_mask = te_inp_attn_mask ,
429
- checkpoint_core_attention = False ,
386
+ checkpoint_core_attention = recompute ,
430
387
)
431
388
loss = te_out .sum ()
432
389
loss .backward ()
@@ -442,13 +399,33 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
442
399
@pytest .mark .parametrize ("dtype" , param_types )
443
400
@pytest .mark .parametrize ("bs" , batch_sizes )
444
401
@pytest .mark .parametrize ("model" , model_configs .keys ())
445
- def test_gpt_full_activation_recompute (dtype , bs , model ):
402
+ @pytest .mark .parametrize ("fp8" , all_boolean )
403
+ def test_gpt_selective_activation_recompute (dtype , bs , model , fp8 ):
404
+ if fp8 and not fp8_available :
405
+ pytest .skip (reason_for_no_fp8 )
406
+
446
407
config = model_configs [model ]
447
408
409
+ outputs = _test_e2e_selective_recompute (bs , dtype , config , fp8 , recompute = False )
410
+ outputs_recompute = _test_e2e_selective_recompute (bs , dtype , config , fp8 , recompute = True )
411
+ assert_all_equal (outputs , outputs_recompute )
412
+
413
+
414
+ def _test_e2e_full_recompute (bs , dtype , config , fp8 , recompute = False ):
415
+ reset_rng_states ()
416
+ FP8GlobalStateManager .reset ()
417
+
448
418
sigma = 0.023
449
419
init_method = init_method_normal (sigma )
450
420
output_layer_init_method = scaled_init_method_normal (sigma , config .num_layers )
451
421
422
+ _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker ()
423
+ _DUMMY_CUDA_RNG_STATE_TRACKER .add ("model-parallel-rng" , seed )
424
+
425
+ def get_dummy_cuda_rng_tracker ():
426
+ """Get cuda rng tracker."""
427
+ return _DUMMY_CUDA_RNG_STATE_TRACKER
428
+
452
429
block = (
453
430
TransformerLayer (
454
431
config .hidden_size ,
@@ -466,11 +443,54 @@ def test_gpt_full_activation_recompute(dtype, bs, model):
466
443
params_dtype = dtype ,
467
444
)
468
445
.cuda ()
469
- .eval ()
470
446
)
471
447
472
- outputs = _test_e2e_full_recompute (block , bs , dtype , config , recompute = False )
473
- outputs_recompute = _test_e2e_full_recompute (block , bs , dtype , config , recompute = True )
448
+ te_inp_hidden_states = torch .randn (
449
+ config .seq_len , bs , config .hidden_size , dtype = dtype , requires_grad = True
450
+ ).cuda ()
451
+ te_inp_hidden_states .retain_grad ()
452
+ te_inp_attn_mask = get_causal_attn_mask (config .seq_len )
453
+
454
+ with fp8_autocast (enabled = fp8 ):
455
+ if recompute :
456
+ te_out = te_checkpoint (
457
+ block ,
458
+ False , # distribute_saved_activations
459
+ get_dummy_cuda_rng_tracker ,
460
+ None , # tp_group
461
+ te_inp_hidden_states ,
462
+ attention_mask = te_inp_attn_mask ,
463
+ checkpoint_core_attention = False ,
464
+ )
465
+ else :
466
+ te_out = block (
467
+ te_inp_hidden_states ,
468
+ attention_mask = te_inp_attn_mask ,
469
+ checkpoint_core_attention = False ,
470
+ )
471
+ loss = te_out .sum ()
472
+ loss .backward ()
473
+ torch .cuda .synchronize ()
474
+
475
+ outputs = [te_out , te_inp_hidden_states .grad ]
476
+ for p in block .parameters ():
477
+ if p .requires_grad :
478
+ outputs .append (p .grad )
479
+ return outputs
480
+
481
+
482
+ @pytest .mark .parametrize ("dtype" , param_types )
483
+ @pytest .mark .parametrize ("bs" , batch_sizes )
484
+ @pytest .mark .parametrize ("model" , model_configs .keys ())
485
+ @pytest .mark .parametrize ("fp8" , all_boolean )
486
+ def test_gpt_full_activation_recompute (dtype , bs , model , fp8 ):
487
+ if fp8 and not fp8_available :
488
+ pytest .skip (reason_for_no_fp8 )
489
+
490
+ config = model_configs [model ]
491
+
492
+ outputs = _test_e2e_full_recompute (bs , dtype , config , fp8 , recompute = False )
493
+ outputs_recompute = _test_e2e_full_recompute (bs , dtype , config , fp8 , recompute = True )
474
494
assert_all_equal (outputs , outputs_recompute )
475
495
476
496
@@ -565,8 +585,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
565
585
def test_gpt_checkpointing (dtype , bs , model ):
566
586
config = model_configs [model ]
567
587
outputs = _test_e2e_checkpointing (bs , dtype , config , checkpoint = False )
568
- outputs_recompute = _test_e2e_checkpointing (bs , dtype , config , checkpoint = True )
569
- assert_all_equal (outputs , outputs_recompute )
588
+ outputs_checkpoint = _test_e2e_checkpointing (bs , dtype , config , checkpoint = True )
589
+ assert_all_equal (outputs , outputs_checkpoint )
570
590
571
591
572
592
def _test_e2e_gpt_accuracy (block , bs , dtype , config ):
0 commit comments