diff --git a/tests/paddle/parallel_tests/linear_pp.py b/tests/paddle/parallel_tests/linear_pp.py index 994e15ba7d..4326e43a68 100644 --- a/tests/paddle/parallel_tests/linear_pp.py +++ b/tests/paddle/parallel_tests/linear_pp.py @@ -107,6 +107,7 @@ def set_attr(self): def test_pipeline_train(self): """Test pipeline parallel training""" set_random_seed(1024) + np.random.seed(1024) weight1_np = np.random.normal(size=[self.in_features, self.hidden_features]) weight2_np = np.random.normal(size=[self.hidden_features, self.in_features]) diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index bb93458230..16bb88d7f5 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -1085,7 +1085,7 @@ def calc_transformer_output_and_grad(layer, encoder_input, mask, encoder_output, assert_allclose(layer_te.self_attention.qkv.bias.grad, layer_pd.self_attention.qkv.bias.grad, rtol=0.01, - atol=0.5) + atol=0.6) else: assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, layer_pd.self_attention.layernorm_qkv.bias.grad, diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py index 5960cccd3d..cf3449ea66 100644 --- a/tests/paddle/utils.py +++ b/tests/paddle/utils.py @@ -3,9 +3,12 @@ # See LICENSE for license information. """Utils for testing""" +import random import numpy as np import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker import transformer_engine # pylint: disable=unused-import @@ -49,6 +52,43 @@ def is_devices_enough(required): def set_random_seed(seed): """Set random seed for reproducability.""" - np.random.seed(seed) - paddle.seed(seed) - paddle.distributed.fleet.meta_parallel.model_parallel_random_seed(seed) + + hcg = fleet.get_hybrid_communicate_group() + if paddle.distributed.get_world_size() > 1: + # obtain rank message of hybrid parallel + + mp_rank = hcg.get_model_parallel_rank() + mp_size = hcg.get_model_parallel_world_size() + + pp_rank = hcg.get_stage_id() + pp_size = hcg.get_pipe_parallel_world_size() + + dp_rank = hcg.get_data_parallel_rank() + dp_size = hcg.get_data_parallel_world_size() + + sharding_rank = hcg.get_sharding_parallel_rank() + else: + mp_rank, mp_size = 0, 1 + pp_rank, pp_size = 0, 1 + dp_rank, dp_size = 0, 1 + sharding_rank, _ = 0, 1 + + random.seed(seed + 100 * pp_rank) + np.random.seed(seed + 100 * pp_rank) + + seed_offset = seed + 1024 + paddle.distributed.get_world_size() + global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) + + sharding_rank * (mp_size * pp_size * dp_size)) + + seed_offset += paddle.distributed.get_world_size() + local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) + + sharding_rank * (mp_size * pp_size * dp_size)) + + tracker = get_rng_state_tracker() + # tracker.reset() + if "global_seed" not in tracker.states_: + tracker.add("global_seed", global_seed) + if "local_seed" not in tracker.states_: + tracker.add("local_seed", local_seed) + + paddle.seed(global_seed) diff --git a/transformer_engine/paddle/distributed.py b/transformer_engine/paddle/distributed.py index 5bf51c9274..bacc2e27dc 100644 --- a/transformer_engine/paddle/distributed.py +++ b/transformer_engine/paddle/distributed.py @@ -39,13 +39,13 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None], @contextmanager -def track_rng_state(enable: bool) -> None: +def track_rng_state(enable: bool, **kwargs) -> None: """ Applies get_rng_state_tracker().rng_state() to the context. If not enabled, it does nothing. """ if enable: - with get_rng_state_tracker().rng_state(): + with get_rng_state_tracker().rng_state(**kwargs): yield else: yield diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index 565321baad..19a22a28b8 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -401,6 +401,7 @@ def __init__( zero_centered_gamma: bool = False, set_parallel_mode: bool = False, tp_group: Optional[dist_group_type] = None, + rng_state_name: str = 'local_seed', backend: str = 'transformer_engine', ) -> None: super().__init__() @@ -422,6 +423,7 @@ def __init__( self.num_attention_heads = num_attention_heads norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.set_parallel_mode = set_parallel_mode + self.rng_state_name = rng_state_name self.backend = backend self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) @@ -555,7 +557,7 @@ def forward( 0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head ]) - with track_rng_state(enable=self.tensor_parallel): + with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): context_layer = self.core_attention( query_layer=mixed_qkv_layer, key_value_layer=None, @@ -584,7 +586,7 @@ def forward( query_layer = query_layer.reshape(shape=[ 0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head ]) - with track_rng_state(enable=self.tensor_parallel): + with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): context_layer = self.core_attention( query_layer=query_layer, key_value_layer=mixed_kv_layer, diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py index a95b9fcfe1..b34a119dbb 100644 --- a/transformer_engine/paddle/layer/transformer.py +++ b/transformer_engine/paddle/layer/transformer.py @@ -9,6 +9,7 @@ from . import LayerNormMLP, LayerNorm, MultiHeadAttention from ..constants import AttnMaskTypes, LayerTypes, dist_group_type +from ..distributed import get_tp_group_and_world_size, track_rng_state class TransformerLayer(paddle.nn.Layer): @@ -90,6 +91,8 @@ def __init__(self, activation: str = 'gelu', set_parallel_mode: bool = False, tp_group: Optional[dist_group_type] = None, + attention_dropout_rng_state_name: str = 'local_seed', + hidden_dropout_rng_state_name: str = 'global_seed', backend: str = 'transformer_engine') -> None: super().__init__() @@ -99,7 +102,10 @@ def __init__(self, self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.self_attn_mask_type = self_attn_mask_type self.set_parallel_mode = set_parallel_mode - self.tp_group = tp_group + self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, + enable_tp=set_parallel_mode) + self.tensor_parallel = self.tp_size > 1 + self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name assert (self_attn_mask_type in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported" @@ -119,6 +125,7 @@ def __init__(self, "zero_centered_gamma": zero_centered_gamma, "set_parallel_mode": set_parallel_mode, "tp_group": tp_group, + "rng_state_name": attention_dropout_rng_state_name, "backend": backend, } @@ -224,11 +231,12 @@ def forward( residual = hidden_states # dropoout add. - out = paddle.nn.functional.dropout( - attention_output, - p=self.hidden_dropout, - training=True, - ) + with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): + out = paddle.nn.functional.dropout( + attention_output, + p=self.hidden_dropout, + training=True, + ) bda_output = residual + out # Cross attention. @@ -247,11 +255,13 @@ def forward( attention_output = inter_attention_outputs residual = bda_output - out = paddle.nn.functional.dropout( - attention_output, - p=self.hidden_dropout, - training=True, - ) + with track_rng_state(enable=self.tensor_parallel, + name=self.hidden_dropout_rng_state_name): + out = paddle.nn.functional.dropout( + attention_output, + p=self.hidden_dropout, + training=True, + ) bda_output = residual + out # MLP. @@ -263,7 +273,8 @@ def forward( residual = bda_output # dropoout add. - out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True) + with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): + out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True) output = residual + out # For BERT like architectures.