Skip to content

Commit

Permalink
[Paddle] Add control of RNG state (#410)
Browse files Browse the repository at this point in the history
* Add control of attention dropout and hidden dropout RNG state

Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]>

* Fix CI error

Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]>

---------

Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]>
  • Loading branch information
Tom-Zheng authored Sep 1, 2023
1 parent 3a63b13 commit 805b987
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 20 deletions.
1 change: 1 addition & 0 deletions tests/paddle/parallel_tests/linear_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def set_attr(self):
def test_pipeline_train(self):
"""Test pipeline parallel training"""
set_random_seed(1024)
np.random.seed(1024)

weight1_np = np.random.normal(size=[self.in_features, self.hidden_features])
weight2_np = np.random.normal(size=[self.hidden_features, self.in_features])
Expand Down
2 changes: 1 addition & 1 deletion tests/paddle/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def calc_transformer_output_and_grad(layer, encoder_input, mask, encoder_output,
assert_allclose(layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.01,
atol=0.5)
atol=0.6)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
layer_pd.self_attention.layernorm_qkv.bias.grad,
Expand Down
46 changes: 43 additions & 3 deletions tests/paddle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# See LICENSE for license information.
"""Utils for testing"""

import random
import numpy as np

import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker

import transformer_engine # pylint: disable=unused-import

Expand Down Expand Up @@ -49,6 +52,43 @@ def is_devices_enough(required):

def set_random_seed(seed):
"""Set random seed for reproducability."""
np.random.seed(seed)
paddle.seed(seed)
paddle.distributed.fleet.meta_parallel.model_parallel_random_seed(seed)

hcg = fleet.get_hybrid_communicate_group()
if paddle.distributed.get_world_size() > 1:
# obtain rank message of hybrid parallel

mp_rank = hcg.get_model_parallel_rank()
mp_size = hcg.get_model_parallel_world_size()

pp_rank = hcg.get_stage_id()
pp_size = hcg.get_pipe_parallel_world_size()

dp_rank = hcg.get_data_parallel_rank()
dp_size = hcg.get_data_parallel_world_size()

sharding_rank = hcg.get_sharding_parallel_rank()
else:
mp_rank, mp_size = 0, 1
pp_rank, pp_size = 0, 1
dp_rank, dp_size = 0, 1
sharding_rank, _ = 0, 1

random.seed(seed + 100 * pp_rank)
np.random.seed(seed + 100 * pp_rank)

seed_offset = seed + 1024 + paddle.distributed.get_world_size()
global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
sharding_rank * (mp_size * pp_size * dp_size))

seed_offset += paddle.distributed.get_world_size()
local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
sharding_rank * (mp_size * pp_size * dp_size))

tracker = get_rng_state_tracker()
# tracker.reset()
if "global_seed" not in tracker.states_:
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_:
tracker.add("local_seed", local_seed)

paddle.seed(global_seed)
4 changes: 2 additions & 2 deletions transformer_engine/paddle/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],


@contextmanager
def track_rng_state(enable: bool) -> None:
def track_rng_state(enable: bool, **kwargs) -> None:
"""
Applies get_rng_state_tracker().rng_state() to the context.
If not enabled, it does nothing.
"""
if enable:
with get_rng_state_tracker().rng_state():
with get_rng_state_tracker().rng_state(**kwargs):
yield
else:
yield
Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/paddle/layer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def __init__(
zero_centered_gamma: bool = False,
set_parallel_mode: bool = False,
tp_group: Optional[dist_group_type] = None,
rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine',
) -> None:
super().__init__()
Expand All @@ -422,6 +423,7 @@ def __init__(
self.num_attention_heads = num_attention_heads
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.set_parallel_mode = set_parallel_mode
self.rng_state_name = rng_state_name
self.backend = backend

self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
Expand Down Expand Up @@ -555,7 +557,7 @@ def forward(
0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
])

with track_rng_state(enable=self.tensor_parallel):
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
context_layer = self.core_attention(
query_layer=mixed_qkv_layer,
key_value_layer=None,
Expand Down Expand Up @@ -584,7 +586,7 @@ def forward(
query_layer = query_layer.reshape(shape=[
0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
])
with track_rng_state(enable=self.tensor_parallel):
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
context_layer = self.core_attention(
query_layer=query_layer,
key_value_layer=mixed_kv_layer,
Expand Down
35 changes: 23 additions & 12 deletions transformer_engine/paddle/layer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from . import LayerNormMLP, LayerNorm, MultiHeadAttention
from ..constants import AttnMaskTypes, LayerTypes, dist_group_type
from ..distributed import get_tp_group_and_world_size, track_rng_state


class TransformerLayer(paddle.nn.Layer):
Expand Down Expand Up @@ -90,6 +91,8 @@ def __init__(self,
activation: str = 'gelu',
set_parallel_mode: bool = False,
tp_group: Optional[dist_group_type] = None,
attention_dropout_rng_state_name: str = 'local_seed',
hidden_dropout_rng_state_name: str = 'global_seed',
backend: str = 'transformer_engine') -> None:
super().__init__()

Expand All @@ -99,7 +102,10 @@ def __init__(self,
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.self_attn_mask_type = self_attn_mask_type
self.set_parallel_mode = set_parallel_mode
self.tp_group = tp_group
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1
self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name

assert (self_attn_mask_type
in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported"
Expand All @@ -119,6 +125,7 @@ def __init__(self,
"zero_centered_gamma": zero_centered_gamma,
"set_parallel_mode": set_parallel_mode,
"tp_group": tp_group,
"rng_state_name": attention_dropout_rng_state_name,
"backend": backend,
}

Expand Down Expand Up @@ -224,11 +231,12 @@ def forward(
residual = hidden_states

# dropoout add.
out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
bda_output = residual + out

# Cross attention.
Expand All @@ -247,11 +255,13 @@ def forward(
attention_output = inter_attention_outputs
residual = bda_output

out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
with track_rng_state(enable=self.tensor_parallel,
name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=True,
)
bda_output = residual + out

# MLP.
Expand All @@ -263,7 +273,8 @@ def forward(
residual = bda_output

# dropoout add.
out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True)
with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True)
output = residual + out

# For BERT like architectures.
Expand Down

0 comments on commit 805b987

Please sign in to comment.