-
Notifications
You must be signed in to change notification settings - Fork 363
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Paddle] Add parallel support (#357)
* [Paddle] Add TP, DP, PP, FSDP Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Minor fix Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Fix CI failure Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Remove set_nccl_overlap_warning_if_tp Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Improve variable naming Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Refactor FP8 Buffer Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Stylic changes Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Fix FP32 parallel training Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Fix numel performance issue Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Squashed commit of the following: commit 79e2e5fd774e67dcdda9aae01a9f31a6479c5d70 Author: Tian Zheng (Engrg-Hardware 1) <[email protected]> Date: Sun Aug 20 14:39:16 2023 +0000 Add TP test Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> commit 1d40ad60540490f97ed82ba877cc6eda8902cbf6 Author: Tian Zheng (Engrg-Hardware 1) <[email protected]> Date: Sun Aug 20 14:22:25 2023 +0000 Fix tp_size when disabled Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> commit 6632f735a0c8251862355fc74622af59fae3a509 Author: Tian Zheng (Engrg-Hardware 1) <[email protected]> Date: Sun Aug 20 05:52:18 2023 +0000 Add TP for attention and transformer layer Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Add shape check Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Add FSDP check for stage 1,2,3 Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Review changes Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Fix group_sharding test Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Support NVTE_FUSE_ATTN Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> * Fix CI errors Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> --------- Signed-off-by: Tian Zheng (Engrg-Hardware 1) <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
- Loading branch information
Showing
24 changed files
with
2,248 additions
and
140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
"""Helper functions to launch distributed tests""" | ||
|
||
import copy | ||
import os | ||
from pathlib import Path | ||
import subprocess | ||
import time | ||
import unittest | ||
|
||
from paddle import fluid | ||
from paddle.distributed.utils.launch_utils import ( | ||
TrainerProc, | ||
find_free_ports, | ||
get_cluster, | ||
watch_local_trainers, | ||
) | ||
|
||
__all__ = ['TestDistributed'] | ||
|
||
|
||
def get_cluster_from_args(selected_gpus): | ||
"""Get node information from selected GPUs""" | ||
cluster_node_ips = '127.0.0.1' | ||
node_ip = '127.0.0.1' | ||
|
||
node_ips = [x.strip() for x in cluster_node_ips.split(',')] | ||
|
||
node_ips.index(node_ip) | ||
|
||
free_ports = None | ||
|
||
free_ports = find_free_ports(len(selected_gpus)) | ||
if free_ports is not None: | ||
free_ports = list(free_ports) | ||
|
||
trainer_endpoints = [] | ||
for ip in node_ips: | ||
trainer_endpoints.append([f"{ip}:{port}" for port in free_ports]) | ||
return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) | ||
|
||
|
||
def get_gpus(selected_gpus): | ||
"""Get selected GPU string""" | ||
selected_gpus = [x.strip() for x in selected_gpus.split(',')] | ||
return selected_gpus | ||
|
||
|
||
def start_local_trainers( | ||
cluster, | ||
pod, | ||
training_script, | ||
training_script_args, | ||
allocator_strategy="auto_growth", | ||
): | ||
"""Launch trainers""" | ||
current_env = copy.copy(os.environ.copy()) | ||
# paddle broadcast ncclUniqueId use socket, and | ||
# proxy maybe make trainers unreachable, so delete them. | ||
# if we set them to "", grpc will log error message "bad uri" | ||
# so just delete them. | ||
current_env.pop("http_proxy", None) | ||
current_env.pop("https_proxy", None) | ||
|
||
procs = [] | ||
for t in pod.trainers: | ||
proc_env = { | ||
"FLAGS_selected_gpus": ",".join([str(g) for g in t.gpus]), | ||
"PADDLE_TRAINER_ID": f"{t.rank}", | ||
"PADDLE_CURRENT_ENDPOINT": f"{t.endpoint}", | ||
"PADDLE_TRAINERS_NUM": f"{cluster.trainers_nranks()}", | ||
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), | ||
"PYTHONPATH": str(Path(__file__).resolve().parent), | ||
} | ||
|
||
proc_env["FLAGS_allocator_strategy"] = allocator_strategy | ||
if allocator_strategy == "auto_growth": | ||
proc_env["FLAGS_fraction_of_gpu_memory_to_use"] = "0.1" | ||
|
||
current_env.update(proc_env) | ||
|
||
print(f"trainer proc env:{current_env}") | ||
|
||
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': | ||
cmd = "python -m coverage run --branch -p " + training_script | ||
else: | ||
cmd = "python -u " + training_script | ||
|
||
print(f"start trainer proc:{cmd} env:{proc_env}") | ||
|
||
fn = None | ||
|
||
proc = subprocess.Popen(cmd.split(" ") + training_script_args, env=current_env) # pylint: disable=consider-using-with | ||
|
||
tp = TrainerProc() | ||
tp.proc = proc | ||
tp.rank = t.rank | ||
tp.log_fn = fn | ||
tp.cmd = cmd | ||
|
||
procs.append(tp) | ||
|
||
return procs | ||
|
||
|
||
class TestDistributed(unittest.TestCase): | ||
"""Base class for distributed test""" | ||
|
||
@staticmethod | ||
def run_2gpu( | ||
target_file_name, | ||
allocator_strategy="auto_growth", | ||
): | ||
"""Run target file in subprocesses""" | ||
if (not fluid.core.is_compiled_with_cuda() or fluid.core.get_cuda_device_count() == 0): | ||
return | ||
|
||
selected_gpus = get_gpus('0,1') | ||
cluster = None | ||
pod = None | ||
|
||
cluster, pod = get_cluster_from_args(selected_gpus) | ||
|
||
procs = start_local_trainers( | ||
cluster, | ||
pod, | ||
allocator_strategy=allocator_strategy, | ||
training_script=target_file_name, | ||
training_script_args=[], | ||
) | ||
|
||
while True: | ||
alive = watch_local_trainers(procs, cluster.trainers_endpoints()) | ||
|
||
if not alive: | ||
print(f"Local procs complete, POD info:{pod}") | ||
break | ||
time.sleep(3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
"""Unittest for Linear layer in tensor parallel""" | ||
|
||
import unittest | ||
|
||
import paddle | ||
from paddle.distributed import fleet | ||
|
||
from utils import assert_allclose, set_random_seed | ||
import transformer_engine.paddle as te | ||
|
||
|
||
def assert_allclose_across_ranks(tensor, group=None): | ||
"""Assert tensor is identical in all ranks""" | ||
gathered_list = [] | ||
paddle.distributed.all_gather(gathered_list, tensor, group=group) | ||
assert len(gathered_list) > 1 | ||
for gathered_tensor in gathered_list: | ||
assert_allclose(tensor, gathered_tensor) | ||
|
||
|
||
class TestAmaxReduction(unittest.TestCase): | ||
"""Tests Amax reduction""" | ||
|
||
def setUp(self): | ||
self.data_parallel_size = 2 | ||
self.init_dist_env() | ||
self.global_dtype = 'bfloat16' | ||
paddle.set_default_dtype(self.global_dtype) | ||
|
||
def init_dist_env(self): | ||
"""Init Paddle Fleet environment""" | ||
strategy = fleet.DistributedStrategy() | ||
strategy.hybrid_configs = { | ||
"dp_degree": self.data_parallel_size, | ||
"mp_degree": 1, | ||
"pp_degree": 1, | ||
} | ||
fleet.init(is_collective=True, strategy=strategy) | ||
|
||
def test_amax_reduction(self): | ||
"""Tests column parallel linear""" | ||
set_random_seed(1024) | ||
layer1 = te.Linear(16, 16) | ||
layer2 = te.Linear(16, 16) | ||
model = paddle.nn.Sequential(layer1, layer2) | ||
model = fleet.distributed_model(model) | ||
|
||
rank_id = paddle.distributed.get_rank() | ||
set_random_seed(rank_id) | ||
|
||
optimizer = paddle.optimizer.SGD(learning_rate=10.0, parameters=model.parameters()) | ||
optimizer = fleet.distributed_optimizer(optimizer) | ||
|
||
def train_one_step(layer, inp, optimizer): | ||
inp = paddle.to_tensor(inp) | ||
inp.stop_gradient = False | ||
out = layer(inp) | ||
loss = out.mean() | ||
loss.backward() | ||
optimizer.step() | ||
optimizer.clear_grad() | ||
return loss | ||
|
||
for _ in range(5): | ||
inp = paddle.uniform([16, 16], self.global_dtype) | ||
with te.fp8_autocast(enabled=True): | ||
train_one_step(model, inp, optimizer) | ||
|
||
assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].amax_history[-1]) | ||
assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale) | ||
assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale_inv) | ||
assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].amax_history[-1]) | ||
assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale) | ||
assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale_inv) | ||
assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].amax_history[-1]) | ||
assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale) | ||
assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale_inv) | ||
assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].amax_history[-1]) | ||
assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale) | ||
assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.