Skip to content

Commit

Permalink
feat: support process group (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangraying authored Oct 15, 2021
1 parent 71e470d commit cd499b8
Show file tree
Hide file tree
Showing 15 changed files with 706 additions and 322 deletions.
3 changes: 2 additions & 1 deletion .buildkite/scripts/run_pytest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ echo "$BUILDKITE_PARALLEL_JOB_COUNT"
set -euo pipefail
cp -a /upstream /workdir
export HOME=/workdir && cd $HOME && bash .buildkite/scripts/install_bagua.sh || exit 1
pytest -s -o "testpaths=tests"
pip install pytest-timeout
pytest --timeout=300 -s -o "testpaths=tests"
2 changes: 2 additions & 0 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ def init_operations(
bucket.append_centralized_synchronous_op(
hierarchical=False,
average=True,
group=bagua_module._bagua_process_group,
)
else:
async_op = bucket.append_asynchronous_model_average_op(
peer_selection_mode=self.peer_selection_mode,
group=bagua_module._bagua_process_group,
)
bucket._async_op = async_op

Expand Down
1 change: 1 addition & 0 deletions bagua/torch_api/algorithms/bytegrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ def init_operations(
average=self.average,
scattergather=True,
compression="MinMaxUInt8",
group=bagua_module._bagua_process_group,
)
2 changes: 2 additions & 0 deletions bagua/torch_api/algorithms/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def init_operations(
peer_weight=bucket._peer_weight,
hierarchical=self.hierarchical,
peer_selection_mode=self.peer_selection_mode,
group=bagua_module._bagua_process_group,
)


Expand Down Expand Up @@ -187,4 +188,5 @@ def init_operations(
right_peer_weight=bucket._right_peer_weight,
hierarchical=self.hierarchical,
compression="MinMaxUInt8",
group=bagua_module._bagua_process_group,
)
15 changes: 5 additions & 10 deletions bagua/torch_api/algorithms/gradient_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,8 @@ def init_operations(
bucket: BaguaBucket,
):
bucket.clear_ops()
if self.hierarchical:
bucket.append_centralized_synchronous_op(
hierarchical=self.hierarchical,
average=self.average,
)
else:
bucket.append_centralized_synchronous_op(
hierarchical=self.hierarchical,
average=self.average,
)
bucket.append_centralized_synchronous_op(
hierarchical=self.hierarchical,
average=self.average,
group=bagua_module._bagua_process_group,
)
2 changes: 2 additions & 0 deletions bagua/torch_api/algorithms/q_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def init_operations(
bucket.append_centralized_synchronous_op(
hierarchical=False,
average=True,
group=bagua_module._bagua_process_group,
)
else:

Expand All @@ -186,6 +187,7 @@ def calculate_momentum(*args):
average=True,
scattergather=True,
compression="MinMaxUInt8",
group=bagua_module._bagua_process_group,
)

def init_backward_hook(self, bagua_module: BaguaModule):
Expand Down
64 changes: 47 additions & 17 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
#!/usr/bin/env python3

from __future__ import annotations
from bagua.torch_api.communication import get_backend
from bagua.torch_api.communication import get_backend, _get_default_group
from typing import List, Callable, Optional

import bagua_core as B
import torch

from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.utils import check_contiguous
from bagua.torch_api.communication import broadcast
from bagua.torch_api.communication import (
broadcast,
BaguaProcessGroup,
_bagua_backend_comm,
_rank_not_in_comm,
)


class BaguaBucket:
Expand Down Expand Up @@ -157,6 +162,7 @@ def append_centralized_synchronous_op(
average: bool = True,
scattergather: bool = False,
compression: Optional[str] = None,
group: Optional[BaguaProcessGroup] = None,
):
"""
Append a centralized synchronous operation to a bucket. It will sum or average the tensors in the bucket
Expand All @@ -174,19 +180,23 @@ def append_centralized_synchronous_op(
of allreduce. This is required for using compression.
compression: If not ``None``, the tensors will be compressed for communication. Currently ``"MinMaxUInt8"`` is
supported.
group: The process group to work on. If ``None``, the default process group will be used.
"""
if group is None:
group = _get_default_group()

if hierarchical:
self.backend_bucket.append_centralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
_bagua_backend_comm(group.get_inter_node_communicator()),
_bagua_backend_comm(group.get_intra_node_communicator()),
hierarchical=hierarchical,
average=average,
scattergather=scattergather,
compression=compression,
)
else:
self.backend_bucket.append_centralized_synchronous_op(
self._bagua_backend.global_communicator,
_bagua_backend_comm(group.get_global_communicator()),
None,
hierarchical=hierarchical,
average=average,
Expand All @@ -199,6 +209,7 @@ def append_decentralized_synchronous_op(
peer_weight: BaguaTensor,
hierarchical: bool = True,
peer_selection_mode: str = "all",
group: Optional[BaguaProcessGroup] = None,
):
"""
Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers.
Expand All @@ -219,27 +230,33 @@ def append_decentralized_synchronous_op(
peer_selection_mode (str): Can be ``"all"`` or ``"shift_one"``. ``"all"`` means all workers' weights are averaged
in each communication step. ``"shift_one"`` means each worker selects a different peer to do weights average
in each communication step.
group: The process group to work on. If ``None``, the default process group will be used.
"""
if group is None:
group = _get_default_group()

if hierarchical:
self.backend_bucket.append_decentralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
_bagua_backend_comm(group.get_inter_node_communicator()),
_bagua_backend_comm(group.get_intra_node_communicator()),
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
)
else:
self.backend_bucket.append_decentralized_synchronous_op(
self._bagua_backend.global_communicator,
_bagua_backend_comm(group.get_global_communicator()),
None,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
)

def decentralized_synchronous_op_copy_back_peer_weight(
self, peer_weight: BaguaTensor, hierarchical: bool = True
self,
peer_weight: BaguaTensor,
hierarchical: bool = True,
group: Optional[BaguaProcessGroup] = None,
):
"""
Copy :attr:`peer_weight` back to bucket weights to end a decentralized synchronous operation.
Expand All @@ -252,11 +269,15 @@ def decentralized_synchronous_op_copy_back_peer_weight(
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high. Must be the same with :attr:`hierarchical` argument in
:meth:`append_decentralized_synchronous_op`.
group: The process group to work on. If ``None``, the default process group will be used.
"""
intra_comm = self._bagua_backend.intranode_communicator
inter_comm = self._bagua_backend.internode_communicator
if group is None:
group = _get_default_group()

intra_comm = group.get_intra_node_communicator()
inter_comm = group.get_inter_node_communicator()

if not hierarchical or (inter_comm is not None):
if not hierarchical or not _rank_not_in_comm(inter_comm):
self.backend_tensor.copy_(peer_weight)

if hierarchical:
Expand All @@ -269,6 +290,7 @@ def append_low_precision_decentralized_synchronous_op(
right_peer_weight: BaguaTensor,
hierarchical: bool = True,
compression: str = "MinMaxUInt8",
group: Optional[BaguaProcessGroup] = None,
):
"""
Append a low precision decentralized synchronous operation to a bucket. It will compress the difference
Expand All @@ -290,12 +312,15 @@ def append_low_precision_decentralized_synchronous_op(
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
compression (str): The way how tensors are compressed for communication. Currently ``"MinMaxUInt8"`` is supported.
group: The process group to work on. If ``None``, the default process group will be used.
"""
if group is None:
group = _get_default_group()

if hierarchical:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
_bagua_backend_comm(group.get_inter_node_communicator()),
_bagua_backend_comm(group.get_intra_node_communicator()),
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
Expand All @@ -305,7 +330,7 @@ def append_low_precision_decentralized_synchronous_op(
)
else:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
self._bagua_backend.global_communicator,
_bagua_backend_comm(group.get_global_communicator()),
None,
hierarchical=hierarchical,
peer_selection_mode="ring",
Expand All @@ -315,7 +340,9 @@ def append_low_precision_decentralized_synchronous_op(
right_peer_weight=right_peer_weight._bagua_backend_tensor,
)

def append_asynchronous_model_average_op(self, peer_selection_mode: str):
def append_asynchronous_model_average_op(
self, peer_selection_mode: str, group: Optional[BaguaProcessGroup] = None
):

"""
Append an asynchronous model average operation to a bucket. This operation will enable continuous
Expand All @@ -331,12 +358,15 @@ def append_asynchronous_model_average_op(self, peer_selection_mode: str):
Args:
peer_selection_mode (str): The way how workers communicate with each otehr. Currently ``"all"`` is supported.
``"all"`` means all workers' weights are averaged during each communication.
group: The process group to work on. If ``None``, the default process group will be used.
Returns:
The asynchronous model average operation itself.
"""
if group is None:
group = _get_default_group()

return self.backend_bucket.append_decentralized_asynchronous_op(
self._bagua_backend.global_communicator,
_bagua_backend_comm(group.get_global_communicator()),
None,
peer_selection_mode=peer_selection_mode,
torch_stream=torch.cuda.current_stream().cuda_stream,
Expand Down
Loading

0 comments on commit cd499b8

Please sign in to comment.