Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make FSDPv1 only perform cat() during last microbatch backward() within FlattenParamsWrapper #1178

Draft
wants to merge 11 commits into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from
64 changes: 62 additions & 2 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401

from logging import getLogger
logger = getLogger()

class FlatParameter(nn.Parameter):
"""A parameter that is initialized from a list of parameters and can be
Expand Down Expand Up @@ -368,12 +370,70 @@ def _unflatten_params_as_views(self) -> None:
"""Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
"""
logger.info("CHRISLOG: _unflatten_params_as_views() called")
assert self.is_flattened
ps = self.get_param_views()
# ps = self.get_param_views()

"""Return a generator of views that map to the original parameters."""

"""Used to get a generator over all views from a list of external data list."""
params = self.flat_params
external_data_list = [None] * len(params)
assert len(external_data_list) == len(
params
), f"Incorrect external data list: {len(external_data_list)} vs. {len(params)}"

# Post accumulation hook so we can call backward() on original leaf params at last microbatch
import functools

def _post_accumulation_hook(new_param_stop_grad, new_param):
# TODO: make it only call backward() only for last microbatch (within FSDP no_sync)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please try to split this into three logger.info entries? I think the printing of new_param failed as the new_param we used to build partial function was gone...

# if not mpu.get_data_parallel_no_sync():
logger.info(
f"CHRISLOG: _post_accumulation_hook() called {new_param=} {new_param_stop_grad=} {new_param_stop_grad.grad=}"
)
new_param.backward(gradient=new_param_stop_grad.grad)

gens = []
for p, data in zip(params, external_data_list):
# Sanity check
assert p.data.numel() <= sum(
p._param_numels
), f"Incorrect internal state {p.data.numel()} vs. {sum(p._param_numels)}"
data = data if data is not None else p
if data.numel() != sum(p._param_numels):
raise ValueError(
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(p._param_numels)}"
)

# Split the data into views of each parameter.
param_views_stop_grad = []
for t, s in zip(data.split(p._param_numels), p._param_shapes):
# Create unflattened view for the param.
new_param = t.view(s)
# Create a new_param_stop_grad param via detaching original leaf params after .view()
# as the new leaf nodes so that autograd.backward() won't call
# grad_fn of view() (which will be cat())
new_param_stop_grad = new_param.detach().requires_grad_(True)
# Register post-accumulation hook to the new_param_stop_grad parameters so that
# we can still manually call backward() function
# to propogate gradients to the original leaf params, e.g. after last_microbatch
# backward()
new_param_stop_grad.register_post_accumulate_grad_hook(
functools.partial(_post_accumulation_hook, new_param=new_param)
)
param_views_stop_grad.append(new_param_stop_grad)

gens.append(param_views_stop_grad)
ps = chain(*gens)

# Set the param with unflattened view as the new attribute
# under original param name
param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
setattr(p, '_fsdp_weight', True)
setattr(p, "_fsdp_weight", True)
setattr(m, n, p) # This will set as plain attr
# logger.info(f"CHRISLOG: {n=} {p.is_leaf=}")
param_views.append(p)

# Save param views for easy access if anyone still wants to access
Expand Down
Loading