Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement @data_distributed Decorator #1098

Merged
merged 3 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions alf/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def __init__(self,
self._metrics = []
self._replay_buffer = None

self._ddp_activated_rank = -1

# These 3 parameters are only set when ``set_replay_buffer()`` is called.
self._replay_buffer_num_envs = None
self._replay_buffer_max_length = None
Expand Down Expand Up @@ -213,6 +215,17 @@ def use_rollout_state(self):
"""
return self._use_rollout_state

def activate_ddp(self, rank: int):
"""Prepare the Algorithm with DistributedDataParallel wrapper

Note that Algorithm does not need to remember the rank of the device.

Args:
rank (int): DDP wrapper needs to know on which GPU device this
module's parameters and buffers are supposed to be.
"""
self._ddp_activated_rank = rank

@use_rollout_state.setter
def use_rollout_state(self, flag):
self._use_rollout_state = flag
Expand Down
80 changes: 3 additions & 77 deletions alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
import time
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Callable
from absl import logging

Expand All @@ -26,49 +25,11 @@
from alf.data_structures import AlgStep, Experience, make_experience, TimeStep
from alf.utils import common, dist_utils, summary_utils
from alf.utils.summary_utils import record_time
from alf.utils.distributed import data_distributed
from alf.tensor_specs import TensorSpec
from .config import TrainerConfig


class _UnrollPerformer(torch.nn.Module):
"""Wraps RLAlgorithm.unroll() as a forward()

In fact, UnrollPerformer.forward() is just a delegation to
RLAlgorithm.unroll(). The reason that we need this wrapper is to trick DDP
to add necessary callbacks to the unroll process. DDP wrapper only
recognizes and hijacks the method named "forward()".

"""

def __init__(self, algorithm: 'RLAlgorithm'):
"""Construct the unroll() wrapper

Args:
algorithm (RLAlgorithm): an RLAlgorithm instance whose unroll() is
wrapped by UnrollPerformer's forward().
"""
super().__init__()
self.inner_algorithm = algorithm

# DDP will panic if the wrapped module (i.e. UnrollPerformer here) has
# member in its state_dict() that is not a Tensor. Here such state_dict
# members are picked and thrown into _ddp_params_and_buffers_to_ignore.
# By contract this implicitly instruct DDP wrapper to not include them
# in its parameter/buffer synchronization.
self._ddp_params_and_buffers_to_ignore = []
for name, value in self.state_dict().items():
if type(value) is not torch.Tensor:
self._ddp_params_and_buffers_to_ignore.append(name)

def forward(self, unroll_length: int):
"""This is simply a delegation to the underlying algorithm's unroll()

Args:
unroll_length (int): the number of steps to unroll
"""
return self.inner_algorithm._unroll(unroll_length)


def adjust_replay_buffer_length(config: TrainerConfig,
num_earliest_frames_ignored: int = 0) -> int:
"""Adjust the replay buffer length for whole replay buffer training.
Expand Down Expand Up @@ -244,13 +205,6 @@ def __init__(self,
assert reward_weights is None, (
"reward_weights cannot be used for one dimensional reward")

# When in DDP (distributed data parallel) mode, self._unroll_performer
# will be set to a DDP wrapped nn.odule that does the unroll(), so that
# gradients derived in the next backward() will be aggregated and
# syncrhonized across all DDP processes. The caller decides whether to
# activate the DDP (as a result the unroll performer too) by calling
# RLAlgorithm.activate_ddp().
self._unroll_performer = None
self._rollout_info_spec = None

self._current_time_step = None
Expand Down Expand Up @@ -325,21 +279,6 @@ def action_spec(self):
"""Return the action spec."""
return self._action_spec

def activate_ddp(self, rank: int):
"""Prepare the RLAlgorithm with DistributedDataParallel wrapper

Note that RLAlgorithm does not need to remember the rank of the device.

Args:
rank (int): DDP wrapper needs to know on which GPU device this
module's parameters and buffers are supposed to be.
"""
# The DDP wrapped module is still a module. To prevent it from being
# taken into the state_dict of this RLAlgorithm, it is put into a tuple
# to avoid triggering automatic state_dict inclusion by __setattr__.
self._unroll_performer = (DDP(
_UnrollPerformer(self), device_ids=[rank]), )

@torch.no_grad()
def set_reward_weights(self, reward_weights):
"""Update reward weights; this function can be called at any step during
Expand Down Expand Up @@ -491,22 +430,9 @@ def _rollout_step(self, time_step: TimeStep, state):
self._rollout_info_spec = dist_utils.extract_spec(policy_step.info)
return policy_step

def unroll(self, unroll_length: int):
"""Unroll ``unroll_length`` steps using the current policy.

This is a delegation over _unroll() which decides whether to call it
with DDP wrapper on or off.

"""
if self._unroll_performer is not None:
# If DDP is on, self._unroll_performer will be activated. In this
# case, call its foward().
return self._unroll_performer[0](unroll_length)
else:
return self._unroll(unroll_length)

@common.mark_rollout
def _unroll(self, unroll_length: int):
@data_distributed
def unroll(self, unroll_length: int):
r"""Unroll ``unroll_length`` steps using the current policy.

Because the ``self._env`` is a batched environment. The total number of
Expand Down
139 changes: 139 additions & 0 deletions alf/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2021 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable
import functools

import torch
from torch.nn.parallel import DistributedDataParallel as DDP


class _MethodPerformer(torch.nn.Module):
"""A nn.Module wrapper whose forward() performs a specified method of
the wrapped module.

The end goal is to make a TARGET METHOD data distributed.

We need this delegation so that DDP can then wrap over this module. When DDP
hijacks the forward() of _MethodPerformer to inject synchronization hooks,
it effectively does so for the target method of the wrapped module.

"""

def __init__(self, module: torch.nn.Module, perform: Callable[..., Any]):
"""Constructs a _MethodPerformer.

Args:

module: an instance of the module whose method is going to be
delegated to. The _MethodPerformer instance needs to access and
inherit the parameters from the module, so that DDP knows what
parameters to cover.

perform: the target method of the module.

"""
super().__init__()

self._wrapped_module = module # Register and inherit the parameters
self._perform = functools.partial(perform, self._wrapped_module)

# DDP will panic if the wrapped module has member in its state_dict()
# that is not a Tensor. Here such state_dict members are picked and
# thrown into _ddp_params_and_buffers_to_ignore. By contract this
# implicitly instructs DDP wrapper to not include them in its
# parameter/buffer synchronization.
self._ddp_params_and_buffers_to_ignore = []
for name, value in self.state_dict().items():
if type(value) is not torch.Tensor:
self._ddp_params_and_buffers_to_ignore.append(name)

def forward(self, *args, **kwargs):
return self._perform(*args, **kwargs)


def data_distributed(method):
"""This decorator makes a target method of a module capable of being data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this decorator can only apply to methods of Algorithm? If so, the docstring should clarify this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is just for Algorithm, but actually all nn.Module derived classes. One of the prerequisite is that it needs to have self._ddp_activated_rank, which is stated in the docstring as a contract.

distributed via DDP.

This is to provide a simple and transparent way to enable DDP for specific
code logics.
breakds marked this conversation as resolved.
Show resolved Hide resolved

When the method is wrapped by @data_distributed, the outputs (tensors) of
this method will have gradient synchronization hooks attached to them. Later
when those outputs are used in ``backward()`` to compute gradients, the
hooks will be called to synchronize across all processes. As a result, the
corresponding parameters receive not only the gradients from this process,
but also gradients from the other processes. Note that each single process
will be TRAPPED at the call to the ``backward()`` that involves those output
tensors, until all processes finished the back propagation and have the
gradients sync'ed.

Example usage:

.. code-block:: python

class A(nn.Module):
# ...
@data_distributed
def compute_something(self, input):
return self._network1(input), self._network2(input)
# ...

In the above code, after applying the decorator, the method
``compute_something`` will be made data distributed if the following
conditions are met:

1. Multiple processes within the same process group creates A's instances
and calls ``compute_something()`` individually.

2. All such A instances have ``self._ddp_activated_rank`` set to the correct
rank of the GPU device that belongs to them.

Otherwise the method ``compute_something()`` will behave normally.

"""

@functools.wraps(method)
def wrapped(*args, **kwargs):
# The first argument to the method is going to be ``self``, i.e. the
# instance that the method belongs to. By accessing it we get the
# reference of the module to wrap.
module_to_wrap = args[0]
assert isinstance(module_to_wrap, torch.nn.Module), (
f'Cannot apply @data_distributed on {type(module_to_wrap)}')

ddp_rank = getattr(module_to_wrap, '_ddp_activated_rank', -1)

# A ddp_rank of -1 means DDP is not activated for this module. In this
# case, just perform the normal method call.
if ddp_rank == -1:
return method(*args, **kwargs)

# Create a DDP wrapped _MethodPerformer instance if not yet. All the
# _MethodPerformer instances are registered in a map called
# _ddp_performer_map, which belongs to the module to wrap.
if not hasattr(module_to_wrap, '_ddp_performer_map'):
setattr(module_to_wrap, '_ddp_performer_map', {})

performer = module_to_wrap._ddp_performer_map.get(
method.__name__, None)
if performer is None:
performer = DDP(
_MethodPerformer(module=module_to_wrap, perform=method),
device_ids=[ddp_rank])
module_to_wrap._ddp_performer_map[method.__name__] = performer
return performer(*args[1:], **kwargs)

return wrapped