Skip to content

Commit

Permalink
find_unused_parameters (#1117)
Browse files Browse the repository at this point in the history
  • Loading branch information
breakds authored Dec 8, 2021
1 parent 89061cc commit 0b4ec5b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
2 changes: 2 additions & 0 deletions alf/examples/ppg_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@
algorithm_ctor=Agent,
whole_replay_buffer_training=True,
clear_replay_buffer=True)

alf.config('make_ddp_performer', find_unused_parameters=True)
37 changes: 34 additions & 3 deletions alf/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from torch.nn.parallel import DistributedDataParallel as DDP

import alf
from alf.experience_replayers.replay_buffer import ReplayBuffer


Expand Down Expand Up @@ -82,10 +83,41 @@ def __init__(self, module: torch.nn.Module, perform: Callable[..., Any]):
if buf in ignored_named_buffers:
self._ddp_params_and_buffers_to_ignore.append(name)

# TODO(breakds): In the future when needed, we can do explicit filtering
# if the wrapped module is an Algorithm. All parameters and buffers that
# are not within the optimizer can be added to ignore list.

def forward(self, *args, **kwargs):
return self._perform(*args, **kwargs)


@alf.configurable
def make_ddp_performer(module: torch.nn.Module,
method,
ddp_rank: int,
find_unused_parameters: bool = False):
"""Creates a DDP wrapped MethodPerformer.
This function is an alf.configurable and used in the @data_distributed
series of decorators below. Override this in your configuration with
alf.config('make_ddp_performer', find_unused_parameters=True)
to enable ``find_unused_parameters``. This asks DDP to ignore parameters
that are not used for computing the output of ``forward()`` when waiting for
synchronization of gradients and parameters upon ``backward()``. Normally
you do not need to worry about this. It is useful for algorithms such as PPG
where part of the parameters of the model does NOT ALWAYS contribute to the
network output.
"""
print(f'find_unused_parameters={find_unused_parameters}')
return DDP(
_MethodPerformer(module=module, perform=method),
device_ids=[ddp_rank],
find_unused_parameters=find_unused_parameters)


def data_distributed(method):
"""This decorator makes a target method of a module capable of being data
distributed via DDP.
Expand Down Expand Up @@ -170,9 +202,8 @@ def wrapped(*args, **kwargs):
performer = module_to_wrap._ddp_performer_map.get(
method.__name__, None)
if performer is None:
performer = DDP(
_MethodPerformer(module=module_to_wrap, perform=method),
device_ids=[ddp_rank])
performer = make_ddp_performer(module_to_wrap, method,
ddp_rank)
module_to_wrap._ddp_performer_map[method.__name__] = performer
return performer(*args[1:], **kwargs)

Expand Down

0 comments on commit 0b4ec5b

Please sign in to comment.