Skip to content

Commit

Permalink
modify test for fp8_model_params+recompute
Browse files Browse the repository at this point in the history
Signed-off-by: Sudhakar Singh <[email protected]>
  • Loading branch information
sudhakarsingh27 committed Oct 24, 2023
1 parent dfcbcf1 commit 8ff9e05
Showing 1 changed file with 76 additions and 69 deletions.
145 changes: 76 additions & 69 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from typing import List, Optional
import pytest
import copy
from contextlib import nullcontext

import torch
import torch.nn as nn
from torch.nn import Parameter

from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_init
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
Expand Down Expand Up @@ -339,7 +340,7 @@ def forward(
return x


def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()

Expand All @@ -354,40 +355,42 @@ def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER

block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
with fp8_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=fp8 and fp8_model_params,
)
.cuda()
)
.cuda()
)

te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

with fp8_autocast(enabled=fp8):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=recompute,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
with fp8_autocast(enabled=fp8):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=recompute,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()

outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
Expand All @@ -400,18 +403,19 @@ def get_dummy_cuda_rng_tracker():
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8):
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)

config = model_configs[model]

outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True)
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)


def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()

Expand All @@ -426,7 +430,8 @@ def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER

block = (
with fp8_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
Expand All @@ -441,36 +446,37 @@ def get_dummy_cuda_rng_tracker():
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=fp8 and fp8_model_params,
)
.cuda()
)

te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

with fp8_autocast(enabled=fp8):
if recompute:
te_out = te_checkpoint(
block,
False, # distribute_saved_activations
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
else:
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

with fp8_autocast(enabled=fp8):
if recompute:
te_out = te_checkpoint(
block,
False, # distribute_saved_activations
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
else:
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()

outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
Expand All @@ -483,14 +489,15 @@ def get_dummy_cuda_rng_tracker():
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8):
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)

config = model_configs[model]

outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True)
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)


Expand Down

0 comments on commit 8ff9e05

Please sign in to comment.