-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fp8 all gather hack #1136
base: ngoyal_added_zero2_shard_modelparams_multiple_gpus
Are you sure you want to change the base?
Fp8 all gather hack #1136
Conversation
dfe122b
to
fe3a0d6
Compare
fe3a0d6
to
7dd000e
Compare
7dd000e
to
0224797
Compare
Will merge main_grad related changes with #1142 |
bd70153
to
af3d2d7
Compare
# Cast grad to FP32. | ||
grad_or_main_grad.data = grad_or_main_grad.to(param.dtype) | ||
elif self._is_fp8_dtype(): | ||
# Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this is not working with the latest FP8 wgrad ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This meant to be for future work when we have fp8 reduce-scatter. I'll update the comment.
@@ -1393,7 +1447,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: | |||
|
|||
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for | |||
# the conversion). | |||
is_bf16 = self.compute_dtype == torch.bfloat16 | |||
is_bf16 = self.compute_dtype in [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is_bf16_or_fp8
b9b093b
to
a2b49d1
Compare
@@ -2265,8 +2361,7 @@ def local_metadata_dict(self) -> Dict[str, Any]: | |||
backing_param_name = m.module.flat_param_names[i] | |||
names, shapes, numels = m.module.metadata(i) | |||
else: | |||
assert len(m._param_name_groups[i]) == 1 | |||
backing_param_name = m._param_name_groups[i][0] | |||
backing_param_name = m._param_name_groups[m._num_flatten_params][i - m._num_flatten_params] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to make sure checkpointing works properly with this.
d92dc0f
to
6a4d7f4
Compare
… during post_backward_hook
6a4d7f4
to
c0f4b97
Compare
This is based on ngoyal_added_zero2_shard_modelparams_multiple_gpus and adding hacks to use fp8 all-gather with Nvidia's transformer engine (see the latest commit for the changes on top of ngoyal_added_zero2_shard_modelparams_multiple_gpus branch).
This depends on transformer engine changes in https://github.com/facebookresearch/TransformerEngine/pull/20
See https://github.com/fairinternal/xlformers/pull/1403 for an example how to use.
Also depends on PyTorch changes in pytorch/pytorch#109654
To use fp8 allgather, set compute_dtype=torch.float8_e4m3fn and mixed_precision=True
We separate out precision critical parameters like affine weights for norm as non flattened params and hard-code to use bf16.
We update scale/scale_inv inside forward before _rebuild_full_params that calls _cast_fp32_param_shards_to_fp16 vs. TE baseline that updates scale/scale_inv in prepare_forward. This because we need fp8 quantization of weights earlier before allgather. (One can consider doing this in post backward but this has a problem since updating bwd amax update is done after bwd of all layers are finished which can be later than post backward so we won't be using the latest bwd amax info for scale/scale_inv update).
We hard-code special handling for a couple of TransformerEngine layers like Linear, LayerNormLinear, and LayerNormMLP in _cast_fp32_param_shards_to_fp16 to access their fp8 meta data to quantize with right scales (TODO: we may want to extract this as a user customizable call back functions?)
CC @awgu @ngoyal2707 @vedanuj @jiecaoyu @yf225 @GD06