From 8ff9e059ac584e6ff6d411dfc7ff9d549c5fddbe Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 24 Oct 2023 13:57:21 -0700 Subject: [PATCH] modify test for fp8_model_params+recompute Signed-off-by: Sudhakar Singh --- tests/pytorch/test_numerics.py | 145 +++++++++++++++++---------------- 1 file changed, 76 insertions(+), 69 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 02fb63e71f..1683dbc10d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -7,12 +7,13 @@ from typing import List, Optional import pytest import copy +from contextlib import nullcontext import torch import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager +from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_init from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -339,7 +340,7 @@ def forward( return x -def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False): +def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -354,40 +355,42 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, - params_dtype=dtype, + with fp8_init(enabled=fp8 and fp8_model_params): + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + params_dtype=dtype, + fuse_qkv_params=fp8 and fp8_model_params, + ) + .cuda() ) - .cuda() - ) - te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() - te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_hidden_states = torch.randn( + config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + te_inp_hidden_states.retain_grad() + te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=fp8): - te_out = block( - te_inp_hidden_states, - attention_mask=te_inp_attn_mask, - checkpoint_core_attention=recompute, - ) - loss = te_out.sum() - loss.backward() - torch.cuda.synchronize() + with fp8_autocast(enabled=fp8): + te_out = block( + te_inp_hidden_states, + attention_mask=te_inp_attn_mask, + checkpoint_core_attention=recompute, + ) + loss = te_out.sum() + loss.backward() + torch.cuda.synchronize() outputs = [te_out, te_inp_hidden_states.grad] for p in block.parameters(): @@ -400,18 +403,19 @@ def get_dummy_cuda_rng_tracker(): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", all_boolean) -def test_gpt_selective_activation_recompute(dtype, bs, model, fp8): +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) config = model_configs[model] - outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False) - outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True) + outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) + outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) assert_all_equal(outputs, outputs_recompute) -def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False): +def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -426,7 +430,8 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - block = ( + with fp8_init(enabled=fp8 and fp8_model_params): + block = ( TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -441,36 +446,37 @@ def get_dummy_cuda_rng_tracker(): output_layernorm=False, get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, + fuse_qkv_params=fp8 and fp8_model_params, ) .cuda() - ) - - te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() - te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - - with fp8_autocast(enabled=fp8): - if recompute: - te_out = te_checkpoint( - block, - False, # distribute_saved_activations - get_dummy_cuda_rng_tracker, - None, # tp_group - te_inp_hidden_states, - attention_mask=te_inp_attn_mask, - checkpoint_core_attention=False, - ) - else: - te_out = block( - te_inp_hidden_states, - attention_mask=te_inp_attn_mask, - checkpoint_core_attention=False, - ) - loss = te_out.sum() - loss.backward() - torch.cuda.synchronize() + ) + + te_inp_hidden_states = torch.randn( + config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + te_inp_hidden_states.retain_grad() + te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + + with fp8_autocast(enabled=fp8): + if recompute: + te_out = te_checkpoint( + block, + False, # distribute_saved_activations + get_dummy_cuda_rng_tracker, + None, # tp_group + te_inp_hidden_states, + attention_mask=te_inp_attn_mask, + checkpoint_core_attention=False, + ) + else: + te_out = block( + te_inp_hidden_states, + attention_mask=te_inp_attn_mask, + checkpoint_core_attention=False, + ) + loss = te_out.sum() + loss.backward() + torch.cuda.synchronize() outputs = [te_out, te_inp_hidden_states.grad] for p in block.parameters(): @@ -483,14 +489,15 @@ def get_dummy_cuda_rng_tracker(): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", all_boolean) -def test_gpt_full_activation_recompute(dtype, bs, model, fp8): +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) config = model_configs[model] - outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False) - outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True) + outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) + outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) assert_all_equal(outputs, outputs_recompute)