Skip to content

Commit

Permalink
Update the documentation according to the comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
breakds committed Dec 1, 2021
1 parent cd9c8e4 commit 06b1c97
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion alf/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, module: torch.nn.Module, perform: Callable[..., Any]):
# DDP will panic if the wrapped module has member in its state_dict()
# that is not a Tensor. Here such state_dict members are picked and
# thrown into _ddp_params_and_buffers_to_ignore. By contract this
# implicitly instruct DDP wrapper to not include them in its
# implicitly instructs DDP wrapper to not include them in its
# parameter/buffer synchronization.
self._ddp_params_and_buffers_to_ignore = []
for name, value in self.state_dict().items():
Expand All @@ -70,6 +70,16 @@ def data_distributed(method):
This is to provide a simple and transparent way to enable DDP for specific
code logics.
When the method is wrapped by @data_distributed, the outputs (tensors) of
this method will have gradient synchronization hooks attached to them. Later
when those outputs are used in ``backward()`` to compute gradients, the
hooks will be called to synchronize across all processes. As a result, the
corresponding parameters does not only receive the gradients from this
process, but also gradients from the other processes. Note that each single
process will be TRAPPED at the call to the ``backward()`` that involves
those output tensors, until all processes finished the back propagation and
have the gradients sync'ed.
Example usage:
.. code-block:: python
Expand Down

0 comments on commit 06b1c97

Please sign in to comment.