Skip to content

Commit

Permalink
More reliable way to ignore replay buffer's buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
breakds committed Dec 2, 2021
1 parent 58496fa commit ef9064d
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions alf/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from torch.nn.parallel import DistributedDataParallel as DDP

from alf.experience_replayers.replay_buffer import ReplayBuffer


class _MethodPerformer(torch.nn.Module):
"""A nn.Module wrapper whose forward() performs a specified method of
Expand Down Expand Up @@ -62,11 +64,22 @@ def __init__(self, module: torch.nn.Module, perform: Callable[..., Any]):
# We also need to ignore all the buffers that is under the replay buffer
# of the module (e.g. when the module is an Algorithm) for DDP, because
# we do not want DDP to synchronize replay buffers across processes.
#
# Those buffers are not registered in the state_dict() because of Alf's
# special treatment but can be found under named_buffers(). We do not
# want DDP to synchronize replay buffers.
for name, _ in self.named_buffers():
if '_replay_buffer' in name.split('.'):
ignored_named_buffers = []
for sub_module in module.modules():
if isinstance(sub_module, ReplayBuffer):
for _, buf in sub_module.named_buffers():
# Find all the buffers that are registered under a
# ReplayBuffer submodule.
ignored_named_buffers.append(buf)

for name, buf in self.named_buffers():
# If the buffer is in the ignored_named_buffers (address-wise equal,
# i.e. ``is``), add its name to DDP's ignore list.
if any([buf is x for x in ignored_named_buffers]):
self._ddp_params_and_buffers_to_ignore.append(name)

def forward(self, *args, **kwargs):
Expand Down

0 comments on commit ef9064d

Please sign in to comment.