Skip to content

Commit da2f240

Browse files
ksivamancyanguwa
authored andcommitted
[PyTorch] Fixes and tests for FP8 + activation recompute (NVIDIA#487)
* initial test fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Drop eval for selective checkpointing tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Remove redundant recompute for FA Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * CI fix; Decouple fused attention and numerics tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Charlene Yang <[email protected]>
1 parent b682a04 commit da2f240

File tree

5 files changed

+154
-90
lines changed

5 files changed

+154
-90
lines changed

tests/pytorch/test_fused_attn.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
QKVLayout,
2626
fused_attn_bwd,
2727
fused_attn_fwd,
28-
fused_attn_bwd_qkvpacked,
29-
fused_attn_fwd_qkvpacked,
3028
)
3129
import transformer_engine.pytorch.fp8 as fp8
3230
from transformer_engine.pytorch.module.base import (
@@ -38,20 +36,38 @@
3836
init_method_normal,
3937
scaled_init_method_normal,
4038
)
39+
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
4140
import transformer_engine_extensions as tex
4241

43-
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
42+
43+
# Only run FP8 tests on H100.
4444
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
4545
_flash_attn_version = packaging.version.Version(version("flash-attn"))
4646
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
4747

48+
49+
seed = 1234
50+
torch.manual_seed(seed)
51+
torch.cuda.manual_seed(seed)
52+
# Record initial RNG state from script run.
53+
_cpu_rng_state = torch.get_rng_state()
54+
_cuda_rng_state = torch.cuda.get_rng_state()
55+
56+
4857
def _get_cudnn_version():
4958
cudnn_version_encoded = ext.get_cudnn_version()
5059
cudnn_major = cudnn_version_encoded // 1000
5160
cudnn_minor = (cudnn_version_encoded - cudnn_major * 1000) // 100
5261
cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor
5362
return [cudnn_major, cudnn_minor, cudnn_patch]
5463

64+
65+
def reset_rng_states() -> None:
66+
"""revert back to initial RNG state."""
67+
torch.set_rng_state(_cpu_rng_state)
68+
_set_cuda_rng_state(_cuda_rng_state)
69+
70+
5571
_cudnn_version = _get_cudnn_version()
5672

5773

@@ -212,6 +228,13 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
212228
else:
213229
bias = None
214230

231+
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
232+
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
233+
234+
def get_dummy_cuda_rng_tracker():
235+
"""Get cuda rng tracker."""
236+
return _DUMMY_CUDA_RNG_STATE_TRACKER
237+
215238
block = (
216239
DotProductAttention(
217240
config.num_attention_heads,
@@ -774,6 +797,13 @@ def _run_dpa_fp8_ref(dtype, bs, config, backend):
774797
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
775798
op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)
776799

800+
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
801+
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
802+
803+
def get_dummy_cuda_rng_tracker():
804+
"""Get cuda rng tracker."""
805+
return _DUMMY_CUDA_RNG_STATE_TRACKER
806+
777807
block = (
778808
DotProductAttention(
779809
config.num_attention_heads,

tests/pytorch/test_numerics.py

Lines changed: 86 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn as nn
1313
from torch.nn import Parameter
1414

15+
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
1516
from transformer_engine.pytorch.utils import (
1617
init_method_normal,
1718
scaled_init_method_normal,
@@ -25,6 +26,10 @@
2526
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
2627

2728

29+
# Only run FP8 tests on H100.
30+
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
31+
32+
2833
seed = 1234
2934
torch.manual_seed(seed)
3035
torch.cuda.manual_seed(seed)
@@ -90,20 +95,11 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float)
9095

9196

9297
def reset_rng_states() -> None:
93-
# revert back to initial RNG state.
98+
"""revert back to initial RNG state."""
9499
torch.set_rng_state(_cpu_rng_state)
95100
_set_cuda_rng_state(_cuda_rng_state)
96101

97102

98-
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
99-
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
100-
101-
102-
def get_dummy_cuda_rng_tracker():
103-
"""Get cuda rng tracker."""
104-
return _DUMMY_CUDA_RNG_STATE_TRACKER
105-
106-
107103
class TorchScaledMaskedSoftmax(nn.Module):
108104
def __init__(self) -> None:
109105
super().__init__()
@@ -343,41 +339,21 @@ def forward(
343339
return x
344340

345341

346-
def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False):
342+
def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
347343
reset_rng_states()
348-
349-
te_inp_hidden_states = torch.randn(
350-
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
351-
).cuda()
352-
te_inp_hidden_states.retain_grad()
353-
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
354-
355-
te_out = block(
356-
te_inp_hidden_states,
357-
attention_mask=te_inp_attn_mask,
358-
checkpoint_core_attention=recompute,
359-
)
360-
loss = te_out.sum()
361-
loss.backward()
362-
torch.cuda.synchronize()
363-
364-
outputs = [te_out, te_inp_hidden_states.grad]
365-
for p in block.parameters():
366-
if p.requires_grad:
367-
outputs.append(p.grad)
368-
return outputs
369-
370-
371-
@pytest.mark.parametrize("dtype", param_types)
372-
@pytest.mark.parametrize("bs", batch_sizes)
373-
@pytest.mark.parametrize("model", model_configs.keys())
374-
def test_gpt_selective_activation_recompute(dtype, bs, model):
375-
config = model_configs[model]
344+
FP8GlobalStateManager.reset()
376345

377346
sigma = 0.023
378347
init_method = init_method_normal(sigma)
379348
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
380349

350+
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
351+
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
352+
353+
def get_dummy_cuda_rng_tracker():
354+
"""Get cuda rng tracker."""
355+
return _DUMMY_CUDA_RNG_STATE_TRACKER
356+
381357
block = (
382358
TransformerLayer(
383359
config.hidden_size,
@@ -395,38 +371,19 @@ def test_gpt_selective_activation_recompute(dtype, bs, model):
395371
params_dtype=dtype,
396372
)
397373
.cuda()
398-
.eval()
399374
)
400375

401-
outputs = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False)
402-
outputs_recompute = _test_e2e_selective_recompute(block, bs, dtype, config, recompute=True)
403-
assert_all_equal(outputs, outputs_recompute)
404-
405-
406-
def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
407-
reset_rng_states()
408-
409376
te_inp_hidden_states = torch.randn(
410377
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
411378
).cuda()
412379
te_inp_hidden_states.retain_grad()
413380
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
414381

415-
if recompute:
416-
te_out = te_checkpoint(
417-
block,
418-
False, # distribute_saved_activations
419-
get_dummy_cuda_rng_tracker,
420-
None, # tp_group
421-
te_inp_hidden_states,
422-
attention_mask=te_inp_attn_mask,
423-
checkpoint_core_attention=False,
424-
)
425-
else:
382+
with fp8_autocast(enabled=fp8):
426383
te_out = block(
427384
te_inp_hidden_states,
428385
attention_mask=te_inp_attn_mask,
429-
checkpoint_core_attention=False,
386+
checkpoint_core_attention=recompute,
430387
)
431388
loss = te_out.sum()
432389
loss.backward()
@@ -442,13 +399,33 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
442399
@pytest.mark.parametrize("dtype", param_types)
443400
@pytest.mark.parametrize("bs", batch_sizes)
444401
@pytest.mark.parametrize("model", model_configs.keys())
445-
def test_gpt_full_activation_recompute(dtype, bs, model):
402+
@pytest.mark.parametrize("fp8", all_boolean)
403+
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8):
404+
if fp8 and not fp8_available:
405+
pytest.skip(reason_for_no_fp8)
406+
446407
config = model_configs[model]
447408

409+
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False)
410+
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True)
411+
assert_all_equal(outputs, outputs_recompute)
412+
413+
414+
def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
415+
reset_rng_states()
416+
FP8GlobalStateManager.reset()
417+
448418
sigma = 0.023
449419
init_method = init_method_normal(sigma)
450420
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
451421

422+
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
423+
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
424+
425+
def get_dummy_cuda_rng_tracker():
426+
"""Get cuda rng tracker."""
427+
return _DUMMY_CUDA_RNG_STATE_TRACKER
428+
452429
block = (
453430
TransformerLayer(
454431
config.hidden_size,
@@ -466,11 +443,54 @@ def test_gpt_full_activation_recompute(dtype, bs, model):
466443
params_dtype=dtype,
467444
)
468445
.cuda()
469-
.eval()
470446
)
471447

472-
outputs = _test_e2e_full_recompute(block, bs, dtype, config, recompute=False)
473-
outputs_recompute = _test_e2e_full_recompute(block, bs, dtype, config, recompute=True)
448+
te_inp_hidden_states = torch.randn(
449+
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
450+
).cuda()
451+
te_inp_hidden_states.retain_grad()
452+
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
453+
454+
with fp8_autocast(enabled=fp8):
455+
if recompute:
456+
te_out = te_checkpoint(
457+
block,
458+
False, # distribute_saved_activations
459+
get_dummy_cuda_rng_tracker,
460+
None, # tp_group
461+
te_inp_hidden_states,
462+
attention_mask=te_inp_attn_mask,
463+
checkpoint_core_attention=False,
464+
)
465+
else:
466+
te_out = block(
467+
te_inp_hidden_states,
468+
attention_mask=te_inp_attn_mask,
469+
checkpoint_core_attention=False,
470+
)
471+
loss = te_out.sum()
472+
loss.backward()
473+
torch.cuda.synchronize()
474+
475+
outputs = [te_out, te_inp_hidden_states.grad]
476+
for p in block.parameters():
477+
if p.requires_grad:
478+
outputs.append(p.grad)
479+
return outputs
480+
481+
482+
@pytest.mark.parametrize("dtype", param_types)
483+
@pytest.mark.parametrize("bs", batch_sizes)
484+
@pytest.mark.parametrize("model", model_configs.keys())
485+
@pytest.mark.parametrize("fp8", all_boolean)
486+
def test_gpt_full_activation_recompute(dtype, bs, model, fp8):
487+
if fp8 and not fp8_available:
488+
pytest.skip(reason_for_no_fp8)
489+
490+
config = model_configs[model]
491+
492+
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False)
493+
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True)
474494
assert_all_equal(outputs, outputs_recompute)
475495

476496

@@ -565,8 +585,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
565585
def test_gpt_checkpointing(dtype, bs, model):
566586
config = model_configs[model]
567587
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
568-
outputs_recompute = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
569-
assert_all_equal(outputs, outputs_recompute)
588+
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
589+
assert_all_equal(outputs, outputs_checkpoint)
570590

571591

572592
def _test_e2e_gpt_accuracy(block, bs, dtype, config):

transformer_engine/pytorch/attention.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,19 +2164,6 @@ def forward(
21642164
)
21652165

21662166
if use_flash_attention:
2167-
if checkpoint_core_attention:
2168-
return self._checkpointed_attention_forward(self.flash_attention,
2169-
query_layer,
2170-
key_layer,
2171-
value_layer,
2172-
attention_mask=attention_mask,
2173-
qkv_layout=qkv_layout,
2174-
cu_seqlens_q=cu_seqlens_q,
2175-
cu_seqlens_kv=cu_seqlens_kv,
2176-
attn_mask_type=attn_mask_type,
2177-
cp_group=self.cp_group,
2178-
cp_global_ranks=self.cp_global_ranks,
2179-
cp_stream=self.cp_stream)
21802167
return self.flash_attention(query_layer,
21812168
key_layer,
21822169
value_layer,

transformer_engine/pytorch/fp8.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,29 @@ class FP8GlobalStateManager:
7575
dp_amax_reduce_forward_idx = 0
7676
dp_amax_reduce_backward_idx = 0
7777

78+
@classmethod
79+
def reset(cls) -> None:
80+
"""Reset the global state"""
81+
cls.FP8_ENABLED = False
82+
cls.FP8_CALIBRATION = False
83+
cls.FP8_RECIPE = None
84+
cls.FP8_DISTRIBUTED_GROUP = None
85+
cls.IS_FIRST_FP8_MODULE = False
86+
cls.FP8_AUTOCAST_COUNTER = 0
87+
cls.FP8_CURRENT_CONTEXT_ID = 0
88+
cls.FP8_AUTOCAST_DEPTH = 0
89+
cls.global_fp8_buffer = {}
90+
cls.fp8_tensors_recompute_buffer = []
91+
cls.amax_forward_global_reduce_func = None
92+
cls.buffer_delete_key_fwd = None
93+
cls.buffer_delete_key_bwd = None
94+
cls.amax_reduce_handle_fwd = None
95+
cls.fp8_available = None
96+
cls.reason_for_no_fp8 = ""
97+
cls.dp_amax_reduce_interval = None
98+
cls.dp_amax_reduce_forward_idx = 0
99+
cls.dp_amax_reduce_backward_idx = 0
100+
78101
@classmethod
79102
def is_fp8_available(cls) -> Tuple[bool, str]:
80103
"""Return if fp8 support is available"""

0 commit comments

Comments
 (0)