diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index dd63165..48cdc8b 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from float8_experimental.float8_dynamic_linear import ( cast_to_float8_e4m3fn, @@ -5,7 +6,11 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, +) # subclass the ColwiseParallel and RowwiseParallel classes # to add the float8 support @@ -109,3 +114,93 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) return super()._apply(module, device_mesh) + + +class PrepareFloat8ModuleInput(PrepareModuleInput): + # subclass the PrepareModuleInput classes to implement fp8 specific logic, the only difference is that + # after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor) + # This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) + # so that if there are multiple float8 users of the input activation, we perform fp8 allgather + # only once. + # FP8 Args: + # float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, + # we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn + # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used + # for the float8 cast. If not specified, we will search for the Float8DynamicLinear in the submodules + # and use the forward config from that module, in this case all module's forward config must be + # the same. + + def __init__( + self, + *, + input_layouts=None, + desired_input_layouts=None, + input_kwarg_layouts=None, + desired_input_kwarg_layouts=None, + use_local_output=False, + float8_dtype=torch.float8_e4m3fn, + fwd_config_submodule_fqn=None, + ): + super().__init__( + input_layouts=input_layouts, + desired_input_layouts=desired_input_layouts, + input_kwarg_layouts=input_kwarg_layouts, + desired_input_kwarg_layouts=desired_input_kwarg_layouts, + use_local_output=use_local_output, + ) + + # fp8 specific fields + self.float8_dtype = float8_dtype + self.fwd_config_submodule_fqn = fwd_config_submodule_fqn + + if self.float8_dtype != torch.float8_e4m3fn: + raise NotImplementedError( + "PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now" + ) + + def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): + if input_layout is not None: + if isinstance(input, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert inp.placements[0] == input_layout + dt_inp = input + else: + assert isinstance( + input, torch.Tensor + ), "expecting input to be a torch.Tensor!" + dt_inp = DTensor.from_local( + input, mesh, (input_layout,), run_check=False + ) + + dt_inp = cast_to_float8_e4m3fn( + dt_inp, self.fwd_linear_config + ) # DTensor(Float8Tensor) + if desired_layout is not None and input_layout != desired_layout: + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) + + return dt_inp.to_local() if self.use_local_output else dt_inp + else: + return input + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + + fwd_linear_config = None + if self.fwd_config_submodule_fqn is not None: + fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) + assert isinstance(fwd_linear, Float8DynamicLinear) + fwd_linear_config = fwd_linear.forward_config + else: + # search for ScaledMM configs for all the submodules and make sure they are the same + for mod in module.modules(): + if isinstance(mod, Float8DynamicLinear): + if fwd_linear_config is None: + fwd_linear_config = mod.forward_config + else: + assert ( + fwd_linear_config == mod.forward_config + ), "All the Float8DynamicLinear modules should have same forward config!" + + self.fwd_linear_config = fwd_linear_config + super()._apply(module, device_mesh) + return module diff --git a/test/test_dtensor.py b/test/test_dtensor.py index bb8d3db..e319608 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -12,6 +12,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from float8_experimental.float8_dynamic_linear import ( Float8DynamicLinear, @@ -22,6 +23,7 @@ from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, + PrepareFloat8ModuleInput, ) from float8_experimental.float8_utils import tensor_to_scale from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard @@ -38,17 +40,26 @@ def setup_distributed(): return device_mesh -class ToyModel(nn.Module): +class FeedForward(nn.Module): """MLP based model""" + def __init__(self): + super(FeedForward, self).__init__() + self.w1 = nn.Linear(16, 32, bias=False) + self.w2 = nn.Linear(16, 32, bias=False) + self.out_proj = nn.Linear(32, 16, bias=False) + + def forward(self, x): + return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) + + +class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() - self.in_proj = nn.Linear(16, 32) - self.relu = nn.ReLU() - self.out_proj = nn.Linear(32, 16) + self.ffn = FeedForward() def forward(self, x): - return self.out_proj(self.relu(self.in_proj(x))) + return self.ffn(x) def test_scaled_mm(mesh: DeviceMesh, size=16): @@ -182,8 +193,9 @@ def test_fp8_mlp_tensor_parallelism_base( tp_model, mesh, { - "in_proj": Float8ColwiseParallel(), - "out_proj": Float8RowwiseParallel(), + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel(), }, ) @@ -192,17 +204,46 @@ def test_fp8_mlp_tensor_parallelism_base( sp_model, mesh, { - "in_proj": Float8ColwiseParallel(input_layouts=Shard(0)), - "out_proj": Float8RowwiseParallel( - output_layouts=Shard(0), use_local_output=False + "ffn": PrepareFloat8ModuleInput( + input_layouts=Shard(1), desired_input_layouts=Replicate() + ), + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + # PrepareFloat8ModuleInput with specific submodule fqn + sp_model2 = copy.deepcopy(toy_model) + sp_model2 = swap_linear_with_float8_linear( + sp_model2, Float8DynamicLinear, emulate=True + ) + + sp_model2 = parallelize_module( + sp_model2, + mesh, + { + "ffn": PrepareFloat8ModuleInput( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + fwd_config_submodule_fqn="w2", + ), + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel( + output_layouts=Shard(1), use_local_output=False ), }, ) if compile: tp_model = torch.compile(tp_model) + sp_model = torch.compile(sp_model) + sp_model2 = torch.compile(sp_model2) - x_fp32 = torch.rand(size * 2, size, device=device, requires_grad=False) + x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) x_fp32_tp_input = x_fp32.clone() x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) @@ -214,11 +255,19 @@ def test_fp8_mlp_tensor_parallelism_base( global_out.sum().backward() torch.testing.assert_close(tp_out, global_out) torch.testing.assert_close(sp_out.full_tensor(), global_out) + torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad + ) + + sp_out2 = sp_model2(x_fp32_sp_input) + sp_out2.sum().backward() + torch.testing.assert_close(sp_out2.full_tensor(), global_out) torch.testing.assert_close( - tp_model.in_proj.weight.grad, sp_model.in_proj.weight.grad + tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad ) torch.testing.assert_close( - tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad + tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad )