Skip to content

Commit

Permalink
use main_grad for higher precision gradient accumulation; update amax…
Browse files Browse the repository at this point in the history
… during post_backward_hook
  • Loading branch information
jspark1105 committed Oct 5, 2023
1 parent db6a1c7 commit bd70153
Showing 1 changed file with 93 additions and 42 deletions.
135 changes: 93 additions & 42 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
from transformer_engine.pytorch.cpp_extensions import (
cast_to_fp8,
DType,
FP8FwdTensors,
)
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8, DType, FP8FwdTensors
from transformer_engine.pytorch.fp8 import amax_and_scale_update, FP8GlobalStateManager

from . import fsdp_optim_utils as ou

Expand Down Expand Up @@ -123,6 +120,14 @@ class OffloadConfig:
dir: Optional[str] = None


def _is_te_module_with_weights(m: nn.Module) -> bool:
return isinstance(m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP))


def _is_fp8_dtype(dtype: torch.dtype) -> bool:
return dtype in [torch.float8_e5m2, torch.float8_e4m3fn]


class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
Expand Down Expand Up @@ -462,20 +467,29 @@ def __init__(
non_flatten_params = params
param_name_groups = [[n] for n in param_names]
if self.flatten_parameters:
to_be_flatten_params = [
[
# don't flatten norm_weights since we need to handle them
# separately during fp8 training
if self._is_fp8_dtype():
to_be_flatten_params = [
[
params[i]
for i in range(len(params))
if "norm_weight" not in param_names[i]
]
]
non_flatten_params = [
params[i]
for i in range(len(params))
if "norm_weight" not in param_names[i]
if "norm_weight" in param_names[i]
]
]
non_flatten_params = [
params[i] for i in range(len(params)) if "norm_weight" in param_names[i]
]
param_name_groups = [
[n for n in param_names if "norm_weight" not in n],
[n for n in param_names if "norm_weight" in n],
]
param_name_groups = [
[n for n in param_names if "norm_weight" not in n],
[n for n in param_names if "norm_weight" in n],
]
else:
to_be_flatten_params = params
non_flatten_params = []
param_name_groups = [param_names]
del param_names

self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
Expand Down Expand Up @@ -543,6 +557,9 @@ def __init__(
if self.zero2_process_group is not None:
assert not self.move_params_to_cpu

def _is_fp8_dtype(self) -> bool:
return _is_fp8_dtype(self)

def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
Expand Down Expand Up @@ -775,7 +792,10 @@ def _shard_parameters_(self) -> None:
assert p.dtype == torch.float32

# If world_size is 1, then we all-reduce grads instead of sharding.
p._is_sharded = self.world_size > 1 and isinstance(p, FlatParameter)
# An exception is norm weights during fp8 training.
p._is_sharded = self.world_size > 1 and (
not self._is_fp8_dtype() or isinstance(p, FlatParameter)
)
p._orig_size = p.data.size()

if not p._is_sharded:
Expand Down Expand Up @@ -1279,11 +1299,9 @@ def _init_param_attributes(self, p: Parameter) -> None:
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
if self.compute_dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
] and not isinstance(p, FlatParameter):
if self._is_fp8_dtype() and not isinstance(p, FlatParameter):
# assume non flattened are precision critical like norm
assert not p._is_sharded
dtype = torch.bfloat16
else:
dtype = self.compute_dtype
Expand All @@ -1307,11 +1325,9 @@ def _init_param_attributes(self, p: Parameter) -> None:
# world_size, although these padding elements will be removed before the
# relevant computation.
if p._is_sharded:
if self.compute_dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
] and not isinstance(p, FlatParameter):
if self._is_fp8_dtype() and not isinstance(p, FlatParameter):
# assume non flattened are precision critical like norm
assert not p._is_sharded
dtype = torch.bfloat16
else:
dtype = self.compute_dtype
Expand Down Expand Up @@ -1690,7 +1706,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
grad_or_main_grad = (
param.main_grad if hasattr(param, "main_grad") else param.grad
)
if grad_or_main_grad is None:
return

if hasattr(param, "_linked_param"):
Expand All @@ -1703,10 +1722,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
param = param._linked_param

assert param.grad is not None, param.shape
if param.grad.requires_grad:
assert grad_or_main_grad is not None, param.shape
if grad_or_main_grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")

if (
self._require_backward_grad_sync
and self._is_fp8_dtype()
and isinstance(param, FlatParameter)
):
# Need to update amax and scale here before
# _cast_fp32_param_shards_to_fp16
for m in set(info[1] for info in param._param_infos):
# Previous iteration was grad_enabled
if m.fp8_meta.get("update_amax_and_scale_fwd", False):
if m.fp8_meta["recipe"].reduce_amax:
FP8GlobalStateManager.copy_amax_from_global_buffer(
m.fp8_meta, forward=True
)
amax_and_scale_update(m.fp8_meta, True)
FP8GlobalStateManager.set_amax_buffer_key_deletion(
m.fp8_meta, forward=True
)
else:
amax_and_scale_update(m.fp8_meta, True)

if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
Expand All @@ -1731,27 +1771,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data
orig_grad_data = grad_or_main_grad.data

if self.mixed_precision:
if self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype)
elif self.compute_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
grad_or_main_grad.data = grad_or_main_grad.data.to(param.dtype)
elif self._is_fp8_dtype():
# Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad)
param.grad.data = param.grad.data.to(torch.bfloat16)
grad_or_main_grad.data = grad_or_main_grad.data.to(torch.bfloat16)

if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor)
grad_or_main_grad.data.div_(self.gradient_predivide_factor)

if param._is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
if hasattr(param, "main_grad"):
grad = param.main_grad.data
param.main_grad = None
else:
grad = param.grad.data
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
Expand All @@ -1774,7 +1818,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
# assert self.world_size == 1
self._post_reduction_hook(param, param.grad.data)
self._post_reduction_hook(param, grad_or_main_grad.data)

# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
Expand Down Expand Up @@ -2212,6 +2256,17 @@ def _prep_grads_for_backward(self) -> None:
right shape, device, accumulated values, etc.
"""
for p in self.params:
if isinstance(p, FlatParameter) and all(
_is_te_module_with_weights(info[1]) for info in p._param_infos
):
if getattr(p, "main_grad", None) is None:
p.main_grad = torch.zeros_like(
p, dtype=torch.bfloat16 if self._is_fp8_dtype() else torch.float
)
main_grad_views = p.get_param_views(p.main_grad)
for (_, m, n), main_grad in zip(p._param_infos, main_grad_views):
getattr(m, n).main_grad = main_grad

if p.grad is not None:
if p.grad.device != p.data.device:
p.grad = None
Expand Down Expand Up @@ -2428,10 +2483,8 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
if self.compute_dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
] and p._fp16_shard.dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
if self._is_fp8_dtype() and _is_fp8_dtype(p._fp16_shard.dtype):
# fp8 quantization
assert isinstance(p, FlatParameter)
assert len(p._param_infos) == len(p._param_numels)
numel_per_shard = p.numel()
Expand All @@ -2442,9 +2495,7 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No
if offset + numel <= 0 or offset >= numel_per_shard:
offset += numel
continue
assert isinstance(
m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP)
)
assert _is_te_module_with_weights(m)
fp8_dtype_forward = te.fp8.get_fp8_te_dtype(
m.fp8_meta["recipe"], fprop_tensor=True
)
Expand Down

0 comments on commit bd70153

Please sign in to comment.