Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add unit tests for FSDP2 + torch.compile(transformer block) (#321)
Browse files Browse the repository at this point in the history
Summary:
TorchTitan complains about FSDP2 + float8 + torch.compile(transformer block).

there is a mismatch in float8 scale so dynamo guards assersion failed `torch._C._dynamo.guards.assert_size_stride(new_inputs[3], (), ())`
* in 1st iteration, we calculate float8 scale through `cast_to_float8_e4m3_dynamic` ([code](https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/fsdp_utils.py#L172)). scale is a scalar tensor, eg `tensor(4674.8633)`
* in 2nd iteration, we calulate float8 scale through `precompute_float8_dynamic_scale`, but scale is NOT a scalar tensor, eg `tensor([[4674.8633]]`
* this PR calls `.squeeze` to make sure scales are always scalar tensors, and dynamo guards assersion always hold true

added unit test so we can catch the isssue at PR time

TODO: add fp8 + torch.compile to CI in torchtitan

Pull Request resolved: #321

Reviewed By: vkuzo

Differential Revision: D59892261

Pulled By: weifengpy

fbshipit-source-id: 6f9f5a4e2de06c347403f4c7c82b3978f37ff9eb
  • Loading branch information
weifengpy authored and facebook-github-bot committed Jul 18, 2024
1 parent ec8b46c commit 7f0d6bb
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ pytest test/test_numerics_integration.py
./test/test_dtensor.sh

# run integration tests on the FSDP2 integration
python test/test_fsdp2/test_fsdp2_eager.py
python test/test_fsdp2/test_fsdp2.py

# run all of these tests
./test/test_everything.sh
Expand Down
2 changes: 0 additions & 2 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional, Tuple

import torch

from float8_experimental.float8_tensor import (
Expand Down
6 changes: 4 additions & 2 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
float8_linear.weight._local_tensor._precomputed_scale = (
scale._local_tensor.squeeze()
)


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand Down Expand Up @@ -301,7 +303,7 @@ def __tensor_flatten__(self):
],
{
"mm_config": self._mm_config,
"is_amax_initialized": is_amax_initialized,
"is_amax_initialized": self.is_amax_initialized,
},
)

Expand Down
2 changes: 1 addition & 1 deletion test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ then
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_dtensor.sh
pytest test/test_fsdp2/test_fsdp2_eager.py
pytest test/test_fsdp2/test_fsdp2.py
fi

echo "all tests successful"
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_transformer_parity(self):
TensorScalingType.DYNAMIC,
TensorScalingType.DELAYED,
],
"compile_transformer_block": [False, True],
},
self._test_transformer_parity,
)
Expand All @@ -98,6 +99,7 @@ def _test_transformer_parity(
enable_fsdp_fp8_all_gather: bool,
precompute: bool,
scaling_type_w: TensorScalingType,
compile_transformer_block: bool,
):
if not enable_fsdp_fp8_all_gather and precompute:
return
Expand All @@ -112,11 +114,17 @@ def _test_transformer_parity(
module = self.init_transformer(weight_tying=weight_tying).cuda()
ref_module = copy.deepcopy(module)
swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w)
if compile_transformer_block:
for layer_id, transformer_block in ref_module.layers.named_children():
transformer_block = torch.compile(transformer_block, dynamic=False)
ref_module.layers.register_module(layer_id, transformer_block)
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w)
for submodule in module.modules():
if isinstance(submodule, TransformerBlock):
fully_shard(submodule)
for layer_id, transformer_block in module.layers.named_children():
if compile_transformer_block:
transformer_block = torch.compile(transformer_block, dynamic=False)
fully_shard(transformer_block)
module.layers.register_module(layer_id, transformer_block)
fully_shard(module)
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
Expand All @@ -132,6 +140,7 @@ def _test_transformer_parity(
local_inp,
precompute,
scaling_type_w=scaling_type_w,
compile_transformer_block=compile_transformer_block,
)

@skip_if_lt_x_gpu(2)
Expand Down
8 changes: 6 additions & 2 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
Expand All @@ -23,6 +23,7 @@ def check_parity_no_mp(
local_inp: torch.Tensor,
precompute: bool = False,
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
compile_transformer_block: bool = False,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
Expand All @@ -46,7 +47,10 @@ def check_parity_no_mp(
):
precompute_float8_dynamic_scale_for_fsdp(model)

test_cls.assertEqual(losses[0], losses[1])
if compile_transformer_block:
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4)
else:
test_cls.assertEqual(losses[0], losses[1])


def check_parity_bf16_mp(
Expand Down

0 comments on commit 7f0d6bb

Please sign in to comment.