Skip to content

Commit

Permalink
update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisxcai committed May 15, 2024
1 parent c91cb72 commit fd3f3fc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ class FullyShardedDataParallel(nn.Module):
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False
optimize_backward_concat (bool):
If True, only trigger the self._fsdp_wrapped_module.flat_params backward(), which will
invoke the _post_backward_hook() and concat() op,
when self._require_backward_grad_sync is True (e.g. last microbatch)
If True, only let backward pass propagate to self.params, which will
invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync
is True (e.g. last microbatch)
NOTE: this likely will incur more GPU memory usage
"""

Expand Down
12 changes: 6 additions & 6 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class FlattenParamsWrapper(nn.Module):
originally, give each flat_param a unique name. Note a "flat_param_"
prefix will be added to those names.
optimize_backward_concat (bool):
If True, only trigger the self.flat_params backward(), which will
invoke the parent FSDP module's _post_backward_hook() and concat() op,
when self._require_backward_grad_sync is True (e.g. last microbatch)
If True, only let backward pass propagate to the corresponding FSDP.params, which will
invoke the FSDP._post_backward_hook() and concat() op, when _require_backward_grad_sync
is True (e.g. last microbatch)
NOTE: this likely will incur more GPU memory usage
"""

Expand All @@ -170,10 +170,10 @@ def __init__(
self._fpw_module = module
self.is_flattened = False
self.optimize_backward_concat = optimize_backward_concat
# If self.optimize_backward_concat == True, used to propagate the
# parent FSDP modules's _require_backward_grad_sync flag
# If optimize_backward_concat == True, used to propagate the
# corresponding FSDP modules's _require_backward_grad_sync flag
self._require_backward_grad_sync = True
# If self.optimize_backward_concat == True, used to accumulate the
# If optimize_backward_concat == True, used to accumulate the
# fp32 gradients for the flattened parameters
self.fp32_grads = []

Expand Down

0 comments on commit fd3f3fc

Please sign in to comment.