Skip to content

Commit

Permalink
DDP for off-policy training branch
Browse files Browse the repository at this point in the history
  • Loading branch information
breakds committed Dec 2, 2021
1 parent de2c6a1 commit 58496fa
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
21 changes: 18 additions & 3 deletions alf/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from alf.utils import common, dist_utils, spec_utils, summary_utils
from alf.utils.summary_utils import record_time
from alf.utils.math_ops import add_ignore_empty
from alf.utils.distributed import data_distributed
from alf.experience_replayers.replay_buffer import ReplayBuffer
from .algorithm_interface import AlgorithmInterface
from .config import TrainerConfig
Expand Down Expand Up @@ -1473,17 +1474,31 @@ def _collect_train_info_parallelly(self, experience):
info = dist_utils.params_to_distributions(info, self.train_info_spec)
return info

def _update(self, experience, batch_info, weight):
@data_distributed
def _compute_train_info_and_loss_info(self, experience):
"""Compute train_info and loss_info based on the experience.
This function has data distributed support. This means that if the
Algorithm instance has DDP activated, the output will have a hook to
synchronize gradients across processes upon the call to the backward()
that involes the output (i.e. train_info and loss_info).
"""
length = alf.nest.get_nest_size(experience, dim=0)
if self._config.temporally_independent_train_step or length == 1:
train_info = self._collect_train_info_parallelly(experience)
else:
train_info = self._collect_train_info_sequentially(experience)

experience = dist_utils.params_to_distributions(
experience, self.processed_experience_spec)

loss_info = self.calc_loss(train_info)

return train_info, loss_info

def _update(self, experience, batch_info, weight):
train_info, loss_info = self._compute_train_info_and_loss_info(
experience)

if loss_info.priority != ():
priority = (
loss_info.priority**self._config.priority_replay_alpha() +
Expand Down
4 changes: 0 additions & 4 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,6 @@ def __init__(self, config: TrainerConfig, ddp_rank: int = -1):
debug_summaries=self._debug_summaries)
self._algorithm.set_path('')
if ddp_rank >= 0:
if not self._algorithm.on_policy:
raise RuntimeError(
'Mutli-GPU with DDP does not support off-policy training yet'
)
# Activate the DDP training
self._algorithm.activate_ddp(ddp_rank)

Expand Down
10 changes: 10 additions & 0 deletions alf/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def __init__(self, module: torch.nn.Module, perform: Callable[..., Any]):
if type(value) is not torch.Tensor:
self._ddp_params_and_buffers_to_ignore.append(name)

# We also need to ignore all the buffers that is under the replay buffer
# of the module (e.g. when the module is an Algorithm) for DDP, because
# we do not want DDP to synchronize replay buffers across processes.
# Those buffers are not registered in the state_dict() because of Alf's
# special treatment but can be found under named_buffers(). We do not
# want DDP to synchronize replay buffers.
for name, _ in self.named_buffers():
if '_replay_buffer' in name.split('.'):
self._ddp_params_and_buffers_to_ignore.append(name)

def forward(self, *args, **kwargs):
return self._perform(*args, **kwargs)

Expand Down

0 comments on commit 58496fa

Please sign in to comment.