Skip to content

Commit

Permalink
use main_grad for higher precision gradient accumulation; update amax…
Browse files Browse the repository at this point in the history
… during post_backward_hook
  • Loading branch information
jspark1105 committed Oct 15, 2023
1 parent db6a1c7 commit 6a4d7f4
Showing 1 changed file with 106 additions and 35 deletions.
141 changes: 106 additions & 35 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
from transformer_engine.pytorch.cpp_extensions import (
cast_to_fp8,
DType,
FP8FwdTensors,
)
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8, DType, FP8FwdTensors
from transformer_engine.pytorch.fp8 import amax_and_scale_update, FP8GlobalStateManager

from . import fsdp_optim_utils as ou

Expand Down Expand Up @@ -123,6 +120,14 @@ class OffloadConfig:
dir: Optional[str] = None


def _is_te_module_with_weights(m: nn.Module) -> bool:
return isinstance(m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP))


def _is_fp8_dtype(dtype: torch.dtype) -> bool:
return dtype in [torch.float8_e5m2, torch.float8_e4m3fn]


class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
Expand Down Expand Up @@ -462,6 +467,8 @@ def __init__(
non_flatten_params = params
param_name_groups = [[n] for n in param_names]
if self.flatten_parameters:
# don't flatten norm_weights since we need to handle them
# separately during fp8 training
to_be_flatten_params = [
[
params[i]
Expand All @@ -470,7 +477,9 @@ def __init__(
]
]
non_flatten_params = [
params[i] for i in range(len(params)) if "norm_weight" in param_names[i]
params[i]
for i in range(len(params))
if "norm_weight" in param_names[i]
]
param_name_groups = [
[n for n in param_names if "norm_weight" not in n],
Expand Down Expand Up @@ -543,6 +552,9 @@ def __init__(
if self.zero2_process_group is not None:
assert not self.move_params_to_cpu

def _is_fp8_dtype(self) -> bool:
return _is_fp8_dtype(self.compute_dtype)

def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
Expand Down Expand Up @@ -672,7 +684,7 @@ def _cast_buffers(
@property
def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None]"""
return [p for p in self.parameters() if p.grad is not None]
return [p for p in self.parameters() if p.grad is not None or getattr(p, "main_grad", None) is not None]

@torch.no_grad()
def clip_grad_norm_(
Expand Down Expand Up @@ -775,7 +787,10 @@ def _shard_parameters_(self) -> None:
assert p.dtype == torch.float32

# If world_size is 1, then we all-reduce grads instead of sharding.
p._is_sharded = self.world_size > 1 and isinstance(p, FlatParameter)
# An exception is norm weights during fp8 training.
p._is_sharded = self.world_size > 1 and (
not self._is_fp8_dtype() or isinstance(p, FlatParameter)
)
p._orig_size = p.data.size()

if not p._is_sharded:
Expand Down Expand Up @@ -1279,11 +1294,9 @@ def _init_param_attributes(self, p: Parameter) -> None:
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
if self.compute_dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
] and not isinstance(p, FlatParameter):
if self._is_fp8_dtype() and not isinstance(p, FlatParameter):
# assume non flattened are precision critical like norm
assert not p._is_sharded
dtype = torch.bfloat16
else:
dtype = self.compute_dtype
Expand All @@ -1307,11 +1320,9 @@ def _init_param_attributes(self, p: Parameter) -> None:
# world_size, although these padding elements will be removed before the
# relevant computation.
if p._is_sharded:
if self.compute_dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
] and not isinstance(p, FlatParameter):
if self._is_fp8_dtype() and not isinstance(p, FlatParameter):
# assume non flattened are precision critical like norm
assert not p._is_sharded
dtype = torch.bfloat16
else:
dtype = self.compute_dtype
Expand Down Expand Up @@ -1439,16 +1450,58 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self._is_root and self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs)

just_added_to_fsdp_forward_ordering = False
if self not in self._fsdp_forward_ordering:
self._my_fsdp_instance_idx = len(self._fsdp_forward_ordering)
self._fsdp_forward_ordering.append(self)
just_added_to_fsdp_forward_ordering = True

# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if self.force_input_to_fp32 and not self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(False, False, is_bf16, *args, **kwargs)

# need to use fp32_to_fp16 stream since _cast_fp32_param_shards_to_fp16
# depends on this block.
with torch.no_grad(), torch.cuda.stream(self._streams["fp32_to_fp16"]):
# Collect parameters to update amax before we
# _cast_fp32_param_shards_to_fp16 that uses fp8 scale to quantize
# before all-gather.
# These include params we prefetch all-gather.
params = []
if self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1:
if self._my_fsdp_instance_idx == 0 and self._is_fp8_dtype():
# The first FSDP instance didn't have chance to prefetch
params = self.params
if self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._is_fp8_dtype():
# FSDP instance we'll prefetch all-gather
params.extend(self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1].params)
elif just_added_to_fsdp_forward_ordering:
# In the first iteration, we didn't have chance to record
# fsdp_instance_idx to prefetch
if self._is_fp8_dtype():
params = self.params

for p in params:
if not isinstance(p, FlatParameter):
continue
d = {info[0]: info[1] for info in p._param_infos}
for n, m in d.items():
# Previous iteration was grad_enabled
if m.fp8_meta.get("update_amax_and_scale_fwd", False):
if m.fp8_meta["recipe"].reduce_amax:
FP8GlobalStateManager.copy_amax_from_global_buffer(
m.fp8_meta, forward=True
)
# FIXME update_weight_scale_inv is only True for the first micro-batch
amax_and_scale_update(m.fp8_meta, True)
FP8GlobalStateManager.set_amax_buffer_key_deletion(
m.fp8_meta, forward=True
)
else:
amax_and_scale_update(m.fp8_meta, True)

# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()
Expand Down Expand Up @@ -1690,7 +1743,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
grad_or_main_grad = (
param.main_grad if getattr(param, "main_grad", None) is not None else param.grad
)
if grad_or_main_grad is None:
return

if hasattr(param, "_linked_param"):
Expand All @@ -1703,8 +1759,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
param = param._linked_param

assert param.grad is not None, param.shape
if param.grad.requires_grad:
assert grad_or_main_grad is not None, param.shape
if grad_or_main_grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")

if self._require_backward_grad_sync or self.reshard_after_forward:
Expand All @@ -1731,27 +1787,32 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data
orig_grad_data = grad_or_main_grad

if self.mixed_precision:
if self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype)
elif self.compute_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
grad_or_main_grad.data = grad_or_main_grad.to(param.dtype)
elif self._is_fp8_dtype():
# Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad)
param.grad.data = param.grad.data.to(torch.bfloat16)
grad_or_main_grad.data = grad_or_main_grad.to(torch.bfloat16)

if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor)
grad_or_main_grad.div_(self.gradient_predivide_factor)

# logging.info(f"{torch.distributed.get_rank()=} {param._is_sharded=}")
if param._is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
if hasattr(param, "main_grad"):
grad = param.main_grad
param.main_grad = None
else:
grad = param.grad
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
Expand All @@ -1774,7 +1835,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
# assert self.world_size == 1
self._post_reduction_hook(param, param.grad.data)
self._post_reduction_hook(param, grad_or_main_grad)

# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
Expand Down Expand Up @@ -1886,6 +1947,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
)
assert p.shape == p._saved_grad_shard.shape
assert p.dtype == p._saved_grad_shard.dtype
assert getattr(p, "main_grad", None) is None
p.grad = p._saved_grad_shard

if hasattr(p, "_saved_grad_shard"):
Expand Down Expand Up @@ -2212,6 +2276,17 @@ def _prep_grads_for_backward(self) -> None:
right shape, device, accumulated values, etc.
"""
for p in self.params:
if isinstance(p, FlatParameter) and all(
_is_te_module_with_weights(info[1]) for info in p._param_infos
):
if getattr(p, "main_grad", None) is None:
p.main_grad = torch.empty_like(
p, dtype=torch.bfloat16 if self._is_fp8_dtype() else torch.float
)
main_grad_views = p.get_param_views(p.main_grad)
for (_, m, n), main_grad in zip(p._param_infos, main_grad_views):
getattr(m, n).main_grad = main_grad

if p.grad is not None:
if p.grad.device != p.data.device:
p.grad = None
Expand Down Expand Up @@ -2311,8 +2386,8 @@ def local_metadata_dict(self) -> Dict[str, Any]:
backing_param_name = m.module.flat_param_names[i]
names, shapes, numels = m.module.metadata(i)
else:
assert len(m._param_name_groups[i]) == 1
backing_param_name = m._param_name_groups[i][0]
# assert len(m._param_name_groups[i]) == 1
backing_param_name = m._param_name_groups[m._num_flatten_params][i - m._num_flatten_params]
names = [backing_param_name]
shapes = [p._orig_size]
numels = [p._orig_size.numel()]
Expand Down Expand Up @@ -2428,10 +2503,8 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
if self.compute_dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
] and p._fp16_shard.dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
if self._is_fp8_dtype() and _is_fp8_dtype(p._fp16_shard.dtype):
# fp8 quantization
assert isinstance(p, FlatParameter)
assert len(p._param_infos) == len(p._param_numels)
numel_per_shard = p.numel()
Expand All @@ -2442,9 +2515,7 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No
if offset + numel <= 0 or offset >= numel_per_shard:
offset += numel
continue
assert isinstance(
m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP)
)
assert _is_te_module_with_weights(m)
fp8_dtype_forward = te.fp8.get_fp8_te_dtype(
m.fp8_meta["recipe"], fprop_tensor=True
)
Expand Down

0 comments on commit 6a4d7f4

Please sign in to comment.