Skip to content

Commit

Permalink
fp8 allgather
Browse files Browse the repository at this point in the history
  • Loading branch information
jspark1105 committed Sep 17, 2023
1 parent 0b77de4 commit dfe122b
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 20 deletions.
133 changes: 114 additions & 19 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from torch.nn.parameter import Parameter

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.misc.flatten_params_wrapper import FlatParameter
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
Expand Down Expand Up @@ -455,9 +456,20 @@ def __init__(
non_flatten_params = params
param_name_groups = [[n] for n in param_names]
if self.flatten_parameters:
to_be_flatten_params = [params]
non_flatten_params = []
param_name_groups = [param_names]
to_be_flatten_params = [
[
params[i]
for i in range(len(params))
if "norm_weight" not in param_names[i]
]
]
non_flatten_params = [
params[i] for i in range(len(params)) if "norm_weight" in param_names[i]
]
param_name_groups = [
[n for n in param_names if "norm_weight" not in n],
[n for n in param_names if "norm_weight" in n],
]
del param_names

self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
Expand Down Expand Up @@ -1261,7 +1273,11 @@ def _init_param_attributes(self, p: Parameter) -> None:
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
if self.compute_dtype == torch.int8 and not isinstance(p, FlatParameter):
dtype = torch.bfloat16
else:
dtype = self.compute_dtype
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=dtype)
free_storage_(p._fp16_shard)

if self.mixed_precision:
Expand All @@ -1279,8 +1295,13 @@ def _init_param_attributes(self, p: Parameter) -> None:
# world_size, although these padding elements will be removed before the
# relevant computation.
if p._is_sharded:
if self.compute_dtype == torch.int8 and not isinstance(p, FlatParameter):
# assume non flattened are precision critical like norm
dtype = torch.bfloat16
else:
dtype = self.compute_dtype
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
p.data.numel() * self.world_size, device=self.compute_device, dtype=dtype
)
free_storage_(p._full_param_padded)

Expand Down Expand Up @@ -1393,7 +1414,8 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:

# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
is_bf16 = self.compute_dtype == torch.bfloat16
# self.compute_dtype == torch.int8 means fp8 gemm and bf16 activation
is_bf16 = self.compute_dtype in [torch.bfloat16, torch.int8]
if self._is_root and self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs)

Expand Down Expand Up @@ -1691,9 +1713,12 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data

if self.mixed_precision and self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype)
if self.mixed_precision:
if self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype)
elif self.compute_dtype == torch.int8:
param.grad.data = param.grad.data.to(torch.bfloat16)

if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
Expand Down Expand Up @@ -1925,10 +1950,27 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors.append((p.data, False))
else:
p.data = p._full_param_padded
if p._full_param_padded.dtype == torch.int8:
# workaround that we don't have fp8 dtype yet to avoid an error
# "data set to a tensor that requires gradients must be floating"
p.data = p._full_param_padded.view(p.data.dtype)
else:
p.data = p._full_param_padded
output_tensors.append((p.data, True))
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
if (
p.data.dtype != p._full_param_padded.dtype
and p._full_param_padded.dtype == torch.int8
and (not self.mixed_precision or not force_full_precision)
):
# another workaround that we don't have fp8 dtype yet
p.data = (
p.data.view(p._full_param_padded.dtype)[: p._orig_size.numel()]
.view(p._orig_size)
.view(p.data.dtype)
)
else:
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)

if self.ssd_offload:
for p in self.params:
Expand Down Expand Up @@ -1997,7 +2039,7 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded
output_tensor = p._full_param_padded.view(p_data.dtype)

# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
Expand Down Expand Up @@ -2158,7 +2200,16 @@ def _use_full_params(self) -> None:
p.data = p._fp16_shard
else:
assert p._full_param_padded.storage().size() != 0, f"{p._orig_size} {id(self)}"
p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size)
if p._full_param_padded.dtype == torch.int8:
p.data = (
p._full_param_padded[: p._orig_size.numel()]
.view(p._orig_size)
.view(torch.bfloat16)
)
else:
p.data = p._full_param_padded[: p._orig_size.numel()].view(
p._orig_size
)

@torch.no_grad()
def _prep_grads_for_backward(self) -> None:
Expand Down Expand Up @@ -2382,12 +2433,56 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
if (
self.compute_dtype == torch.int8
and p._fp16_shard.dtype == torch.int8
):
assert isinstance(p, FlatParameter)
import transformer_engine.pytorch as te
from transformer_engine.pytorch.cpp_extensions import (
cast_to_fp8,
FP8FwdTensors,
)

assert len(p._param_infos) == len(p._param_numels)
numel_per_shard = p.numel()
offset = -numel_per_shard * self.rank
for i in range(len(p._param_infos)):
_, m, n = p._param_infos[i]
numel = p._param_numels[i]
if offset >= numel_per_shard or offset + numel <= 0:
offset += numel
continue
fp8_dtype_forward = te.fp8.get_fp8_te_dtype(
m.fp8_meta["recipe"], fprop_tensor=True
)
assert isinstance(
m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP)
)
if not m.fp8_initialized:
m.fp8_init(
num_gemms=2 if isinstance(m, te.LayerNormMLP) else 1
)
begin = max(offset, 0)
end = min(offset + numel, numel_per_shard)
cast_to_fp8(
p._fp32_shard[begin:end],
m.fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM2_WEIGHT
if n == "fc2_weight"
else FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=p._fp16_shard[begin:end],
)
offset += numel
p.data = p._fp16_shard.view(torch.bfloat16)
else:
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

@torch.no_grad()
Expand Down
10 changes: 9 additions & 1 deletion fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,15 @@ def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Te
self._param_numels
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
data = external_data if external_data is not None else self
if data.numel() != sum(self._param_numels):
if data.numel() * 2 == sum(self._param_numels):
# fp8 disguised as bf16
return (
t.view(s)
for (t, s) in zip(
data.view(torch.int8).split(self._param_numels), self._param_shapes
)
)
elif data.numel() != sum(self._param_numels):
raise ValueError(
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
)
Expand Down

0 comments on commit dfe122b

Please sign in to comment.