Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable DDP for Off-Policy training branch #1099

Merged
merged 3 commits into from
Dec 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions alf/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from alf.utils import common, dist_utils, spec_utils, summary_utils
from alf.utils.summary_utils import record_time
from alf.utils.math_ops import add_ignore_empty
from alf.utils.distributed import data_distributed
from alf.experience_replayers.replay_buffer import ReplayBuffer
from .algorithm_interface import AlgorithmInterface
from .config import TrainerConfig
Expand Down Expand Up @@ -1473,17 +1474,31 @@ def _collect_train_info_parallelly(self, experience):
info = dist_utils.params_to_distributions(info, self.train_info_spec)
return info

def _update(self, experience, batch_info, weight):
@data_distributed
def _compute_train_info_and_loss_info(self, experience):
"""Compute train_info and loss_info based on the experience.

This function has data distributed support. This means that if the
Algorithm instance has DDP activated, the output will have a hook to
synchronize gradients across processes upon the call to the backward()
that involes the output (i.e. train_info and loss_info).

"""
length = alf.nest.get_nest_size(experience, dim=0)
if self._config.temporally_independent_train_step or length == 1:
train_info = self._collect_train_info_parallelly(experience)
else:
train_info = self._collect_train_info_sequentially(experience)

experience = dist_utils.params_to_distributions(
experience, self.processed_experience_spec)

loss_info = self.calc_loss(train_info)

return train_info, loss_info

def _update(self, experience, batch_info, weight):
train_info, loss_info = self._compute_train_info_and_loss_info(
experience)

if loss_info.priority != ():
priority = (
loss_info.priority**self._config.priority_replay_alpha() +
Expand Down
4 changes: 0 additions & 4 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,6 @@ def __init__(self, config: TrainerConfig, ddp_rank: int = -1):
debug_summaries=self._debug_summaries)
self._algorithm.set_path('')
if ddp_rank >= 0:
if not self._algorithm.on_policy:
raise RuntimeError(
'Mutli-GPU with DDP does not support off-policy training yet'
)
# Activate the DDP training
self._algorithm.activate_ddp(ddp_rank)

Expand Down
23 changes: 23 additions & 0 deletions alf/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from torch.nn.parallel import DistributedDataParallel as DDP

from alf.experience_replayers.replay_buffer import ReplayBuffer


class _MethodPerformer(torch.nn.Module):
"""A nn.Module wrapper whose forward() performs a specified method of
Expand Down Expand Up @@ -59,6 +61,27 @@ def __init__(self, module: torch.nn.Module, perform: Callable[..., Any]):
if type(value) is not torch.Tensor:
self._ddp_params_and_buffers_to_ignore.append(name)

# We also need to ignore all the buffers that is under the replay buffer
# of the module (e.g. when the module is an Algorithm) for DDP, because
# we do not want DDP to synchronize replay buffers across processes.
#
# Those buffers are not registered in the state_dict() because of Alf's
# special treatment but can be found under named_buffers(). We do not
# want DDP to synchronize replay buffers.
ignored_named_buffers = set()
for sub_module in module.modules():
if isinstance(sub_module, ReplayBuffer):
for _, buf in sub_module.named_buffers():
# Find all the buffers that are registered under a
# ReplayBuffer submodule.
ignored_named_buffers.add(buf)

for name, buf in self.named_buffers():
# If the buffer is in the ignored_named_buffers (address-wise equal,
# i.e. ``is``), add its name to DDP's ignore list.
if buf in ignored_named_buffers:
self._ddp_params_and_buffers_to_ignore.append(name)

def forward(self, *args, **kwargs):
return self._perform(*args, **kwargs)

Expand Down