Skip to content

Commit

Permalink
Implement @data_distributed
Browse files Browse the repository at this point in the history
Initial implemented of data_distributed
  • Loading branch information
breakds committed Dec 1, 2021
1 parent 5dd5900 commit cd9c8e4
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 76 deletions.
13 changes: 13 additions & 0 deletions alf/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def __init__(self,
self._metrics = []
self._replay_buffer = None

self._ddp_activated_rank = -1

# These 3 parameters are only set when ``set_replay_buffer()`` is called.
self._replay_buffer_num_envs = None
self._replay_buffer_max_length = None
Expand Down Expand Up @@ -213,6 +215,17 @@ def use_rollout_state(self):
"""
return self._use_rollout_state

def activate_ddp(self, rank: int):
"""Prepare the RLAlgorithm with DistributedDataParallel wrapper
Note that RLAlgorithm does not need to remember the rank of the device.
Args:
rank (int): DDP wrapper needs to know on which GPU device this
module's parameters and buffers are supposed to be.
"""
self._ddp_activated_rank = rank

@use_rollout_state.setter
def use_rollout_state(self, flag):
self._use_rollout_state = flag
Expand Down
79 changes: 3 additions & 76 deletions alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,49 +26,11 @@
from alf.data_structures import AlgStep, Experience, make_experience, TimeStep
from alf.utils import common, dist_utils, summary_utils
from alf.utils.summary_utils import record_time
from alf.utils.distributed import data_distributed
from alf.tensor_specs import TensorSpec
from .config import TrainerConfig


class _UnrollPerformer(torch.nn.Module):
"""Wraps RLAlgorithm.unroll() as a forward()
In fact, UnrollPerformer.forward() is just a delegation to
RLAlgorithm.unroll(). The reason that we need this wrapper is to trick DDP
to add necessary callbacks to the unroll process. DDP wrapper only
recognizes and hijacks the method named "forward()".
"""

def __init__(self, algorithm: 'RLAlgorithm'):
"""Construct the unroll() wrapper
Args:
algorithm (RLAlgorithm): an RLAlgorithm instance whose unroll() is
wrapped by UnrollPerformer's forward().
"""
super().__init__()
self.inner_algorithm = algorithm

# DDP will panic if the wrapped module (i.e. UnrollPerformer here) has
# member in its state_dict() that is not a Tensor. Here such state_dict
# members are picked and thrown into _ddp_params_and_buffers_to_ignore.
# By contract this implicitly instruct DDP wrapper to not include them
# in its parameter/buffer synchronization.
self._ddp_params_and_buffers_to_ignore = []
for name, value in self.state_dict().items():
if type(value) is not torch.Tensor:
self._ddp_params_and_buffers_to_ignore.append(name)

def forward(self, unroll_length: int):
"""This is simply a delegation to the underlying algorithm's unroll()
Args:
unroll_length (int): the number of steps to unroll
"""
return self.inner_algorithm._unroll(unroll_length)


def adjust_replay_buffer_length(config: TrainerConfig,
num_earliest_frames_ignored: int = 0) -> int:
"""Adjust the replay buffer length for whole replay buffer training.
Expand Down Expand Up @@ -244,13 +206,6 @@ def __init__(self,
assert reward_weights is None, (
"reward_weights cannot be used for one dimensional reward")

# When in DDP (distributed data parallel) mode, self._unroll_performer
# will be set to a DDP wrapped nn.odule that does the unroll(), so that
# gradients derived in the next backward() will be aggregated and
# syncrhonized across all DDP processes. The caller decides whether to
# activate the DDP (as a result the unroll performer too) by calling
# RLAlgorithm.activate_ddp().
self._unroll_performer = None
self._rollout_info_spec = None

self._current_time_step = None
Expand Down Expand Up @@ -325,21 +280,6 @@ def action_spec(self):
"""Return the action spec."""
return self._action_spec

def activate_ddp(self, rank: int):
"""Prepare the RLAlgorithm with DistributedDataParallel wrapper
Note that RLAlgorithm does not need to remember the rank of the device.
Args:
rank (int): DDP wrapper needs to know on which GPU device this
module's parameters and buffers are supposed to be.
"""
# The DDP wrapped module is still a module. To prevent it from being
# taken into the state_dict of this RLAlgorithm, it is put into a tuple
# to avoid triggering automatic state_dict inclusion by __setattr__.
self._unroll_performer = (DDP(
_UnrollPerformer(self), device_ids=[rank]), )

@torch.no_grad()
def set_reward_weights(self, reward_weights):
"""Update reward weights; this function can be called at any step during
Expand Down Expand Up @@ -491,22 +431,9 @@ def _rollout_step(self, time_step: TimeStep, state):
self._rollout_info_spec = dist_utils.extract_spec(policy_step.info)
return policy_step

def unroll(self, unroll_length: int):
"""Unroll ``unroll_length`` steps using the current policy.
This is a delegation over _unroll() which decides whether to call it
with DDP wrapper on or off.
"""
if self._unroll_performer is not None:
# If DDP is on, self._unroll_performer will be activated. In this
# case, call its foward().
return self._unroll_performer[0](unroll_length)
else:
return self._unroll(unroll_length)

@common.mark_rollout
def _unroll(self, unroll_length: int):
@data_distributed
def unroll(self, unroll_length: int):
r"""Unroll ``unroll_length`` steps using the current policy.
Because the ``self._env`` is a batched environment. The total number of
Expand Down
132 changes: 132 additions & 0 deletions alf/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable
import functools

import torch
from torch.nn.parallel import DistributedDataParallel as DDP


class _MethodPerformer(torch.nn.Module):
"""A nn.Module wrapper whose forward() performs a specified method of
the wrapped module.
The end goal is to make a TARGET METHOD data distributed.
We need this delegation so that DDP can then wrap over this module. When DDP
hijacks the forward() of _MethodPerformer to inject synchronization hooks,
it effectively does so for the target method of the wrapped module.
"""

def __init__(self, module: torch.nn.Module, perform: Callable[..., Any]):
"""Constructs a _MethodPerformer.
Args:
module: an instance of the module whose method is going to be
delegated to. The _MethodPerformer instance needs to access and
inherit the parameters from the module, so that DDP knows what
parameters to cover.
perform: the target method of the module.
"""
super().__init__()

self._wrapped_module = module # Register and inherit the parameters
self._perform = functools.partial(perform, self._wrapped_module)

# DDP will panic if the wrapped module has member in its state_dict()
# that is not a Tensor. Here such state_dict members are picked and
# thrown into _ddp_params_and_buffers_to_ignore. By contract this
# implicitly instruct DDP wrapper to not include them in its
# parameter/buffer synchronization.
self._ddp_params_and_buffers_to_ignore = []
for name, value in self.state_dict().items():
if type(value) is not torch.Tensor:
self._ddp_params_and_buffers_to_ignore.append(name)

def forward(self, *args, **kwargs):
return self._perform(*args, **kwargs)


def data_distributed(method):
"""This decorator makes a target method of a module capable of being data
distributed via DDP.
This is to provide a simple and transparent way to enable DDP for specific
code logics.
Example usage:
.. code-block:: python
class A(nn.Module):
# ...
@data_distributed
def compute_something(self, input):
return self._network1(input), self._network2(input)
# ...
variance_scaling_init(layer.weight.data,
nonlinearity=nn.functional.leaky_relu)
In the above code, after applying the decorator, the method
``compute_something`` will be made data distributed if the following
conditions are met:
1. Multiple processes within the same process group creates A's instances
and calls ``compute_something()`` individually.
2. All such A instances have ``self._ddp_activated_rank`` set to the correct
rank of the GPU device that belongs to them.
Otherwise the method ``compute_something()`` will behave normally.
"""

@functools.wraps(method)
def wrapped(*args, **kwargs):
# The first argument to the method is going to be ``self``, i.e. the
# instance that the method belongs to. By accessing it we get the
# reference of the module to wrap.
module_to_wrap = args[0]
assert isinstance(module_to_wrap, torch.nn.Module), (
f'Cannot apply @data_distributed on {type(module_to_wrap)}')

ddp_rank = getattr(module_to_wrap, '_ddp_activated_rank', -1)

# A ddp_rank of -1 means DDP is not activated for this module. In this
# case, just perform the normal method call.
if ddp_rank == -1:
return method(*args, **kwargs)

# Create a DDP wrapped _MethodPerformer instance if not yet. All the
# _MethodPerformer instances are registered in a map called
# _ddp_performer_map, which belongs to the module to wrap.
if not hasattr(module_to_wrap, '_ddp_performer_map'):
setattr(module_to_wrap, '_ddp_performer_map', {})

performer = module_to_wrap._ddp_performer_map.get(
method.__name__, None)
if performer is None:
performer = DDP(
_MethodPerformer(module=module_to_wrap, perform=method),
device_ids=[ddp_rank])
module_to_wrap._ddp_performer_map[method.__name__] = performer
return performer(*args[1:], **kwargs)

return wrapped

0 comments on commit cd9c8e4

Please sign in to comment.