diff --git a/examples/paddle/mnist/test_single_gpu_mnist.py b/examples/paddle/mnist/test_single_gpu_mnist.py index dabeb55656..cffd036d95 100644 --- a/examples/paddle/mnist/test_single_gpu_mnist.py +++ b/examples/paddle/mnist/test_single_gpu_mnist.py @@ -57,11 +57,13 @@ def forward(self, x): def train(args, model, train_loader, optimizer, epoch, use_fp8): """Training function.""" model.train() + losses = [] for batch_id, (data, labels) in enumerate(train_loader): with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager with te.fp8_autocast(enabled=use_fp8): outputs = model(data) loss = F.cross_entropy(outputs, labels) + losses.append(loss.item()) loss.backward() optimizer.step() @@ -74,7 +76,9 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8): f"Loss: {loss.item():.6f}") if args.dry_run: return loss.item() - return loss.item() + avg_loss = sum(losses) / len(losses) + print(f"Train Epoch: {epoch}, Average Loss: {avg_loss}") + return avg_loss def evaluate(model, test_loader, epoch, use_fp8): @@ -226,7 +230,7 @@ def setUpClass(cls): @staticmethod def verify(actual): """Check If loss and accuracy match target""" - desired_traing_loss = 0.5 + desired_traing_loss = 0.1 desired_test_accuracy = 0.98 assert actual[0] < desired_traing_loss assert actual[1] > desired_test_accuracy diff --git a/tests/paddle/dist_launcher.py b/tests/paddle/dist_launcher.py new file mode 100644 index 0000000000..e59b686435 --- /dev/null +++ b/tests/paddle/dist_launcher.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Helper functions to launch distributed tests""" + +import copy +import os +from pathlib import Path +import subprocess +import time +import unittest + +from paddle import fluid +from paddle.distributed.utils.launch_utils import ( + TrainerProc, + find_free_ports, + get_cluster, + watch_local_trainers, +) + +__all__ = ['TestDistributed'] + + +def get_cluster_from_args(selected_gpus): + """Get node information from selected GPUs""" + cluster_node_ips = '127.0.0.1' + node_ip = '127.0.0.1' + + node_ips = [x.strip() for x in cluster_node_ips.split(',')] + + node_ips.index(node_ip) + + free_ports = None + + free_ports = find_free_ports(len(selected_gpus)) + if free_ports is not None: + free_ports = list(free_ports) + + trainer_endpoints = [] + for ip in node_ips: + trainer_endpoints.append([f"{ip}:{port}" for port in free_ports]) + return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) + + +def get_gpus(selected_gpus): + """Get selected GPU string""" + selected_gpus = [x.strip() for x in selected_gpus.split(',')] + return selected_gpus + + +def start_local_trainers( + cluster, + pod, + training_script, + training_script_args, + allocator_strategy="auto_growth", +): + """Launch trainers""" + current_env = copy.copy(os.environ.copy()) + # paddle broadcast ncclUniqueId use socket, and + # proxy maybe make trainers unreachable, so delete them. + # if we set them to "", grpc will log error message "bad uri" + # so just delete them. + current_env.pop("http_proxy", None) + current_env.pop("https_proxy", None) + + procs = [] + for t in pod.trainers: + proc_env = { + "FLAGS_selected_gpus": ",".join([str(g) for g in t.gpus]), + "PADDLE_TRAINER_ID": f"{t.rank}", + "PADDLE_CURRENT_ENDPOINT": f"{t.endpoint}", + "PADDLE_TRAINERS_NUM": f"{cluster.trainers_nranks()}", + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), + "PYTHONPATH": str(Path(__file__).resolve().parent), + } + + proc_env["FLAGS_allocator_strategy"] = allocator_strategy + if allocator_strategy == "auto_growth": + proc_env["FLAGS_fraction_of_gpu_memory_to_use"] = "0.1" + + current_env.update(proc_env) + + print(f"trainer proc env:{current_env}") + + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + cmd = "python -m coverage run --branch -p " + training_script + else: + cmd = "python -u " + training_script + + print(f"start trainer proc:{cmd} env:{proc_env}") + + fn = None + + proc = subprocess.Popen(cmd.split(" ") + training_script_args, env=current_env) # pylint: disable=consider-using-with + + tp = TrainerProc() + tp.proc = proc + tp.rank = t.rank + tp.log_fn = fn + tp.cmd = cmd + + procs.append(tp) + + return procs + + +class TestDistributed(unittest.TestCase): + """Base class for distributed test""" + + @staticmethod + def run_2gpu( + target_file_name, + allocator_strategy="auto_growth", + ): + """Run target file in subprocesses""" + if (not fluid.core.is_compiled_with_cuda() or fluid.core.get_cuda_device_count() == 0): + return + + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + + cluster, pod = get_cluster_from_args(selected_gpus) + + procs = start_local_trainers( + cluster, + pod, + allocator_strategy=allocator_strategy, + training_script=target_file_name, + training_script_args=[], + ) + + while True: + alive = watch_local_trainers(procs, cluster.trainers_endpoints()) + + if not alive: + print(f"Local procs complete, POD info:{pod}") + break + time.sleep(3) diff --git a/tests/paddle/parallel_tests/amax_reduction.py b/tests/paddle/parallel_tests/amax_reduction.py new file mode 100644 index 0000000000..931af07657 --- /dev/null +++ b/tests/paddle/parallel_tests/amax_reduction.py @@ -0,0 +1,87 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for Linear layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet + +from utils import assert_allclose, set_random_seed +import transformer_engine.paddle as te + + +def assert_allclose_across_ranks(tensor, group=None): + """Assert tensor is identical in all ranks""" + gathered_list = [] + paddle.distributed.all_gather(gathered_list, tensor, group=group) + assert len(gathered_list) > 1 + for gathered_tensor in gathered_list: + assert_allclose(tensor, gathered_tensor) + + +class TestAmaxReduction(unittest.TestCase): + """Tests Amax reduction""" + + def setUp(self): + self.data_parallel_size = 2 + self.init_dist_env() + self.global_dtype = 'bfloat16' + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": 1, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + + def test_amax_reduction(self): + """Tests column parallel linear""" + set_random_seed(1024) + layer1 = te.Linear(16, 16) + layer2 = te.Linear(16, 16) + model = paddle.nn.Sequential(layer1, layer2) + model = fleet.distributed_model(model) + + rank_id = paddle.distributed.get_rank() + set_random_seed(rank_id) + + optimizer = paddle.optimizer.SGD(learning_rate=10.0, parameters=model.parameters()) + optimizer = fleet.distributed_optimizer(optimizer) + + def train_one_step(layer, inp, optimizer): + inp = paddle.to_tensor(inp) + inp.stop_gradient = False + out = layer(inp) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([16, 16], self.global_dtype) + with te.fp8_autocast(enabled=True): + train_one_step(model, inp, optimizer) + + assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].amax_history[-1]) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale_inv) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].amax_history[-1]) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale_inv) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].amax_history[-1]) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale) + assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale_inv) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].amax_history[-1]) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale) + assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/group_sharding.py b/tests/paddle/parallel_tests/group_sharding.py new file mode 100644 index 0000000000..b8e4fd885d --- /dev/null +++ b/tests/paddle/parallel_tests/group_sharding.py @@ -0,0 +1,187 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for group sharding""" + +import unittest + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import ( + DygraphShardingOptimizer,) + +from utils import assert_allclose, set_random_seed +import transformer_engine.paddle as te + + +class TestGroupSharding(unittest.TestCase): + """Tests group sharding""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def set_attr(self): + """Set test configs""" + self.sharding_degree = 2 + self.global_dtype = 'float32' + self.rtol = 1e-5 + self.atol = 1e-5 + self.batch_size = 16 + self.in_channels = 16 + self.out_channels = 32 + self.fp8 = False + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": self.sharding_degree, + } + self.strategy = strategy + fleet.init(is_collective=True, strategy=strategy) + + def _get_model_and_optimizer(self, model, stage): + if stage == 1: + optimizer = DygraphShardingOptimizer( + hcg=fleet.get_hybrid_communicate_group(), + user_defined_strategy=self.strategy, + params=model.parameters(), + inner_optimizer_class=paddle.optimizer.AdamW, + learning_rate=0.01, + ) + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + elif stage in [2, 3]: + optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()) + group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group() + + class ShardingLevel: # pylint: disable=too-few-public-methods, + """Paddle sharding options""" + kStage1 = 'os' + kStage2 = 'os_g' + kStage3 = 'p_g_os' + + level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2 + model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel( + model=model, + optimizer=optimizer, + level=level, + group=group, + segment_size=256, + ) + else: + raise ValueError(f"Stage {stage} not supported") + return model, optimizer + + def test_group_sharding_stage1(self): + """Tests group sharding training""" + set_random_seed(1024) + model_te = te.Linear(self.in_channels, self.out_channels) + model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) + model_pd.weight.copy_(model_te.weight.T, True) + model_pd.bias.copy_(model_te.bias, True) + + model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=1) + model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=1) + + rank_id = paddle.distributed.get_rank() + paddle.seed(rank_id) + + def train_one_step(model, inp, optimizer): + out = model(inp) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) + with te.fp8_autocast(enabled=False): + loss_te = train_one_step(model_te, inp, optimizer_te) + loss_pd = train_one_step(model_pd, inp, optimizer_pd) + assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) + + assert len(optimizer_te.state_dict()) == 4, \ + "Expect each rank to hold 4 optimizer state entries." + + def test_group_sharding_stage2(self): + """Tests group sharding training""" + set_random_seed(1024) + model_te = te.Linear(self.in_channels, self.out_channels) + model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) + model_pd.weight.copy_(model_te.weight.T, True) + model_pd.bias.copy_(model_te.bias, True) + + model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=2) + model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=2) + + rank_id = paddle.distributed.get_rank() + paddle.seed(rank_id) + + def train_one_step(model, inp, optimizer): + out = model(inp) + loss = out.mean() + loss.backward() + # Check gradients are split to different trainers + if rank_id == 0: + assert model.bias.grad is None and model.weight.grad is not None + elif rank_id == 1: + assert model.weight.grad is None and model.bias.grad is not None + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) + with te.fp8_autocast(enabled=False): + loss_te = train_one_step(model_te, inp, optimizer_te) + loss_pd = train_one_step(model_pd, inp, optimizer_pd) + assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) + + assert len(optimizer_te.state_dict()) == 4, \ + "Expect each rank to hold 4 optimizer state entries." + + def test_group_sharding_stage3(self): + """Tests group sharding training""" + set_random_seed(1024) + model_te = te.Linear(self.in_channels, self.out_channels) + model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) + model_pd.weight.copy_(model_te.weight.T, True) + model_pd.bias.copy_(model_te.bias, True) + + model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=3) + model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=3) + + rank_id = paddle.distributed.get_rank() + paddle.seed(rank_id) + + def train_one_step(model, inp, optimizer): + out = model(inp) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) + with te.fp8_autocast(enabled=False): + loss_te = train_one_step(model_te, inp, optimizer_te) + loss_pd = train_one_step(model_pd, inp, optimizer_pd) + assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) + + for name, value in optimizer_te.state_dict().items(): + if name.endswith('w_0_moment1_0'): + assert value.numel() == \ + self.in_channels * self.out_channels // self.sharding_degree, \ + "Expect optimizer state to be sharded across trainers." + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_linear_tp.py b/tests/paddle/parallel_tests/layernorm_linear_tp.py new file mode 100644 index 0000000000..1034fb26fc --- /dev/null +++ b/tests/paddle/parallel_tests/layernorm_linear_tp.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for LayerNormLinear layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.layers.mpu import mp_ops + +from utils import assert_allclose, assert_shape, set_random_seed +import transformer_engine.paddle as te + + +class TestLayerNormLinearTp(unittest.TestCase): + """Tests LayerNormLinear layer with column/row parallelism in BF16""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + self.hcg = fleet.get_hybrid_communicate_group() + self.tp_group = self.hcg.get_model_parallel_group() + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.in_features = 32 + self.out_features = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-3 + self.atol = 1e-3 + self.eps = 1e-3 + self.fp8 = False + + def test_column_parallel_layer(self): + """Tests column parallel LayerNormLinear""" + set_random_seed(1024) + layer_te = te.LayerNormLinear( + self.in_features, + self.out_features, + eps=self.eps, + parallel_mode='column', + ) + layer_pd = te.LayerNormLinear( + self.in_features, + self.out_features, + eps=self.eps, + backend='paddle', + ) + # Get total weight + total_weight = [] + partial_weight = layer_te.weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) + total_weight = paddle.concat(total_weight, axis=0) + layer_pd.weight.copy_(total_weight.T, True) + + assert_shape(layer_te.weight, + [self.out_features // self.model_parallel_size, self.in_features]) + assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) + + layer_te = fleet.distributed_model(layer_te) + optimizer_te = fleet.distributed_optimizer(optimizer_te) + + def train_one_step(layer, inp, optimizer, gather=False): + inp = paddle.to_tensor(inp) + inp.stop_gradient = False + out = layer(inp) + if gather: + total_out = mp_ops._c_concat(out, group=self.tp_group) + else: + total_out = out + loss = total_out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss, inp.grad + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) + with te.fp8_autocast(enabled=self.fp8): + loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True) + loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) + assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) + assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) + + +class TestLayerNormLinearTpFp8(TestLayerNormLinearTp): + """Tests LayernormLinear layer with column/row parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.in_features = 32 + self.out_features = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-2 + self.atol = 1e-2 + self.eps = 1e-3 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_mlp_tp.py b/tests/paddle/parallel_tests/layernorm_mlp_tp.py new file mode 100644 index 0000000000..f579f5f371 --- /dev/null +++ b/tests/paddle/parallel_tests/layernorm_mlp_tp.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for LayerNormMLP layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet + +from utils import assert_allclose, assert_shape, set_random_seed +import transformer_engine.paddle as te + + +class TestLayerNormMLPTp(unittest.TestCase): + """Tests LayerNormMLP layer with model parallel in BF16""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + self.hcg = fleet.get_hybrid_communicate_group() + self.tp_group = self.hcg.get_model_parallel_group() + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.hidden_size = 32 + self.ffn_hidden_size = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-3 + self.atol = 1e-3 + self.eps = 1e-3 + self.fp8 = False + + def test_parallel_layer(self): + """Tests parallel LayerNormMLP""" + set_random_seed(1024) + layer_te = te.LayerNormMLP( + hidden_size=self.hidden_size, + ffn_hidden_size=self.ffn_hidden_size, + eps=self.eps, + set_parallel_mode=True, + ) + layer_pd = te.LayerNormMLP( + hidden_size=self.hidden_size, + ffn_hidden_size=self.ffn_hidden_size, + eps=self.eps, + set_parallel_mode=False, + backend='paddle', + ) + + def _get_total_weight(local_weight, tp_group, axis): + total_weight = [] + partial_weight = local_weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) + total_weight = paddle.concat(total_weight, axis=axis) + return total_weight + + # Get total weight + total_fc1_weight = _get_total_weight(layer_te.fc1_weight, tp_group=self.tp_group, axis=0) + total_fc2_weight = _get_total_weight(layer_te.fc2_weight, tp_group=self.tp_group, axis=1) + layer_pd.fc1_weight.copy_(total_fc1_weight.T, True) + layer_pd.fc2_weight.copy_(total_fc2_weight.T, True) + + assert_shape(layer_te.fc1_weight, + [self.ffn_hidden_size // self.model_parallel_size, self.hidden_size]) + assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size]) + assert_shape(layer_te.fc2_weight, + [self.hidden_size, self.ffn_hidden_size // self.model_parallel_size]) + assert_shape(layer_te.fc2_bias, [self.hidden_size]) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) + + layer_te = fleet.distributed_model(layer_te) + optimizer_te = fleet.distributed_optimizer(optimizer_te) + + def train_one_step(layer, inp, optimizer): + inp = paddle.to_tensor(inp) + inp.stop_gradient = False + out = layer(inp) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss, inp.grad + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype) + with te.fp8_autocast(enabled=self.fp8): + loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te) + loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) + assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) + assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) + + +class TestLayerNormMLPTpFp8(TestLayerNormMLPTp): + """Tests LayerNormMLP layer with tensor parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.hidden_size = 32 + self.ffn_hidden_size = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-2 + self.atol = 1e-2 + self.eps = 1e-3 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/linear_pp.py b/tests/paddle/parallel_tests/linear_pp.py new file mode 100644 index 0000000000..994e15ba7d --- /dev/null +++ b/tests/paddle/parallel_tests/linear_pp.py @@ -0,0 +1,192 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for Linear layer in pipeline parallel""" + +import unittest + +import numpy as np + +import paddle +from paddle.distributed import fleet + +from paddle.distributed.fleet.meta_parallel import ( + LayerDesc, + PipelineLayer, +) + +from utils import assert_allclose, set_random_seed +import transformer_engine.paddle as te + + +class TEPipelineModel(PipelineLayer): + """Model for pipeline parallel test""" + + def __init__(self, + in_features, + hidden_features, + weight_attrs, + use_te=True, + use_fp8=False, + **kwargs): + self.in_features = in_features + self.hidden_features = hidden_features + self.fp8 = use_fp8 + hcg = fleet.get_hybrid_communicate_group() + self.dp_group = hcg.get_data_parallel_group() + + Linear = te.Linear if use_te else paddle.nn.Linear + model_desc = [ + LayerDesc(Linear, self.in_features, self.hidden_features, weight_attr=weight_attrs[0]), + LayerDesc(Linear, self.hidden_features, self.in_features, weight_attr=weight_attrs[1]), + ] + super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs) + + def forward(self, *args, **kwargs): + with te.fp8_autocast(enabled=self.fp8, fp8_group=self.dp_group): + return super().forward(*args, **kwargs) + + +class StandaloneModel(paddle.nn.Layer): + """Model for pipeline parallel test""" + + def __init__(self, in_features, hidden_features, weight_attrs): + super().__init__() + self.in_features = in_features + self.hidden_features = hidden_features + Linear = paddle.nn.Linear + self.layer = paddle.nn.Sequential( + Linear(self.in_features, self.hidden_features, weight_attr=weight_attrs[0]), + Linear(self.hidden_features, self.in_features, weight_attr=weight_attrs[1]), + ) + self.loss = paddle.nn.CrossEntropyLoss() + + def forward(self, inp): + out = self.layer(inp[0]) + loss = self.loss(out, inp[1]) + return loss + + +class TestLinearPipelineParallel(unittest.TestCase): + """Tests Linear layer with pipeline parallel""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": self.batch_size // self.micro_batch_size, + "micro_batch_size": self.micro_batch_size, + } + fleet.init(is_collective=True, strategy=strategy) + self.rank = fleet.worker_index() + self.hcg = fleet.get_hybrid_communicate_group() + + def set_attr(self): + """Set test configs""" + self.batch_size = 32 + self.micro_batch_size = 16 + self.in_features = 32 + self.hidden_features = 64 + self.global_dtype = 'float32' + self.rtol = 1e-5 + self.atol = 1e-5 + self.iter = 10 + self.fp8 = False + + def test_pipeline_train(self): + """Test pipeline parallel training""" + set_random_seed(1024) + + weight1_np = np.random.normal(size=[self.in_features, self.hidden_features]) + weight2_np = np.random.normal(size=[self.hidden_features, self.in_features]) + weight_attrs = [ + paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np)), + paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np)), + ] + weight_attrs_transposed = [ + paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np.T)), + paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np.T)), + ] + + pipe_model = TEPipelineModel( + self.in_features, + self.hidden_features, + weight_attrs_transposed, + use_te=True, + use_fp8=self.fp8, + seg_method="layer:Linear", + num_stages=self.pipeline_parallel_size, + ) + + # Check if model is split across ranks as expected + for name, sublayer in pipe_model.named_sublayers(): + if name in ('_loss_fn', 'shared_layers'): + continue + if self.rank == 0: + assert tuple(sublayer.weight.shape) == weight1_np.T.shape, \ + f"Shape does not match, expect: {weight1_np.T.shape}, " \ + f"actual: {tuple(sublayer.weight.shape)}" + elif self.rank == 1: + assert tuple(sublayer.weight.shape) == weight2_np.T.shape, \ + f"Shape does not match, expect: {weight2_np.T.shape}, " \ + f"actual: {tuple(sublayer.weight.shape)}" + + standalone_model = StandaloneModel( + self.in_features, + self.hidden_features, + weight_attrs, + ) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.1, + parameters=standalone_model.parameters()) + + pipe_model = fleet.distributed_model(pipe_model) + optimizer_te = fleet.distributed_optimizer(optimizer_te) + + def train_one_step(layer, inp, optimizer): + loss = layer(inp) + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for i in range(self.iter): + inp = paddle.to_tensor(np.random.normal(size=[self.batch_size, self.in_features]), + dtype=self.global_dtype) + label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1])) + loss_te = pipe_model.train_batch([inp, label], optimizer_te) + loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd) + print(f"Iter: {i}, loss_te: {loss_te.item()}, loss_pd: {loss_pd.item()}") + assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) + + +class TestLinearPipelineParallelFP8(TestLinearPipelineParallel): + """Tests Linear layer with column/row parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 32 + self.micro_batch_size = 16 + self.in_features = 32 + self.hidden_features = 64 + self.global_dtype = 'float32' + self.rtol = 5e-2 + self.atol = 5e-2 + self.iter = 10 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/linear_tp.py b/tests/paddle/parallel_tests/linear_tp.py new file mode 100644 index 0000000000..fe0aeddccd --- /dev/null +++ b/tests/paddle/parallel_tests/linear_tp.py @@ -0,0 +1,180 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for Linear layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.layers.mpu import mp_ops + +from utils import assert_allclose, assert_shape, set_random_seed +import transformer_engine.paddle as te + + +class TestLinearTp(unittest.TestCase): + """Tests Linear layer with column/row parallelism in BF16""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + self.rank = fleet.worker_index() + self.hcg = fleet.get_hybrid_communicate_group() + self.tp_group = self.hcg.get_model_parallel_group() + self.world_size = self.hcg.get_model_parallel_world_size() + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.in_features = 32 + self.out_features = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-3 + self.atol = 1e-3 + self.fp8 = False + + def test_column_parallel_layer(self): + """Tests column parallel linear""" + set_random_seed(1024) + layer_te = te.Linear( + self.in_features, + self.out_features, + parallel_mode='column', + ) + layer_pd = te.Linear( + self.in_features, + self.out_features, + backend='paddle', + ) + # Get total weight + total_weight = [] + partial_weight = layer_te.weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) + total_weight = paddle.concat(total_weight, axis=0) + layer_pd.weight.copy_(total_weight.T, True) + + assert_shape(layer_te.weight, + [self.out_features // self.model_parallel_size, self.in_features]) + assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) + + layer_te = fleet.distributed_model(layer_te) + optimizer_te = fleet.distributed_optimizer(optimizer_te) + + def train_one_step(layer, inp, optimizer, gather=False): + inp = paddle.to_tensor(inp) + inp.stop_gradient = False + out = layer(inp) + if gather: + total_out = mp_ops._c_concat(out, group=self.tp_group) + else: + total_out = out + loss = total_out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss, inp.grad + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) + with te.fp8_autocast(enabled=self.fp8): + loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True) + loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) + assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) + assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) + + def test_row_parallel_layer(self): + """Tests row parallel linear""" + set_random_seed(1024) + layer_te = te.Linear( + self.in_features, + self.out_features, + parallel_mode='row', + ) + layer_pd = te.Linear( + self.in_features, + self.out_features, + backend='paddle', + ) + # Get total weight + total_weight = [] + partial_weight = layer_te.weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) + total_weight = paddle.concat(total_weight, axis=1) + layer_pd.weight.copy_(total_weight.T, True) + + assert_shape(layer_te.weight, + [self.out_features, self.in_features // self.model_parallel_size]) + assert_shape(layer_te.bias, [self.out_features]) + + optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) + optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) + + # Note(tizheng): For this test, we cannot wrap model with fleet.distributed_model, + # because it will broadcast inputs across mp group. However, RPL expects splitted + # inputs, which is different on each rank. + + def train_one_step(layer, inp, optimizer, split=False): + inp = paddle.to_tensor(inp, stop_gradient=True) + if split: + # TODO(tizheng): Why not working? + # issue: https://github.com/PaddlePaddle/Paddle/issues/55565 + # input_parallel = mp_ops._c_split(inp, group=layer.tp_group) + split_size = inp.shape[1] // self.world_size + input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)] + else: + input_parallel = inp + input_parallel.stop_gradient = False + out = layer(input_parallel) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + if split: + grad_input = [] + paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) + grad_input = paddle.concat(grad_input, axis=1) + else: + grad_input = input_parallel.grad + return loss, grad_input + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) + with te.fp8_autocast(enabled=self.fp8): + loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, split=True) + loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd) + assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) + assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) + + +class TestLinearTpFP8(TestLinearTp): + """Tests Linear layer with column/row parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.in_features = 32 + self.out_features = 64 + self.global_dtype = 'bfloat16' + self.rtol = 1e-2 + self.atol = 1e-2 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/parallel_tests/transformer_tp.py b/tests/paddle/parallel_tests/transformer_tp.py new file mode 100644 index 0000000000..69fef08d56 --- /dev/null +++ b/tests/paddle/parallel_tests/transformer_tp.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Unittest for Transformer layer in tensor parallel""" + +import unittest + +import paddle +from paddle.distributed import fleet + +from utils import assert_allclose, set_random_seed +import transformer_engine.paddle as te + + +class TestTransformerTp(unittest.TestCase): + """Tests Transformer layer with model parallel in BF16""" + + def setUp(self): + self.set_attr() + self.init_dist_env() + paddle.set_default_dtype(self.global_dtype) + + def init_dist_env(self): + """Init Paddle Fleet environment""" + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + self.hcg = fleet.get_hybrid_communicate_group() + self.tp_group = self.hcg.get_model_parallel_group() + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.hidden_size = 1024 + self.num_heads = 16 + self.ffn_hidden_size = 4096 + self.q_seqlen = 128 + self.kv_seqlen = 128 + self.mask_type = 'padding' + self.layer_type = 'encoder' + self.global_dtype = 'bfloat16' + self.rtol = 5e-2 + self.atol = 5e-2 + self.eps = 1e-3 + self.fp8 = False + + def test_parallel_layer(self): + """Tests parallel Transformer""" + set_random_seed(1024) + common_args = [ + self.hidden_size, + self.ffn_hidden_size, + self.num_heads, + ] + common_kwargs = { + 'layernorm_epsilon': self.eps, + 'hidden_dropout': 0.0, + 'attention_dropout': 0.0, + 'self_attn_mask_type': self.mask_type, + 'layer_type': self.layer_type, + } + layer_tp = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=True) + layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False) + + def _get_total_weight(local_weight, tp_group, axis): + total_weight = [] + partial_weight = local_weight.clone().detach() + paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) + total_weight = paddle.concat(total_weight, axis=axis) + return total_weight + + def _get_weight(obj, weight_names): + for name in weight_names: + obj = getattr(obj, name) + return obj + + def copy_weight(layer_src, layer_dst, partition_mode, weight_names): + weight_src = _get_weight(layer_src, weight_names) + weight_dst = _get_weight(layer_dst, weight_names) + if partition_mode is None: + total_weight = weight_src + elif partition_mode == 'column': + total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=0) + elif partition_mode == 'row': + total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) + else: + raise ValueError(f"Partition Mode {partition_mode} is not supported.") + assert weight_dst.shape == total_weight.shape, \ + f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." + weight_dst.copy_(total_weight, True) + + copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight']) + copy_weight(layer_tp, layer_single, 'column', ['self_attention', 'layernorm_qkv', 'weight']) + copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight']) + copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight']) + copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight']) + copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight']) + + optimizer_tp = paddle.optimizer.SGD(learning_rate=0.1, parameters=layer_tp.parameters()) + optimizer_single = paddle.optimizer.SGD(learning_rate=0.1, + parameters=layer_single.parameters()) + + layer_tp = fleet.distributed_model(layer_tp) + optimizer_tp = fleet.distributed_optimizer(optimizer_tp) + + def train_one_step(layer, inp_list, optimizer, fp8_enabled): + with te.fp8_autocast(enabled=fp8_enabled): + out = layer(*inp_list) + loss = out.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + for _ in range(5): + inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size], + self.global_dtype) + mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), + dtype='bool') + loss_tp = train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8) + loss_single = train_one_step(layer_single, [inp, mask], optimizer_single, self.fp8) + assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) + + +class TestTransformerTpFp8(TestTransformerTp): + """Tests Transformer layer with tensor parallelism in FP8""" + + def set_attr(self): + """Set test configs""" + self.batch_size = 16 + self.hidden_size = 1024 + self.num_heads = 16 + self.ffn_hidden_size = 4096 + self.q_seqlen = 128 + self.kv_seqlen = 128 + self.mask_type = 'padding' + self.layer_type = 'encoder' + self.global_dtype = 'bfloat16' + self.rtol = 5e-2 + self.atol = 5e-2 + self.eps = 1e-3 + self.fp8 = True + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index 171b9233e7..bb93458230 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -98,8 +98,8 @@ def test_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad """ Test BF16 Linear """ - rtol = 1e-2 - atol = 1e-2 + rtol = 5e-2 + atol = 5e-2 input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) input_tensor.stop_gradient = no_dgrad @@ -258,8 +258,8 @@ def test_layernorm_linear_bf16(bs, in_features, out_features, has_bias, no_dbias Test BF16 LayerNormLinear Layer """ paddle.set_default_dtype(activation_dtype) - rtol = 1e-2 - atol = 1e-2 + rtol = 5e-2 + atol = 5e-2 input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) input_tensor.stop_gradient = no_dgrad @@ -905,7 +905,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, """ paddle.set_default_dtype(math_dtype) rtol = 5e-2 - atol = 5e-2 + atol = 6e-2 eps = 1e-3 encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py index 662978086a..241f96214b 100644 --- a/tests/paddle/test_operators.py +++ b/tests/paddle/test_operators.py @@ -728,8 +728,8 @@ def _get_fused_attention_out(self): return out, q_grad, k_grad, v_grad - @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), - reason="cuDNN fMHA requires Ampere+ GPU") + @pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)), + reason="cuDNN fMHA requires Ampere and Hopper GPU") @pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('is_causal_masking', [True, False]) @@ -745,8 +745,8 @@ def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0), - reason="cuDNN fMHA requires Ampere+ GPU") + @pytest.mark.skipif(paddle.device.cuda.get_device_capability() not in ((8, 0), (9, 0)), + reason="cuDNN fMHA requires Ampere and Hopper GPU") @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype): diff --git a/tests/paddle/test_parallel.py b/tests/paddle/test_parallel.py new file mode 100644 index 0000000000..d6e02747d1 --- /dev/null +++ b/tests/paddle/test_parallel.py @@ -0,0 +1,89 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Test TE Paddle Parallel""" + +from pathlib import Path +import unittest + +from dist_launcher import TestDistributed +from utils import is_devices_enough + +from transformer_engine.paddle.fp8 import is_fp8_available + +test_root = Path(__file__).resolve().parent +gpu_has_fp8, reason = is_fp8_available() + + +class TestParallelLinear(TestDistributed): + """Test Linear in Parallel mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestParallelLinear needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_linear_tp(self): + """Tests linear with tensor parallel in BF16""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_tp.py')) + + +class TestParallelLayerNormLinear(TestDistributed): + """Test LayerNormLinear in Parallel mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormLinear needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_layernorm_linear_tp(self): + """Tests layernorm_linear with tensor parallel in BF16""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_linear_tp.py')) + + +class TestParallelLayerNormMLP(TestDistributed): + """Test LayerNormMLP in Parallel mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormMLP needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_layernorm_mlp_tp(self): + """Tests layernorm_mlp with tensor parallel in BF16""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_mlp_tp.py')) + + +class TestAmaxReduction(TestDistributed): + """Test amax reduction in dp mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestAmaxReduction needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_amax_reduction(self): + """Tests amax reduction""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'amax_reduction.py')) + + +class TestPipelineParallel(TestDistributed): + """Test pipeline parallel""" + + @unittest.skipIf(not is_devices_enough(2), "TestPipelineParallel needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_pipeline_parallel(self): + """Tests pipeline parallel""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_pp.py')) + + +class TestGroupSharding(TestDistributed): + """Test group sharding""" + + @unittest.skipIf(not is_devices_enough(2), "TestGroupSharding needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_group_sharding(self): + """Tests group sharding""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py')) + + +class TestParallelTransformerLayer(TestDistributed): + """Test Transformer Layer in Parallel mode""" + + @unittest.skipIf(not is_devices_enough(2), "TestParallelTransformerLayer needs 2 GPUs") + @unittest.skipIf(not gpu_has_fp8, reason) + def test_transformer_tp(self): + """Tests Transformer Layer with tensor parallel in BF16""" + self.run_2gpu(str(test_root / 'parallel_tests' / 'transformer_tp.py')) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py index 432b39c2e0..5960cccd3d 100644 --- a/tests/paddle/utils.py +++ b/tests/paddle/utils.py @@ -34,3 +34,21 @@ def assert_allclose(actual, if isinstance(desired, paddle.Tensor): desired = paddle.cast(desired, 'float32').numpy() np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose) + + +def assert_shape(inp, expected_shape): + """Assert the shape of input tensor equals to expected shape""" + assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \ + f"actual tensor shape: {inp.shape}" + + +def is_devices_enough(required): + """If the number of device is enough""" + return paddle.device.cuda.device_count() >= required + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + np.random.seed(seed) + paddle.seed(seed) + paddle.distributed.fleet.meta_parallel.model_parallel_random_seed(seed) diff --git a/transformer_engine/paddle/constants.py b/transformer_engine/paddle/constants.py index eac161ec60..cfecd39564 100644 --- a/transformer_engine/paddle/constants.py +++ b/transformer_engine/paddle/constants.py @@ -46,3 +46,7 @@ class FP8BwdTensors(Enum): AttnTypes = ("self", "cross") LayerTypes = ("encoder", "decoder") + +GemmParallelModes = ("row", "column", None) + +dist_group_type = paddle.distributed.collective.Group diff --git a/transformer_engine/paddle/distributed.py b/transformer_engine/paddle/distributed.py new file mode 100644 index 0000000000..5bf51c9274 --- /dev/null +++ b/transformer_engine/paddle/distributed.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Methods needed for distributed training.""" + +from contextlib import contextmanager +from typing import Optional, Union, Tuple + +import paddle + +import paddle.distributed.fleet.base.topology as tp +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.layers.mpu import mp_ops + +from .constants import dist_group_type + +_weight_split_axis = { + 'transformer_engine': { + 'row': 1, + 'column': 0 + }, + 'paddle': { + 'row': 0, + 'column': 1 + } +} + + +def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None], + enable_tp: bool = True) -> Tuple[Union[dist_group_type, None], int]: + """Get TP group and world size using Fleet API""" + if not (paddle.distributed.is_initialized() and enable_tp): + return None, 1 + model_parallel_group = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() + if tp_group is None else tp_group) + world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() + if tp_group is None else tp_group.nranks) + return model_parallel_group, world_size + + +@contextmanager +def track_rng_state(enable: bool) -> None: + """ + Applies get_rng_state_tracker().rng_state() to the context. + If not enabled, it does nothing. + """ + if enable: + with get_rng_state_tracker().rng_state(): + yield + else: + yield + + +def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None: + """Set distributed attributes for the input tensor""" + tensor.is_distributed = is_parallel + if is_parallel: + tensor.split_axis = axis + + +def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, + parallel_mode: Optional[str], backend: str) -> None: + """Set distributed attributes for the weight tensor""" + if not is_parallel or parallel_mode is None: + return + set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode]) + + +def allreduce( + input_: paddle.Tensor, + tp_group: Optional[dist_group_type] = None, +) -> paddle.Tensor: + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if tp_group is None or tp_group.nranks == 1: + return input_ + + # All-reduce. + output = mp_ops._mp_allreduce( + input_, + group=tp_group, + use_calc_stream=True, + use_model_parallel=True, + ) + + return output + + +def identity( + input_: paddle.Tensor, + tp_group: Optional[dist_group_type] = None, +) -> paddle.Tensor: + """ + Identity when forward. + Allreduce across model parallel group when backward. + """ + output = mp_ops._c_identity(input_, group=tp_group) + + return output diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py index bcd7ae2b22..576b8d859c 100644 --- a/transformer_engine/paddle/fp8.py +++ b/transformer_engine/paddle/fp8.py @@ -3,9 +3,8 @@ # See LICENSE for license information. """FP8 utilities for TransformerEngine""" -import copy from contextlib import contextmanager -from typing import Tuple, Optional, Dict, Any +from typing import Tuple, Optional, Dict, Any, Union import numpy as np @@ -13,6 +12,9 @@ import transformer_engine_paddle as tex from transformer_engine.common.recipe import DelayedScaling, Format +from .constants import dist_group_type +from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer + # FP8 support _is_fp8_available = None _reason_for_no_fp8 = "" @@ -50,21 +52,27 @@ class FP8State: """Stores FP8 state""" def __init__(self): - self.fp8_enabled = False - self.fp8_calibration = False - self.fp8_recipe = None + self._fp8_enabled = False + self._fp8_calibration = False + self._fp8_recipe = None + self._fp8_distributed_group = None + self._is_first_fp8_module = False + self._fp8_autocast_counter = 0 + self._fp8_autocast_depth = 0 + self._fp8_fwd_buffer = FP8MetaFwdBuffer() + self._fp8_bwd_buffer = FP8MetaBwdBuffer() def is_fp8_enabled(self) -> bool: """Is FP8 enabled""" - return self.fp8_enabled + return self._fp8_enabled def is_fp8_calibration(self) -> bool: """Is FP8 calibration""" - return self.fp8_calibration + return self._fp8_calibration def get_fp8_recipe(self) -> DelayedScaling: """Return the fp8 recipe""" - return self.fp8_recipe + return self._fp8_recipe @staticmethod def get_default_fp8_recipe() -> DelayedScaling: @@ -73,6 +81,63 @@ def get_default_fp8_recipe() -> DelayedScaling: """ return DelayedScaling() + def get_autocast_id(self) -> int: + """Returns the number of times of entering the `fp8_autocast` context. + as a unique ID for different training steps.""" + return self._fp8_autocast_counter + + def is_first_fp8_module(self): + """Returns `True` only the first time when called multiple + times from within the same `fp8_autocast` context. + """ + tmp = self._is_first_fp8_module + self._is_first_fp8_module = False + return tmp + + def get_fp8_group(self) -> Union[dist_group_type, None]: + """Return the fp8 group for scale/amax comm""" + return self._fp8_distributed_group + + def get_fp8_fwd_buffer(self) -> FP8MetaFwdBuffer: + """Returns global fp8 forward buffer.""" + return self._fp8_fwd_buffer + + def get_fp8_bwd_buffer(self) -> FP8MetaBwdBuffer: + """Returns global fp8 backward buffer.""" + return self._fp8_bwd_buffer + + def enter( + self, + enabled: bool, + calibrating: bool, + fp8_recipe: Optional[DelayedScaling], + fp8_group: Optional[dist_group_type], + ) -> None: + """Called when entering 'fp8_autocast'""" + self.saved_states = (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe, + self._fp8_distributed_group, self._is_first_fp8_module) + + self._fp8_enabled = enabled + self._fp8_calibration = calibrating + self._fp8_recipe = self.get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + self._fp8_distributed_group = fp8_group + + if self._fp8_autocast_depth == 0: + self._is_first_fp8_module = True + self._fp8_autocast_counter += 1 + self._fp8_autocast_depth += 1 + + def exit(self): + """Called when exiting 'fp8_autocast'""" + # Restore saved states + (self._fp8_enabled, self._fp8_calibration, self._fp8_recipe, self._fp8_distributed_group, + self._is_first_fp8_module) = self.saved_states + + self._fp8_autocast_depth -= 1 + + if self._fp8_autocast_depth == 0: + self._fp8_fwd_buffer.finalize() + _global_fp8_state = FP8State() @@ -87,25 +152,20 @@ def fp8_autocast( enabled: bool = False, calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, + fp8_group: Optional[dist_group_type] = None, ) -> None: """ Context manager for FP8 usage. """ - - global _global_fp8_state - saved_fp8_state = copy.deepcopy(_global_fp8_state) try: - _global_fp8_state.fp8_enabled = enabled - _global_fp8_state.fp8_calibration = calibrating - _global_fp8_state.fp8_recipe = FP8State.get_default_fp8_recipe( - ) if fp8_recipe is None else fp8_recipe + _global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group) if enabled: fp8_available, reason_for_no_fp8 = is_fp8_available() assert fp8_available, reason_for_no_fp8 yield finally: - _global_fp8_state = saved_fp8_state + _global_fp8_state.exit() def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: diff --git a/transformer_engine/paddle/fp8_buffer.py b/transformer_engine/paddle/fp8_buffer.py new file mode 100644 index 0000000000..76b0c9db59 --- /dev/null +++ b/transformer_engine/paddle/fp8_buffer.py @@ -0,0 +1,257 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""FP8 meta buffer for FP8 amax reduction""" + +from abc import ABC, abstractmethod +from functools import partial +import os +from typing import Dict, Any, List, Union + +import numpy as np +import paddle + +from .constants import dist_group_type + + +class FP8MetaBufferBase(ABC): + """ + A global buffer that holds FP8 meta for reduction across trainers. + """ + + def __init__(self): + self._data = {} + self._buffer_delete_key = None + self._amax_reduce_wait_func = None + self._dp_amax_reduce_interval = None + self._dp_amax_reduce_idx = 0 + + @staticmethod + @abstractmethod + def _get_meta_tensor_key(): + """Returns scaling key in `fp8_meta`.""" + + @staticmethod + @abstractmethod + def _get_buffer_position_key(): + """Returns module position key in `fp8_meta`.""" + + @staticmethod + @abstractmethod + def _get_autocast_key(): + """Returns autocast id key in `fp8_meta`.""" + + def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str: + """Return a key in `_data` for the AMAX storage.""" + return f"AMAX_{fp8_meta[self._get_autocast_key()]}" + + def _execute_deletion(self) -> None: + """Delete the key from global amax buffer.""" + if (self._buffer_delete_key is not None and self._buffer_delete_key in self._data): + del self._data[self._buffer_delete_key] + + def _wait_handle_and_split( + self, + contiguous_amax: paddle.Tensor, + chunk_sizes: List[int], + amax_buffer_key: str, + wait_handle: Union[bool, None], + ) -> None: + """Wait for amax reduction to finish and then copy reduced amax to buffer""" + if wait_handle is not None: + wait_handle.wait() + self._data[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) + + def _global_amax_reduction( + self, + fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, + ) -> None: + """Concatenate, reduce, and split amaxes in the global buffer.""" + + def _reduce_tensor_across_group_op_max(tensor, group, sync_op): + if paddle.distributed.is_initialized(): + wait_handle = paddle.distributed.all_reduce( + tensor, + op=paddle.distributed.ReduceOp.MAX, + group=group, + sync_op=sync_op, + ) + return wait_handle + return None + + amax_buffer_key = self._get_amax_buffer_key(fp8_meta) + # Key already deleted. + if amax_buffer_key not in self._data: + return None + + # Reduce AMAX in DP-domain at an interval. + if self._dp_amax_reduce_interval is None: + self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) + + tp_amax_reduce = False + if self._dp_amax_reduce_idx == 0: + reduce_group = fp8_meta["fp8_group"] + else: + tp_amax_reduce = True + self._dp_amax_reduce_idx = (self._dp_amax_reduce_idx + 1) % self._dp_amax_reduce_interval + + if tp_amax_reduce: + if tp_size > 1: + reduce_group = tp_group + else: + return None + + chunk_sizes = [x.shape[0] for x in self._data[amax_buffer_key]] + contiguous_amax = paddle.concat(self._data[amax_buffer_key]) + + wait_handle = _reduce_tensor_across_group_op_max( + contiguous_amax, + reduce_group, + not fp8_meta["async_amax_reduction"], + ) + + return partial( + self._wait_handle_and_split, + contiguous_amax, + chunk_sizes, + amax_buffer_key, + wait_handle, + ) + + def add_amax(self, fp8_meta: Dict[str, Any]) -> None: + """Append `amax_history` to global buffer.""" + buffer_key = self._get_amax_buffer_key(fp8_meta) + fp8_meta_tensor_key = self._get_meta_tensor_key() + buffer_position_key = self._get_buffer_position_key() + + if buffer_key not in self._data: + self._data[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + else: + self._data[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) + + if buffer_position_key not in fp8_meta: + fp8_meta[buffer_position_key] = len(self._data[buffer_key]) - 1 + + # Catch incorrect fp8_autocast usage. + assert fp8_meta[buffer_position_key] == len(self._data[buffer_key]) - 1, \ + "Same module is being invoked more than once inside an `fp8_autocast` " \ + "region when using FP8 with amax reduction. This behavior is currently " \ + "unsupported. For more details and correct usage, please see " \ + "https://github.com/NVIDIA/TransformerEngine/pull/93." + + def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None: + """Populate current amax with the correct location from buffer.""" + fp8_meta_tensor_key = self._get_meta_tensor_key() + buffer_position_key = self._get_buffer_position_key() + if buffer_position_key not in fp8_meta: + return + + amax_buffer_key = self._get_amax_buffer_key(fp8_meta) + assert amax_buffer_key in self._data, "TE internal error." + + fp8_meta[fp8_meta_tensor_key].amax_history[0] = self._data[amax_buffer_key][ + fp8_meta[buffer_position_key]] + + def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None: + """Delete this amax key from global buffer during autocast end.""" + if self._get_autocast_key() not in fp8_meta: + return + self._buffer_delete_key = self._get_amax_buffer_key(fp8_meta) + + def get_amax_reduce_handle(self) -> Union[bool, None]: + """Return AMAX reduction wait handle.""" + return self._amax_reduce_handle + + def wait(self) -> None: + """Wait for reduced amax to be available in buffer.""" + if self._amax_reduce_wait_func is not None: + self._amax_reduce_wait_func() # pylint: disable=not-callable + self._amax_reduce_wait_func = None + + def to_numpy(self) -> Dict[str, List[np.array]]: + """Convert to numpy arrays""" + out = {} + for k, v in self._data.items(): + out[k] = [tensor.numpy() for tensor in v] + return out + + def from_numpy(self, buffer: Dict[str, np.array]) -> None: + """Set buffer values from numpy arrays""" + for k, v in buffer.items(): + self._data[k] = [paddle.to_tensor(arr) for arr in v] + + +class FP8MetaFwdBuffer(FP8MetaBufferBase): + """FP8Meta Buffer for forward""" + + @staticmethod + def _get_meta_tensor_key() -> str: + """Returns scaling key in `fp8_meta`.""" + return "scaling_fwd" + + @staticmethod + def _get_buffer_position_key() -> str: + """Returns module position key in `fp8_meta`.""" + return "global_fp8_buffer_pos_fwd" + + @staticmethod + def _get_autocast_key() -> str: + """Returns module position key in `fp8_meta`.""" + return "autocast_id_fwd" + + def set_for_amax_reduction( + self, + fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, + ) -> None: + """Sets up the function to call during autocast exit.""" + self._amax_global_reduce_func = partial( + self._global_amax_reduction, + fp8_meta, + tp_group, + tp_size, + ) + + def finalize(self) -> None: + """ + Called at FP8 autocast end. + Performs AMAX reduction and delete unused buffer entries. + """ + if hasattr(self, '_amax_global_reduce_func') and callable(self._amax_global_reduce_func): + self._amax_reduce_wait_func = self._amax_global_reduce_func() + self._execute_deletion() + + +class FP8MetaBwdBuffer(FP8MetaBufferBase): + """FP8Meta Buffer for backward""" + + @staticmethod + def _get_meta_tensor_key() -> str: + """Returns scaling key in `fp8_meta`.""" + return "scaling_bwd" + + @staticmethod + def _get_buffer_position_key() -> str: + """Returns module position key in `fp8_meta`.""" + return "global_fp8_buffer_pos_bwd" + + @staticmethod + def _get_autocast_key() -> str: + """Returns module position key in `fp8_meta`.""" + return "autocast_id_bwd" + + def finalize( + self, + fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, + ) -> None: + """ + Called at FP8 autocast end in backward. + Performs AMAX reduction and delete unused buffer entries. + """ + self._amax_reduce_wait_func = self._global_amax_reduction(fp8_meta, tp_group, tp_size) + self._execute_deletion() diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index a5aac3566f..565321baad 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -4,27 +4,25 @@ """Attntion API""" import math +import os import warnings from typing import Optional, Tuple, Union import paddle import paddle.nn.functional as F -from transformer_engine.paddle.constants import ( - AttnTypes, - TE_DType, -) -from transformer_engine.paddle.cpp_extensions import ( +from .layernorm_linear import LayerNormLinear +from .linear import Linear +from .softmax import FusedScaleMaskSoftmax +from ..constants import AttnTypes, TE_DType, dist_group_type +from ..cpp_extensions import ( fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked, fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked, ) -from transformer_engine.paddle.utils import (attention_mask_func, mask_to_cu_seqlens) -from .base import TransformerEngineBaseLayer -from .layernorm_linear import LayerNormLinear -from .linear import Linear -from .softmax import FusedScaleMaskSoftmax +from ..distributed import get_tp_group_and_world_size, track_rng_state +from ..utils import attention_mask_func, divide, mask_to_cu_seqlens class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): @@ -161,9 +159,20 @@ def __init__(self, self.attn_mask_type = attn_mask_type self.attention_dropout = attention_dropout self.attention_type = attention_type - self.backend = backend self.rng_state = paddle.zeros((2,), dtype='int64') self.rng_state.persistable = True + + self.backend = backend + + arch = paddle.device.cuda.get_device_capability() + self.is_fused_attn_supported = arch in ((8, 0), (9, 0)) + self.enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", + "0")) and self.is_fused_attn_supported + + if not self.enable_fused_attn and backend == 'transformer_engine': + # FMHA is not enabled, falling back to Paddle backend + self.backend = 'paddle' + if self.backend != 'transformer_engine': self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type, attention_mask_func, @@ -343,7 +352,7 @@ def _pd_forward( return out -class MultiHeadAttention(TransformerEngineBaseLayer): +class MultiHeadAttention(paddle.nn.Layer): """Attention w/ QKV and Proj Gemms Parameters @@ -390,6 +399,8 @@ def __init__( input_layernorm: bool = False, attention_type: str = "self", zero_centered_gamma: bool = False, + set_parallel_mode: bool = False, + tp_group: Optional[dist_group_type] = None, backend: str = 'transformer_engine', ) -> None: super().__init__() @@ -403,11 +414,19 @@ def __init__( assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" + self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, + enable_tp=set_parallel_mode) + self.tensor_parallel = self.tp_size > 1 + self.hidden_size_per_attention_head = hidden_size // num_attention_heads self.num_attention_heads = num_attention_heads norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.set_parallel_mode = set_parallel_mode self.backend = backend + self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) + qkv_parallel_mode = "column" if set_parallel_mode else None + if self.attention_type == "self": if self.input_layernorm: self.layernorm_qkv = LayerNormLinear( @@ -418,6 +437,8 @@ def __init__( bias_attr=self.bias_attr, return_layernorm_output=return_layernorm_output, zero_centered_gamma=zero_centered_gamma, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) else: @@ -426,6 +447,8 @@ def __init__( 3 * hidden_size, self.weight_attr, self.bias_attr, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) @@ -439,6 +462,8 @@ def __init__( bias_attr=self.bias_attr, return_layernorm_output=return_layernorm_output, zero_centered_gamma=zero_centered_gamma, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) else: @@ -447,6 +472,8 @@ def __init__( hidden_size, self.weight_attr, self.bias_attr, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) self.key_value = Linear( @@ -454,6 +481,8 @@ def __init__( 2 * hidden_size, self.weight_attr, self.bias_attr, + parallel_mode=qkv_parallel_mode, + tp_group=self.tp_group, backend=self.backend, ) @@ -472,6 +501,8 @@ def __init__( hidden_size, self.weight_attr, self.bias_attr, + parallel_mode="row" if set_parallel_mode else None, + tp_group=self.tp_group, backend=self.backend, ) @@ -520,23 +551,26 @@ def forward( mixed_qkv_layer = self.qkv(hidden_states) # [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size] - mixed_qkv_layer = mixed_qkv_layer.reshape( - shape=[0, 0, 3, self.num_attention_heads, self.hidden_size_per_attention_head]) - - context_layer = self.core_attention( - query_layer=mixed_qkv_layer, - key_value_layer=None, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - ) + mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[ + 0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ]) + + with track_rng_state(enable=self.tensor_parallel): + context_layer = self.core_attention( + query_layer=mixed_qkv_layer, + key_value_layer=None, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + set_zero=set_zero, + ) else: # cross attention mixed_kv_layer = self.key_value(encoder_output) # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] - mixed_kv_layer = mixed_kv_layer.reshape( - shape=[0, 0, 2, self.num_attention_heads, self.hidden_size_per_attention_head]) + mixed_kv_layer = mixed_kv_layer.reshape(shape=[ + 0, 0, 2, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ]) if self.input_layernorm: layernorm_query_outputs = self.layernorm_query(hidden_states) @@ -547,16 +581,18 @@ def forward( else: query_layer = self.query_layer(hidden_states) - query_layer = query_layer.reshape( - shape=[0, 0, self.num_attention_heads, self.hidden_size_per_attention_head]) - context_layer = self.core_attention( - query_layer=query_layer, - key_value_layer=mixed_kv_layer, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - ) + query_layer = query_layer.reshape(shape=[ + 0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ]) + with track_rng_state(enable=self.tensor_parallel): + context_layer = self.core_attention( + query_layer=query_layer, + key_value_layer=mixed_kv_layer, + attention_mask=attention_mask, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + set_zero=set_zero, + ) context_layer = paddle.reshape(context_layer, [0, 0, context_layer.shape[2] * context_layer.shape[3]]) diff --git a/transformer_engine/paddle/layer/base.py b/transformer_engine/paddle/layer/base.py index 5e16fda098..0f5a1af65c 100644 --- a/transformer_engine/paddle/layer/base.py +++ b/transformer_engine/paddle/layer/base.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager +import os import pickle from typing import Generator, Dict, Tuple, Union, Any @@ -14,7 +15,7 @@ from paddle.fluid import core from paddle.fluid.framework import _dygraph_tracer -from ..constants import FP8BwdTensors +from ..constants import FP8BwdTensors, dist_group_type from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8 from ..fp8 import ( FP8State, @@ -24,7 +25,6 @@ get_fp8_te_dtype, ) from ..profile import nvtx_range -from ..utils import get_bias_dtype, cast_if_needed _2X_ACC_FPROP = False _2X_ACC_DGRAD = True @@ -61,9 +61,15 @@ def __init__(self) -> None: self.fp8_calibration = False self.fp8_meta = {} self.fp8_meta["fp8_checkpoint"] = False + self.fp8_meta["fp8_group"] = None self.fp8_meta["recipe"] = FP8State.get_default_fp8_recipe() self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True) self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False) + self.tp_group = None + self.tp_size = 1 + self.fp8_meta["autocast_id_fwd_stack"] = [] + self.fp8_meta["async_amax_reduction"] = bool( + int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))) def set_activation_dtype(self, inp: paddle.Tensor) -> None: """Get activation data type for AMP.""" @@ -102,18 +108,20 @@ def set_activation_dtype(self, inp: paddle.Tensor) -> None: # assume FP8 execution. def fp8_init(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - state = get_global_fp8_state() - self.fp8_enabled = state.is_fp8_enabled() - self.fp8_calibration = state.is_fp8_calibration() + global_fp8_state = get_global_fp8_state() + self.fp8_enabled = global_fp8_state.is_fp8_enabled() + self.fp8_calibration = global_fp8_state.is_fp8_calibration() self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration if self.fp8_enabled or self.fp8_calibration: # FP8 init has already been run and recipe is the same, don't do anything. - if self.fp8_initialized and state.get_fp8_recipe() == self.fp8_meta["recipe"]: + if self.fp8_initialized and global_fp8_state.get_fp8_recipe( + ) == self.fp8_meta["recipe"]: return # Set FP8, recipe, and other FP8 metadata - self.fp8_meta["recipe"] = state.get_fp8_recipe() + self.fp8_meta["recipe"] = global_fp8_state.get_fp8_recipe() + self.fp8_meta["fp8_group"] = global_fp8_state.get_fp8_group() # Set FP8_MAX per tensor according to recipe self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd @@ -136,6 +144,8 @@ def _get_fp8_state(self) -> paddle.Tensor: state = {} state["scaling_fwd"] = self.fp8_meta["scaling_fwd"].to_numpy() state["scaling_bwd"] = self.fp8_meta["scaling_bwd"].to_numpy() + state["global_fp8_fwd_buffer"] = get_global_fp8_state().get_fp8_fwd_buffer().to_numpy() + state["global_fp8_bwd_buffer"] = get_global_fp8_state().get_fp8_bwd_buffer().to_numpy() # Store other pickelable values. extra = {} for k, v in self.fp8_meta.items(): @@ -179,6 +189,12 @@ def _set_fp8_state(self, state: paddle.Tensor) -> None: self.fp8_meta["scaling_fwd"].from_numpy(state["scaling_fwd"]) self.fp8_meta["scaling_bwd"].from_numpy(state["scaling_bwd"]) + # Restore global FP8 buffer states. + global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() + global_fp8_bwd_buffer = get_global_fp8_state().get_fp8_bwd_buffer() + global_fp8_fwd_buffer.from_numpy(state["global_fp8_fwd_buffer"]) + global_fp8_bwd_buffer.from_numpy(state["global_fp8_bwd_buffer"]) + # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[ @@ -210,9 +226,22 @@ def prepare_forward( # Previous iteration was grad_enabled if self.fp8_meta.get("update_amax_and_scale_fwd", False): - amax_and_scale_update(self.fp8_meta, True) + global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() + global_fp8_fwd_buffer.wait() + if self.fp8_meta["recipe"].reduce_amax: + global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta) + amax_and_scale_update(self.fp8_meta, True) + global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta) + else: + amax_and_scale_update(self.fp8_meta, True) if self.fp8_enabled and self.training: + # Setup for amax reduction + if self.fp8_meta["recipe"].reduce_amax: + global_fp8_state = get_global_fp8_state() + self.fp8_meta["first_module"] = global_fp8_state.is_first_fp8_module() + self.fp8_meta["autocast_id_fwd"] = global_fp8_state.get_autocast_id() + self.fp8_meta["autocast_id_fwd_stack"].append(self.fp8_meta["autocast_id_fwd"]) self.fp8_meta["update_amax_and_scale_fwd"] = True else: self.fp8_meta["update_amax_and_scale_fwd"] = False @@ -220,18 +249,47 @@ def prepare_forward( with nvtx_range(self.__class__.__name__ + " forward"): yield inp + if self.fp8_enabled and self.training and self.fp8_meta["recipe"].reduce_amax: + global_fp8_state = get_global_fp8_state() + global_fp8_fwd_buffer = global_fp8_state.get_fp8_fwd_buffer() + global_fp8_fwd_buffer.add_amax(self.fp8_meta) + global_fp8_fwd_buffer.set_for_amax_reduction( + self.fp8_meta, + self.tp_group, + self.tp_size, + ) + @staticmethod @contextmanager def prepare_backward(fp8_enabled: bool, fp8_meta: Dict[str, Any], + tp_group: dist_group_type, + tp_size: int, name: str = "") -> Generator[None, None, None]: """Checks and prep for BWD.""" if fp8_enabled: - amax_and_scale_update(fp8_meta, False) + global_fp8_state = get_global_fp8_state() + global_fp8_bwd_buffer = global_fp8_state.get_fp8_bwd_buffer() + global_fp8_bwd_buffer.wait() + + if fp8_meta["recipe"].reduce_amax: + global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta) + amax_and_scale_update(fp8_meta, False) + global_fp8_bwd_buffer.set_for_deletion(fp8_meta) + + # Get new backward key. + fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) + else: + amax_and_scale_update(fp8_meta, False) with nvtx_range(name + " backward"): yield + if fp8_enabled and fp8_meta["recipe"].reduce_amax: + global_fp8_bwd_buffer.add_amax(fp8_meta) + if fp8_meta["first_module"]: + global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size) + @staticmethod def grad_output_preprocess( ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: @@ -258,8 +316,6 @@ def grad_output_preprocess( FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ) - bias_dtype = get_bias_dtype(ctx.activation_dtype) - bgrad = cast_if_needed(bgrad, bias_dtype) else: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: grad_output_c, grad_output_t = cast_transpose( diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py index 3f0b8c4a50..89c03ee25c 100644 --- a/transformer_engine/paddle/layer/layernorm.py +++ b/transformer_engine/paddle/layer/layernorm.py @@ -31,7 +31,7 @@ def forward( zero_centered_gamma: bool, ) -> paddle.Tensor: # Make sure input dimensions are compatible - in_features = ln_weight.numel() + in_features = ln_weight.shape[0] assert inp.shape[-1] == in_features, "LayerNorm not possible" inputmat = inp.reshape((-1, in_features)) diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py index 608f02a6ff..285cf4609a 100644 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ b/transformer_engine/paddle/layer/layernorm_linear.py @@ -4,7 +4,7 @@ """LayerNormLinear API""" import os -from typing import Union, Tuple, Dict, Any +from typing import Union, Tuple, Dict, Any, Optional import paddle import paddle.nn.functional as F @@ -21,9 +21,22 @@ from .base import TransformerEngineBaseLayer from .linear import _linear_fwd, _linear_bwd -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors +from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type +from ..distributed import ( + allreduce, + get_tp_group_and_world_size, + identity, + track_rng_state, + set_tensor_dist_attr, + set_weight_tensor_dist_attr, +) from ..fp8 import get_fp8_te_dtype -from ..utils import cast_if_needed, cast_if_needed_inplace, assert_dim_for_fp8_forward_exec +from ..utils import ( + assert_dim_for_fp8_forward_exec, + cast_if_needed, + cast_if_needed_inplace, + divide, +) __all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"] @@ -128,9 +141,13 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: # Make sure input dimensions are compatible - in_features = ln_weight.numel() + in_features = ln_weight.shape[0] assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.reshape((-1, in_features)) if fp8_enabled: @@ -169,6 +186,9 @@ def forward( fp8_calibration, fp8_meta, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, is_grad_enabled, ) @@ -192,6 +212,10 @@ def forward( ctx.return_layernorm_output = return_layernorm_output ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.parallel_mode = parallel_mode + ctx.tensor_parallel = tensor_parallel + ctx.tp_group = tp_group + ctx.tp_size = tp_size ctx.requires_dgrad = not inp.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_ln_bgrad = not ln_bias.stop_gradient @@ -208,6 +232,8 @@ def backward( ...]) -> Tuple[Union[paddle.Tensor, None], ...]: with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled, ctx.fp8_meta, + ctx.tp_group, + ctx.tp_size, name="_LayerNormLinear"): ( inputmat, @@ -262,6 +288,9 @@ def backward( ctx.fp8_meta, True, # Always compute dgrad to feed into LayerNorm bwd ctx.activation_dtype, + ctx.parallel_mode, + ctx.tensor_parallel, + ctx.tp_group, ) if not ctx.fp8_enabled: @@ -307,6 +336,8 @@ def __init__( bias_attr: Union[paddle.ParamAttr, None, bool] = None, return_layernorm_output: bool = False, zero_centered_gamma: bool = False, + parallel_mode: Optional[str] = None, + tp_group: Union[dist_group_type, None] = None, backend: str = 'transformer_engine', ) -> None: super().__init__() @@ -322,9 +353,23 @@ def __init__( self._bias_attr = bias_attr self._dtype = self._helper.get_default_dtype() + # Set parallel configs + self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, + enable_tp=parallel_mode + is not None) + self.tensor_parallel = self.tp_size > 1 + self.parallel_mode = parallel_mode + assert (self.parallel_mode + in GemmParallelModes), f"parallel_mode {parallel_mode} not supported" + + if self.parallel_mode == "column": + self.out_features = divide(self.out_features, self.tp_size) + elif self.parallel_mode == "row": + self.in_features = divide(self.in_features, self.tp_size) + # LayerNorm weights self.ln_weight = self.create_parameter( - shape=[in_features], + shape=[self.in_features], attr=paddle.ParamAttr(initializer=Constant( value=0.0 if self.zero_centered_gamma else 1.0)), dtype=self._dtype, @@ -332,34 +377,48 @@ def __init__( ) self.ln_bias = self.create_parameter( - shape=[in_features], + shape=[self.in_features], attr=paddle.ParamAttr(initializer=Constant(value=0.0)), dtype=self._dtype, is_bias=True, ) - # Linear weights - self.weight = self.create_parameter( - shape=[out_features, in_features] - if self.backend == 'transformer_engine' else [in_features, out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) + # Initialize Linear weight parameter + with track_rng_state(enable=self.tensor_parallel): + # TE linear weight is in column major + self.weight = self.create_parameter( + shape=[self.out_features, self.in_features] + if self.backend == 'transformer_engine' else [self.in_features, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode, + self.backend) + # Initialize Linear bias parameter self.has_bias = self._bias_attr is not False use_default_bias = self._bias_attr is None or self._bias_attr is True if self.has_bias: self.bias = self.create_parameter( - shape=[out_features], + shape=[self.out_features], attr=self._bias_attr if not use_default_bias else paddle.ParamAttr( initializer=Constant(value=0.0)), dtype=self._dtype, is_bias=True, ) + if parallel_mode == "column": + set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) else: self.bias = None + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: + self.gemm_bias_fused_add = False + else: + self.gemm_bias_fused_add = True + # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN # kernels from using all SMs in the device. This is useful for cases such as @@ -385,8 +444,8 @@ def _te_forward( self.ln_weight, self.ln_bias, self.weight, - self.bias, - self.has_bias, + self.bias if self.gemm_bias_fused_add else None, + self.has_bias and self.gemm_bias_fused_add, self.eps, self.fp8_enabled, self.fp8_calibration, @@ -397,10 +456,19 @@ def _te_forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.parallel_mode, + self.tensor_parallel, + self.tp_group, + self.tp_size, ) if self.return_layernorm_output: out, ln_out = out + + if not self.gemm_bias_fused_add: + out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) + + if self.return_layernorm_output: return out, ln_out return out @@ -418,7 +486,12 @@ def _pd_forward( weight=self.ln_weight, bias=self.ln_bias, epsilon=self.eps) - out = F.linear(ln_out, self.weight, self.bias) + if self.parallel_mode == 'column' and self.tensor_parallel: + ln_out = identity(ln_out, self.tp_group) + out = F.linear(ln_out, self.weight, self.bias if self.gemm_bias_fused_add else None) + if self.parallel_mode == 'row' and self.tensor_parallel: + out = allreduce(out, self.tp_group) + out = out + self.bias if self.bias is not None else out if self.return_layernorm_output: return out, ln_out return out diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py index 6d725114b0..9b89d05d47 100644 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ b/transformer_engine/paddle/layer/layernorm_mlp.py @@ -4,25 +4,38 @@ """LayerNormMLP API""" import os -from typing import Union, Tuple, Dict, Any +from typing import Union, Tuple, Dict, Any, Optional import paddle import paddle.nn.functional as F from paddle.nn.initializer import Constant +from .base import TransformerEngineBaseLayer +from .layernorm_linear import _layernorm_fwd_fp8_cast, _layernorm_bwd +from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8 +from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, dist_group_type from ..cpp_extensions import ( cast_from_fp8, dgelu_cast_transpose_bgrad_fp8, gelu_fp8, transpose, ) - -from .base import TransformerEngineBaseLayer -from .layernorm_linear import _layernorm_fwd_fp8_cast, _layernorm_bwd -from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8 -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors +from ..distributed import ( + allreduce, + get_tp_group_and_world_size, + identity, + track_rng_state, + set_tensor_dist_attr, + set_weight_tensor_dist_attr, +) from ..fp8 import get_fp8_te_dtype -from ..utils import cast_if_needed, assert_dim_for_fp8_forward_exec, get_paddle_act_func +from ..utils import ( + assert_dim_for_fp8_forward_exec, + cast_if_needed, + cast_if_needed_inplace, + divide, + get_paddle_act_func, +) __all__ = ["LayerNormMLP"] @@ -43,7 +56,11 @@ def _mlp_forward( fp8_calibration: bool, fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, + activation: str, is_grad_enabled: bool, + set_parallel_mode: bool, + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], ): if fp8_enabled: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -56,6 +73,9 @@ def _mlp_forward( use_fc1_bias, fp8_meta, activation_dtype, + 'column' if set_parallel_mode else None, + tensor_parallel, + tp_group, is_grad_enabled, ) @@ -75,6 +95,9 @@ def _mlp_forward( use_fc2_bias, fp8_meta, activation_dtype, + 'row' if set_parallel_mode else None, + tensor_parallel, + tp_group, is_grad_enabled, ) else: @@ -88,7 +111,10 @@ def _mlp_forward( fp8_calibration, fp8_meta, activation_dtype, - activation='gelu', + 'column' if set_parallel_mode else None, + tensor_parallel, + tp_group, + activation=activation, ) fc2_out = _linear_fwd_non_fp8( @@ -101,6 +127,9 @@ def _mlp_forward( fp8_calibration, fp8_meta, activation_dtype, + 'row' if set_parallel_mode else None, + tensor_parallel, + tp_group, ) return ( fc1_out, @@ -136,6 +165,9 @@ def _mlp_backward( requires_dgrad: bool, activation_dtype: paddle.dtype, activation: str, + set_parallel_mode: bool, + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], ): ( fc1_dgrad, @@ -179,6 +211,9 @@ def _mlp_backward( True, requires_fc2_wgrad, activation_dtype, + 'row' if set_parallel_mode else None, + tensor_parallel, + tp_group, ) # GELU Bwd @@ -193,7 +228,7 @@ def _mlp_backward( if requires_fc1_bgrad: fc1_bgrad = fc1_bgrad_ - # FC2 Bwd + # FC1 Bwd requires_fc1_wgrad = not fc1_weight.stop_gradient dgelu_no_fp8, fc1_input_no_fp8, fc1_input_t = None, None, None if requires_fc1_wgrad: @@ -231,6 +266,9 @@ def _mlp_backward( requires_dgrad, requires_fc1_wgrad, activation_dtype, + 'column' if set_parallel_mode else None, + tensor_parallel, + tp_group, ) else: dgelu, fc2_wgrad, fc2_bgrad = _linear_bwd_non_fp8( @@ -240,6 +278,9 @@ def _mlp_backward( requires_fc2_bgrad, True, activation_dtype, + 'row' if set_parallel_mode else None, + tensor_parallel, + tp_group, gelu_input=fc1_out, activation=activation, ) @@ -250,6 +291,9 @@ def _mlp_backward( requires_fc1_bgrad, requires_dgrad, activation_dtype, + 'column' if set_parallel_mode else None, + tensor_parallel, + tp_group, ) return ( fc1_dgrad, @@ -286,9 +330,13 @@ def forward( bwd_ln_sm_margin: int, zero_centered_gamma: bool, activation: str, + set_parallel_mode: bool, + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: # Make sure input dimensions are compatible - in_features = ln_weight.numel() + in_features = ln_weight.shape[0] assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.reshape((-1, in_features)) if fp8_enabled: @@ -341,7 +389,11 @@ def forward( fp8_calibration, fp8_meta, activation_dtype, + activation, is_grad_enabled, + set_parallel_mode, + tensor_parallel, + tp_group, ) if is_grad_enabled: @@ -369,6 +421,10 @@ def forward( ctx.return_layernorm_output = return_layernorm_output ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.set_parallel_mode = set_parallel_mode + ctx.tensor_parallel = tensor_parallel + ctx.tp_group = tp_group + ctx.tp_size = tp_size ctx.requires_dgrad = not inp.stop_gradient ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient @@ -387,6 +443,8 @@ def backward( ...]) -> Tuple[Union[paddle.Tensor, None], ...]: with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled, ctx.fp8_meta, + ctx.tp_group, + ctx.tp_size, name="_LayerNormMLP"): ( inputmat, @@ -442,6 +500,9 @@ def backward( True, ctx.activation_dtype, ctx.activation, + ctx.set_parallel_mode, + ctx.tensor_parallel, + ctx.tp_group, ) if not ctx.fp8_enabled: # fc2_bias is fused with gemm for non-FP8 path @@ -491,6 +552,8 @@ def __init__( activation: str = "gelu", return_layernorm_output: bool = False, zero_centered_gamma: bool = False, + set_parallel_mode: bool = False, + tp_group: Optional[dist_group_type] = None, backend: str = 'transformer_engine', ) -> None: super().__init__() @@ -507,6 +570,17 @@ def __init__( self._bias_attr = bias_attr self._dtype = self._helper.get_default_dtype() + # Set parallel configs + self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, + enable_tp=set_parallel_mode) + self.tensor_parallel = self.tp_size > 1 + self.set_parallel_mode = set_parallel_mode + + if self.set_parallel_mode: + self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size) + else: + self.size_per_partition = self.ffn_hidden_size + # LayerNorm weights self.ln_weight = self.create_parameter( shape=[self.hidden_size], @@ -524,36 +598,47 @@ def __init__( ) # FC1 weights - self.fc1_weight = self.create_parameter( - shape=[self.ffn_hidden_size, self.hidden_size] - if self.backend == 'transformer_engine' else [self.hidden_size, self.ffn_hidden_size], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) + with track_rng_state(enable=self.tensor_parallel): + self.fc1_weight = self.create_parameter( + shape=[self.size_per_partition, self.hidden_size] if self.backend + == 'transformer_engine' else [self.hidden_size, self.size_per_partition], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + set_weight_tensor_dist_attr(self.fc1_weight, + self.tensor_parallel, + parallel_mode='column', + backend=self.backend) self.has_bias = self._bias_attr is not False - if self._bias_attr is None or self._bias_attr is True: + use_default_bias = self._bias_attr is None or self._bias_attr is True + if use_default_bias: self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0)) if self.has_bias: self.fc1_bias = self.create_parameter( - shape=[self.ffn_hidden_size], + shape=[self.size_per_partition], attr=self._bias_attr, dtype=self._dtype, is_bias=True, ) + set_tensor_dist_attr(self.fc1_bias, self.tensor_parallel, axis=0) else: self.fc1_bias = None # FC2 weights self.fc2_weight = self.create_parameter( - shape=[self.hidden_size, self.ffn_hidden_size] - if self.backend == 'transformer_engine' else [self.ffn_hidden_size, self.hidden_size], + shape=[self.hidden_size, self.size_per_partition] if self.backend + == 'transformer_engine' else [self.size_per_partition, self.hidden_size], attr=self._weight_attr, dtype=self._dtype, is_bias=False, ) + set_weight_tensor_dist_attr(self.fc2_weight, + self.tensor_parallel, + parallel_mode='row', + backend=self.backend) if self.has_bias: self.fc2_bias = self.create_parameter( @@ -565,6 +650,13 @@ def __init__( else: self.fc2_bias = None + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.set_parallel_mode and self.tensor_parallel and self.has_bias: + self.gemm_bias_fused_add = False + else: + self.gemm_bias_fused_add = True + # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN # kernels from using all SMs in the device. This is useful for cases such as @@ -606,12 +698,20 @@ def _te_forward( self.bwd_ln_sm_margin, self.zero_centered_gamma, self.activation, + self.set_parallel_mode, + self.tensor_parallel, + self.tp_group, + self.tp_size, ) if self.return_layernorm_output: out, ln_out = out - return out, ln_out + if not self.gemm_bias_fused_add: + out = out + cast_if_needed_inplace(self.fc2_bias, self.activation_dtype) + + if self.return_layernorm_output: + return out, ln_out return out def _pd_forward( @@ -628,11 +728,16 @@ def _pd_forward( weight=self.ln_weight, bias=self.ln_bias, epsilon=self.eps) + if self.set_parallel_mode and self.tensor_parallel: + ln_out = identity(ln_out, self.tp_group) fc1_out = F.linear(ln_out, self.fc1_weight, self.fc1_bias) act_func = get_paddle_act_func(self.activation) act_out = act_func(fc1_out) - out = F.linear(act_out, self.fc2_weight, self.fc2_bias) - + out = F.linear(act_out, self.fc2_weight, + self.fc2_bias if self.gemm_bias_fused_add else None) + if self.set_parallel_mode and self.tensor_parallel: + out = allreduce(out, self.tp_group) + out = out + self.fc2_bias if self.fc2_bias is not None else out if self.return_layernorm_output: return out, ln_out return out diff --git a/transformer_engine/paddle/layer/linear.py b/transformer_engine/paddle/layer/linear.py index dc9863e062..ff164067a7 100644 --- a/transformer_engine/paddle/layer/linear.py +++ b/transformer_engine/paddle/layer/linear.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """Linear API""" -from typing import Union, Tuple, Dict, Any +from typing import Union, Tuple, Dict, Any, Optional import paddle import paddle.nn.functional as F @@ -17,13 +17,22 @@ _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype -from ..constants import FP8FwdTensors, FP8BwdTensors +from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose +from ..distributed import ( + allreduce, + get_tp_group_and_world_size, + identity, + track_rng_state, + set_tensor_dist_attr, + set_weight_tensor_dist_attr, +) +from ..fp8 import get_fp8_te_dtype from ..utils import ( + assert_dim_for_fp8_forward_exec, cast_if_needed, cast_if_needed_inplace, - assert_dim_for_fp8_forward_exec, + divide, get_bias_dtype, ) @@ -39,12 +48,15 @@ def _linear_fwd_fp8( use_bias: bool, fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], is_grad_enabled: bool, ): """FP8 path of Linear Fwd""" fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) bias_dtype = get_bias_dtype(activation_dtype) - bias = cast_if_needed_inplace(bias, bias_dtype) + bias = cast_if_needed(bias, bias_dtype) if is_grad_enabled: weight_fp8, weight_t_fp8 = cast_transpose( @@ -78,6 +90,10 @@ def _linear_fwd_fp8( use_split_accumulator=_2X_ACC_FPROP, ) + # Row Parallel Linear + if parallel_mode == "row" and tensor_parallel: + out = allreduce(out, tp_group) + return out, weight_t_fp8 @@ -91,6 +107,9 @@ def _linear_fwd_non_fp8( fp8_calibration: bool, fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], activation: str = "", ): """Non-FP8 path of Linear Fwd""" @@ -123,6 +142,9 @@ def _linear_fwd_non_fp8( return out, gelu_out out, _, _ = outputs + # Row Parallel Linear + if parallel_mode == "row" and tensor_parallel: + out = allreduce(out, tp_group) return out @@ -137,6 +159,9 @@ def _linear_fwd( fp8_calibration: bool, fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], is_grad_enabled: bool, ): if fp8_enabled: @@ -149,6 +174,9 @@ def _linear_fwd( use_bias, fp8_meta, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, is_grad_enabled, ) else: @@ -162,6 +190,9 @@ def _linear_fwd( fp8_calibration, fp8_meta, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, ) return ( out, @@ -184,6 +215,9 @@ def _linear_bwd_fp8( requires_dgrad: bool, requires_wgrad: bool, activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], ): dgrad, wgrad = None, None fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -202,6 +236,9 @@ def _linear_bwd_fp8( get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, ) + if parallel_mode == "column" and tensor_parallel: + dgrad = allreduce(dgrad, tp_group) + if requires_wgrad: if not fp8_meta["recipe"].override_linear_precision.wgrad: wgrad = fp8_gemm( @@ -236,6 +273,9 @@ def _linear_bwd_non_fp8( requires_bgrad: bool, requires_dgrad: bool, activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], gelu_input: Union[paddle.Tensor, None] = None, activation: str = "", ): @@ -255,6 +295,9 @@ def _linear_bwd_non_fp8( gelu_input=gelu_input, grad=True, ) + if parallel_mode == "column" and tensor_parallel: + dgrad = allreduce(dgrad, tp_group) + if requires_wgrad: wgrad, bgrad, _ = gemm( inputmat, @@ -288,6 +331,9 @@ def _linear_bwd( fp8_meta: Dict[str, Any], requires_dgrad: bool, activation_dtype: paddle.dtype, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], ): dgrad, wgrad, bgrad = None, None, None requires_wgrad = not weight.stop_gradient @@ -307,6 +353,9 @@ def _linear_bwd( requires_dgrad, requires_wgrad, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, ) else: dgrad, wgrad, bgrad = _linear_bwd_non_fp8( @@ -316,6 +365,9 @@ def _linear_bwd( requires_bgrad, requires_dgrad, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, ) return dgrad, wgrad, bgrad @@ -335,6 +387,10 @@ def forward( fp8_meta: Dict[str, Any], activation_dtype: paddle.dtype, is_grad_enabled: bool, + parallel_mode: Union[str, None], + tensor_parallel: bool, + tp_group: Union[dist_group_type, None], + tp_size: int, ) -> paddle.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -385,6 +441,9 @@ def forward( fp8_calibration, fp8_meta, activation_dtype, + parallel_mode, + tensor_parallel, + tp_group, is_grad_enabled, ) @@ -402,6 +461,10 @@ def forward( ctx.fp8_meta = fp8_meta ctx.use_bias = use_bias ctx.inp_shape = inp.shape + ctx.parallel_mode = parallel_mode + ctx.tensor_parallel = tensor_parallel + ctx.tp_group = tp_group + ctx.tp_size = tp_size ctx.requires_dgrad = not inp.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient @@ -411,6 +474,8 @@ def forward( def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled, ctx.fp8_meta, + ctx.tp_group, + ctx.tp_size, name="_Linear"): ( inputmat, @@ -444,6 +509,9 @@ def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None ctx.fp8_meta, ctx.requires_dgrad, ctx.activation_dtype, + ctx.parallel_mode, + ctx.tensor_parallel, + ctx.tp_group, ) if not ctx.fp8_enabled: @@ -474,6 +542,8 @@ def __init__( out_features: int, weight_attr: Union[paddle.ParamAttr, None] = None, bias_attr: Union[paddle.ParamAttr, None, bool] = None, + parallel_mode: Optional[str] = None, + tp_group: Union[dist_group_type, None] = None, backend: str = 'transformer_engine', ) -> None: super().__init__() @@ -484,28 +554,56 @@ def __init__( self._bias_attr = bias_attr self._dtype = self._helper.get_default_dtype() - # TE linear weight is in column major - self.weight = self.create_parameter( - shape=[out_features, in_features] - if self.backend == 'transformer_engine' else [in_features, out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) + # Set parallel configs + self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group, + enable_tp=parallel_mode + is not None) + self.tensor_parallel = self.tp_size > 1 + self.parallel_mode = parallel_mode + assert (self.parallel_mode + in GemmParallelModes), f"parallel_mode {parallel_mode} not supported" + + if self.parallel_mode == "column": + self.out_features = divide(self.out_features, self.tp_size) + elif self.parallel_mode == "row": + self.in_features = divide(self.in_features, self.tp_size) + + # Initialize weight parameter + with track_rng_state(enable=self.tensor_parallel): + # TE linear weight is in column major + self.weight = self.create_parameter( + shape=[self.out_features, self.in_features] + if self.backend == 'transformer_engine' else [self.in_features, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode, + self.backend) + # Initialize bias parameter self.has_bias = self._bias_attr is not False use_default_bias = self._bias_attr is None or self._bias_attr is True if self.has_bias: self.bias = self.create_parameter( - shape=[out_features], + shape=[self.out_features], attr=self._bias_attr if not use_default_bias else paddle.ParamAttr( initializer=Constant(value=0.0)), dtype=self._dtype, is_bias=True, ) + if parallel_mode == "column": + set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) else: self.bias = None + # For RPL, bias has to be added after TP collectives + # So it cannot be fused with the GEMM + if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: + self.gemm_bias_fused_add = False + else: + self.gemm_bias_fused_add = True + def _te_forward( self, inp: paddle.Tensor, @@ -521,15 +619,22 @@ def _te_forward( out = _Linear.apply( self.weight, inp, - self.bias, - self.has_bias, + self.bias if self.gemm_bias_fused_add else None, + self.has_bias and self.gemm_bias_fused_add, self.fp8_enabled, self.fp8_calibration, self.fp8_meta, self.activation_dtype, paddle.is_grad_enabled(), + self.parallel_mode, + self.tensor_parallel, + self.tp_group, + self.tp_size, ) + if not self.gemm_bias_fused_add: + out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) + return out def _pd_forward( @@ -537,7 +642,13 @@ def _pd_forward( inp: paddle.Tensor, ) -> paddle.Tensor: """Calls Paddle OP""" - return F.linear(inp, self.weight, self.bias) + if self.parallel_mode == 'column' and self.tensor_parallel: + inp = identity(inp, self.tp_group) + out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None) + if self.parallel_mode == 'row' and self.tensor_parallel: + out = allreduce(out, self.tp_group) + out = out + self.bias if self.bias is not None else out + return out def forward(self, *args, **kwargs): """forward""" diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py index 6e6afd4ca2..a95b9fcfe1 100644 --- a/transformer_engine/paddle/layer/transformer.py +++ b/transformer_engine/paddle/layer/transformer.py @@ -7,15 +7,11 @@ import paddle -from transformer_engine.paddle.constants import ( - AttnMaskTypes, - LayerTypes, -) -from transformer_engine.paddle.layer import (LayerNormMLP, LayerNorm, MultiHeadAttention) -from .base import TransformerEngineBaseLayer +from . import LayerNormMLP, LayerNorm, MultiHeadAttention +from ..constants import AttnMaskTypes, LayerTypes, dist_group_type -class TransformerLayer(TransformerEngineBaseLayer): +class TransformerLayer(paddle.nn.Layer): r""" TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper "Attention Is All You Need". @@ -64,6 +60,16 @@ class TransformerLayer(TransformerEngineBaseLayer): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + + Parallelism parameters + ---------------------- + set_parallel_mode : bool, default = `False` + if set to `True`, QKV and FC1 layers are used as Column Parallel + whereas PROJ and FC2 is used as Row Parallel as described + `here `_. + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + """ def __init__(self, @@ -82,6 +88,8 @@ def __init__(self, layer_type: str = "encoder", zero_centered_gamma: bool = False, activation: str = 'gelu', + set_parallel_mode: bool = False, + tp_group: Optional[dist_group_type] = None, backend: str = 'transformer_engine') -> None: super().__init__() @@ -90,6 +98,8 @@ def __init__(self, self.layer_type = layer_type self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.self_attn_mask_type = self_attn_mask_type + self.set_parallel_mode = set_parallel_mode + self.tp_group = tp_group assert (self_attn_mask_type in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported" @@ -107,6 +117,8 @@ def __init__(self, "params_dtype": params_dtype, "return_layernorm_output": apply_residual_connection_post_layernorm, "zero_centered_gamma": zero_centered_gamma, + "set_parallel_mode": set_parallel_mode, + "tp_group": tp_group, "backend": backend, } @@ -136,6 +148,8 @@ def __init__(self, activation=activation, return_layernorm_output=apply_residual_connection_post_layernorm, zero_centered_gamma=zero_centered_gamma, + set_parallel_mode=set_parallel_mode, + tp_group=tp_group, backend=backend, )