From bd70153c6321874593acd1d8977c3862e8e84b69 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Wed, 4 Oct 2023 16:03:01 -0700 Subject: [PATCH] use main_grad for higher precision gradient accumulation; update amax during post_backward_hook --- .../fully_sharded_data_parallel.py | 135 ++++++++++++------ 1 file changed, 93 insertions(+), 42 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index f1aa72842..084e6a715 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -55,11 +55,8 @@ from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.state_dict import replace_by_prefix_ -from transformer_engine.pytorch.cpp_extensions import ( - cast_to_fp8, - DType, - FP8FwdTensors, -) +from transformer_engine.pytorch.cpp_extensions import cast_to_fp8, DType, FP8FwdTensors +from transformer_engine.pytorch.fp8 import amax_and_scale_update, FP8GlobalStateManager from . import fsdp_optim_utils as ou @@ -123,6 +120,14 @@ class OffloadConfig: dir: Optional[str] = None +def _is_te_module_with_weights(m: nn.Module) -> bool: + return isinstance(m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP)) + + +def _is_fp8_dtype(dtype: torch.dtype) -> bool: + return dtype in [torch.float8_e5m2, torch.float8_e4m3fn] + + class FullyShardedDataParallel(nn.Module): """ A wrapper for sharding Module parameters across data parallel workers. This @@ -462,20 +467,29 @@ def __init__( non_flatten_params = params param_name_groups = [[n] for n in param_names] if self.flatten_parameters: - to_be_flatten_params = [ - [ + # don't flatten norm_weights since we need to handle them + # separately during fp8 training + if self._is_fp8_dtype(): + to_be_flatten_params = [ + [ + params[i] + for i in range(len(params)) + if "norm_weight" not in param_names[i] + ] + ] + non_flatten_params = [ params[i] for i in range(len(params)) - if "norm_weight" not in param_names[i] + if "norm_weight" in param_names[i] ] - ] - non_flatten_params = [ - params[i] for i in range(len(params)) if "norm_weight" in param_names[i] - ] - param_name_groups = [ - [n for n in param_names if "norm_weight" not in n], - [n for n in param_names if "norm_weight" in n], - ] + param_name_groups = [ + [n for n in param_names if "norm_weight" not in n], + [n for n in param_names if "norm_weight" in n], + ] + else: + to_be_flatten_params = params + non_flatten_params = [] + param_name_groups = [param_names] del param_names self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( @@ -543,6 +557,9 @@ def __init__( if self.zero2_process_group is not None: assert not self.move_params_to_cpu + def _is_fp8_dtype(self) -> bool: + return _is_fp8_dtype(self) + def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 while world_size % factor == 0 and world_size / factor > factor: @@ -775,7 +792,10 @@ def _shard_parameters_(self) -> None: assert p.dtype == torch.float32 # If world_size is 1, then we all-reduce grads instead of sharding. - p._is_sharded = self.world_size > 1 and isinstance(p, FlatParameter) + # An exception is norm weights during fp8 training. + p._is_sharded = self.world_size > 1 and ( + not self._is_fp8_dtype() or isinstance(p, FlatParameter) + ) p._orig_size = p.data.size() if not p._is_sharded: @@ -1279,11 +1299,9 @@ def _init_param_attributes(self, p: Parameter) -> None: # storage to size 0 at init (here) and re-materialize (by copying # from _fp32_shard) as needed. If offloading params to CPU, the # dtype of the fp16 shard will depend on the *`compute_dtype`*. - if self.compute_dtype in [ - torch.float8_e5m2, - torch.float8_e4m3fn, - ] and not isinstance(p, FlatParameter): + if self._is_fp8_dtype() and not isinstance(p, FlatParameter): # assume non flattened are precision critical like norm + assert not p._is_sharded dtype = torch.bfloat16 else: dtype = self.compute_dtype @@ -1307,11 +1325,9 @@ def _init_param_attributes(self, p: Parameter) -> None: # world_size, although these padding elements will be removed before the # relevant computation. if p._is_sharded: - if self.compute_dtype in [ - torch.float8_e5m2, - torch.float8_e4m3fn, - ] and not isinstance(p, FlatParameter): + if self._is_fp8_dtype() and not isinstance(p, FlatParameter): # assume non flattened are precision critical like norm + assert not p._is_sharded dtype = torch.bfloat16 else: dtype = self.compute_dtype @@ -1690,7 +1706,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # then subsequent hook callbacks will see POST state. self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) self.training_state = TrainingState.BACKWARD_POST - if param.grad is None: + grad_or_main_grad = ( + param.main_grad if hasattr(param, "main_grad") else param.grad + ) + if grad_or_main_grad is None: return if hasattr(param, "_linked_param"): @@ -1703,10 +1722,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared: param = param._linked_param - assert param.grad is not None, param.shape - if param.grad.requires_grad: + assert grad_or_main_grad is not None, param.shape + if grad_or_main_grad.requires_grad: raise RuntimeError("FSDP only works with gradients that don't require gradients") + if ( + self._require_backward_grad_sync + and self._is_fp8_dtype() + and isinstance(param, FlatParameter) + ): + # Need to update amax and scale here before + # _cast_fp32_param_shards_to_fp16 + for m in set(info[1] for info in param._param_infos): + # Previous iteration was grad_enabled + if m.fp8_meta.get("update_amax_and_scale_fwd", False): + if m.fp8_meta["recipe"].reduce_amax: + FP8GlobalStateManager.copy_amax_from_global_buffer( + m.fp8_meta, forward=True + ) + amax_and_scale_update(m.fp8_meta, True) + FP8GlobalStateManager.set_amax_buffer_key_deletion( + m.fp8_meta, forward=True + ) + else: + amax_and_scale_update(m.fp8_meta, True) + if self._require_backward_grad_sync or self.reshard_after_forward: # Free full params. As a special case, we don't free the full params # when in a ``no_sync`` context (as inversely indicated by @@ -1731,19 +1771,19 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): - orig_grad_data = param.grad.data + orig_grad_data = grad_or_main_grad.data if self.mixed_precision: if self.fp32_reduce_scatter: # Cast grad to FP32. - param.grad.data = param.grad.data.to(param.dtype) - elif self.compute_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: + grad_or_main_grad.data = grad_or_main_grad.data.to(param.dtype) + elif self._is_fp8_dtype(): # Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad) - param.grad.data = param.grad.data.to(torch.bfloat16) + grad_or_main_grad.data = grad_or_main_grad.data.to(torch.bfloat16) if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.gradient_predivide_factor) + grad_or_main_grad.data.div_(self.gradient_predivide_factor) if param._is_sharded: assert self._reducer is not None @@ -1751,7 +1791,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - grad = param.grad.data + if hasattr(param, "main_grad"): + grad = param.main_grad.data + param.main_grad = None + else: + grad = param.grad.data # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this @@ -1774,7 +1818,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # world_size == 1. This could be relaxed in the future, in which # case grads should be all-reduced here. # assert self.world_size == 1 - self._post_reduction_hook(param, param.grad.data) + self._post_reduction_hook(param, grad_or_main_grad.data) # After _post_backward_hook returns, orig_grad_data will eventually # go out of scope, at which point it could otherwise be freed for @@ -2212,6 +2256,17 @@ def _prep_grads_for_backward(self) -> None: right shape, device, accumulated values, etc. """ for p in self.params: + if isinstance(p, FlatParameter) and all( + _is_te_module_with_weights(info[1]) for info in p._param_infos + ): + if getattr(p, "main_grad", None) is None: + p.main_grad = torch.zeros_like( + p, dtype=torch.bfloat16 if self._is_fp8_dtype() else torch.float + ) + main_grad_views = p.get_param_views(p.main_grad) + for (_, m, n), main_grad in zip(p._param_infos, main_grad_views): + getattr(m, n).main_grad = main_grad + if p.grad is not None: if p.grad.device != p.data.device: p.grad = None @@ -2428,10 +2483,8 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No for p in params: assert p._fp16_shard is not None alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) - if self.compute_dtype in [ - torch.float8_e5m2, - torch.float8_e4m3fn, - ] and p._fp16_shard.dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: + if self._is_fp8_dtype() and _is_fp8_dtype(p._fp16_shard.dtype): + # fp8 quantization assert isinstance(p, FlatParameter) assert len(p._param_infos) == len(p._param_numels) numel_per_shard = p.numel() @@ -2442,9 +2495,7 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No if offset + numel <= 0 or offset >= numel_per_shard: offset += numel continue - assert isinstance( - m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP) - ) + assert _is_te_module_with_weights(m) fp8_dtype_forward = te.fp8.get_fp8_te_dtype( m.fp8_meta["recipe"], fprop_tensor=True )