Skip to content

Commit

Permalink
[PyTorch] Non-reentrant mode for activation recompute (#670)
Browse files Browse the repository at this point in the history
* added non-reentrant mode support to TE checkpoint

Signed-off-by: Alp Dener <[email protected]>

* updated get_cuda_rng_tracker kwarg to get_rng_state_tracker to remain consistent with other TE API

Signed-off-by: Alp Dener <[email protected]>

* docstring cleanup

Signed-off-by: Alp Dener <[email protected]>

* added mechanism to disable bias_gelu_nvfusion in LayerNormMLP when checkpointing in non-reentrant mode

Signed-off-by: Alp Dener <[email protected]>

* refactored checkpoint and recompute hook names to match PyTorch implementation

Signed-off-by: Alp Dener <[email protected]>

* Fixed incorrect reference before assignment

Signed-off-by: Alp Dener <[email protected]>

* fixed argument error in calling native PyTorch checkpoint

Signed-off-by: Alp Dener <[email protected]>

* fixed linting errors for missing docstrings

Signed-off-by: Alp Dener <[email protected]>

* Fix lint

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* bias GELU fusion consistency between checkpoint test and reference comparison

Signed-off-by: Alp Dener <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
denera and ksivaman authored Feb 24, 2024
1 parent 9b2fed5 commit 82bc797
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 65 deletions.
67 changes: 49 additions & 18 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
import os
import sys
from typing import List, Optional
import pytest
import copy
Expand Down Expand Up @@ -72,22 +73,27 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool:
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2):
assert torch.equal(t1, t2), "Output mismatch."
failed = False
failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
assert not failed, "Output mismatches in:\n" + failed_tensors


def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2):
for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=atol)
if not result:
diff = torch.abs(t1 - t2).flatten()
m = torch.argmax(diff)
msg = (f"Outputs not close enough."
msg = (f"Outputs not close enough in tensor at idx={i}. "
f"Location of the maximum difference: {m.item()} "
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
f"(diff {diff[m].item()})."
Expand Down Expand Up @@ -457,7 +463,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
assert_all_equal(outputs, outputs_recompute)


def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
def _test_e2e_full_recompute(
bs, dtype, config, fp8,
fp8_model_params=False,
recompute=False,
use_reentrant=True
):
reset_rng_states()
FP8GlobalStateManager.reset()

Expand Down Expand Up @@ -494,21 +505,23 @@ def get_dummy_cuda_rng_tracker():
)

te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=use_reentrant
).cuda()
te_inp_hidden_states.retain_grad()
if use_reentrant:
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

with fp8_autocast(enabled=fp8):
if recompute:
te_out = te_checkpoint(
block,
False, # distribute_saved_activations
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
distribute_saved_activations=False,
tp_group=None,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
use_reentrant=use_reentrant,
)
else:
te_out = block(
Expand All @@ -520,27 +533,45 @@ def get_dummy_cuda_rng_tracker():
loss.backward()
torch.cuda.synchronize()

outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
outputs = [te_out]
names = ["output"]
if use_reentrant:
outputs.append(te_inp_hidden_states.grad)
names.append("input")
for name, p in block.named_parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
names.append(name)

return outputs, names


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
@pytest.mark.parametrize("use_reentrant", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, use_reentrant):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)

config = model_configs[model]

outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)
if not use_reentrant:
# Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"

outputs, names = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=False, use_reentrant=use_reentrant)
outputs_recompute, _ = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=True, use_reentrant=use_reentrant)

if not use_reentrant:
# Reset bias+GELU fusion flag to avoid contaminating other tests
del os.environ["NVTE_BIAS_GELU_NVFUSION"]

assert_all_equal(outputs, outputs_recompute, names=names)


def _test_e2e_checkpointing_get_model(config, dtype):
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2472,9 +2472,9 @@ def custom_forward(*input_args, **input_kwargs):

hidden_states = checkpoint(
custom_forward,
False,
self.get_rng_state_tracker,
self.tp_group,
distribute_saved_activations=False,
get_rng_state_tracker=self.get_rng_state_tracker,
tp_group=self.tp_group,
*forward_args,
**forward_kwargs,
)
Expand Down
Loading

0 comments on commit 82bc797

Please sign in to comment.