From d015c2e1ed489ae763eb90a9174d93cffd8eb869 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Sun, 9 Jun 2024 10:57:11 -0700 Subject: [PATCH 1/6] add PrepareFloat8ModuleInput for sequence parallel when applying Sequence Parallel to a module with more than 2 linear layers for input proj, we often want to transform from Shard to Replicate once (allgather once) and then reuse the allgathered result, for fp8 we would need to do the casting before the shard -> replicate so that we can perform the fp8 allgather. This PR subclasses the PrepareModuleInput to add the fp8 casting logic to make sure we run the fp8 allgather instead of bf16 allgather then do the casting for computation. Also adjust the test cases to test the real ffn case for sequence parallel --- float8_experimental/float8_tensor_parallel.py | 76 ++++++++++++++++++- test/test_dtensor.py | 39 ++++++---- 2 files changed, 101 insertions(+), 14 deletions(-) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index dd63165..32694b7 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -5,7 +5,7 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, PrepareModuleInput # subclass the ColwiseParallel and RowwiseParallel classes # to add the float8 support @@ -109,3 +109,77 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) return super()._apply(module, device_mesh) + + +class PrepareFloat8ModuleInput(PrepareModuleInput): + # subclass the PrepareModuleInput classes, the only difference is that after we prepare + # the input DTensor, we cast the input to DTensor(Float8Tensor) + def _prepare_input_fn(self, inputs, device_mesh): + if self.input_layouts is None: + return inputs + prepared_inputs = [] + if not isinstance(inputs, tuple): + inputs = (inputs,) + if len(inputs) != len(self.input_layouts): + raise ValueError("module inputs and input_layouts should have same length!") + + assert self.desired_input_layouts is not None, "desired module inputs should not be None!" + for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): + if input_layout is not None: + if isinstance(inp, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert inp.placements[0] == input_layout + dt_inp = inp + else: + dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False) + + dt_inp = cast_to_float8_e4m3fn( + dt_inp, self.fwd_linear_config + ) # DTensor(Float8Tensor) + if desired_layout is not None and input_layout != desired_layout: + # i.e. Shard -> Replicate: allgather + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) + prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp) + else: + prepared_inputs.append(inp) + return tuple(prepared_inputs) + + def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): + prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) + prepared_kwarg_inputs = {} + for kwarg_key in kwarg_inputs.keys(): + kwarg_val = kwarg_inputs[kwarg_key] + input_layout = None + if kwarg_key in self.input_kwarg_layouts: + input_layout = self.input_kwarg_layouts[kwarg_key] + assert isinstance(kwarg_val, torch.Tensor), f"input of key {kwarg_key} to the module should be a Tensor!" + kwarg_val = DTensor.from_local(kwarg_val, device_mesh, (input_layout,), run_check=False) + + kwarg_val = cast_to_float8_e4m3fn( + kwarg_val, self.fwd_linear_config + ) # DTensor(Float8Tensor) + if kwarg_key in self.desired_input_kwarg_layouts: + desired_layout = self.desired_input_kwarg_layouts[kwarg_key] + if desired_layout != input_layout: + kwarg_val = kwarg_val.redistribute(placements=(desired_layout,)) + + prepared_kwarg_inputs[kwarg_key] = kwarg_val.to_local() if self.use_local_output else kwarg_val + else: + prepared_kwarg_inputs[kwarg_key] = kwarg_val + + return (prepared_arg_inputs, prepared_kwarg_inputs) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + # search for ScaledMM configs for all the submodules and make sure they are the same + fwd_linear_config = None + for mod in module.modules(): + if isinstance(mod, Float8DynamicLinear): + if fwd_linear_config is None: + fwd_linear_config = mod.forward_config + else: + assert fwd_linear_config == mod.forward_config, "All the Float8DynamicLinear modules should have same forward config!" + + self.fwd_linear_config = fwd_linear_config + super()._apply(module, device_mesh) + return module diff --git a/test/test_dtensor.py b/test/test_dtensor.py index bb8d3db..5401b33 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -12,6 +12,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from float8_experimental.float8_dynamic_linear import ( Float8DynamicLinear, @@ -22,6 +23,7 @@ from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, + PrepareFloat8ModuleInput ) from float8_experimental.float8_utils import tensor_to_scale from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard @@ -38,17 +40,25 @@ def setup_distributed(): return device_mesh -class ToyModel(nn.Module): +class FeedForward(nn.Module): """MLP based model""" + def __init__(self): + super(FeedForward, self).__init__() + self.w1 = nn.Linear(16, 32, bias=False) + self.w2 = nn.Linear(16, 32, bias=False) + self.out_proj = nn.Linear(32, 16, bias=False) + + def forward(self, x): + return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) + +class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() - self.in_proj = nn.Linear(16, 32) - self.relu = nn.ReLU() - self.out_proj = nn.Linear(32, 16) + self.ffn = FeedForward() def forward(self, x): - return self.out_proj(self.relu(self.in_proj(x))) + return self.ffn(x) def test_scaled_mm(mesh: DeviceMesh, size=16): @@ -182,8 +192,9 @@ def test_fp8_mlp_tensor_parallelism_base( tp_model, mesh, { - "in_proj": Float8ColwiseParallel(), - "out_proj": Float8RowwiseParallel(), + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel(), }, ) @@ -192,9 +203,11 @@ def test_fp8_mlp_tensor_parallelism_base( sp_model, mesh, { - "in_proj": Float8ColwiseParallel(input_layouts=Shard(0)), - "out_proj": Float8RowwiseParallel( - output_layouts=Shard(0), use_local_output=False + "ffn": PrepareFloat8ModuleInput(input_layouts=Shard(1), desired_input_layouts=Replicate()), + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel( + output_layouts=Shard(1), use_local_output=False ), }, ) @@ -202,7 +215,7 @@ def test_fp8_mlp_tensor_parallelism_base( if compile: tp_model = torch.compile(tp_model) - x_fp32 = torch.rand(size * 2, size, device=device, requires_grad=False) + x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) x_fp32_tp_input = x_fp32.clone() x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) @@ -215,10 +228,10 @@ def test_fp8_mlp_tensor_parallelism_base( torch.testing.assert_close(tp_out, global_out) torch.testing.assert_close(sp_out.full_tensor(), global_out) torch.testing.assert_close( - tp_model.in_proj.weight.grad, sp_model.in_proj.weight.grad + tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad ) torch.testing.assert_close( - tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad + tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad ) From 3edc3ec45e9acfca626ead0d2c7b5912007ea8b4 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Sun, 9 Jun 2024 12:52:47 -0700 Subject: [PATCH 2/6] lint fixes --- float8_experimental/float8_tensor_parallel.py | 40 ++++++++++++++----- test/test_dtensor.py | 11 ++--- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 32694b7..7d45d8c 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from float8_experimental.float8_dynamic_linear import ( cast_to_float8_e4m3fn, @@ -5,7 +6,11 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, PrepareModuleInput +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, +) # subclass the ColwiseParallel and RowwiseParallel classes # to add the float8 support @@ -123,15 +128,21 @@ def _prepare_input_fn(self, inputs, device_mesh): if len(inputs) != len(self.input_layouts): raise ValueError("module inputs and input_layouts should have same length!") - assert self.desired_input_layouts is not None, "desired module inputs should not be None!" - for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): + assert ( + self.desired_input_layouts is not None + ), "desired module inputs should not be None!" + for inp, input_layout, desired_layout in zip( + inputs, self.input_layouts, self.desired_input_layouts + ): if input_layout is not None: if isinstance(inp, DTensor): # TODO: re-enable the check once we fix the compile path # assert inp.placements[0] == input_layout dt_inp = inp else: - dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False) + dt_inp = DTensor.from_local( + inp, device_mesh, (input_layout,), run_check=False + ) dt_inp = cast_to_float8_e4m3fn( dt_inp, self.fwd_linear_config @@ -139,7 +150,9 @@ def _prepare_input_fn(self, inputs, device_mesh): if desired_layout is not None and input_layout != desired_layout: # i.e. Shard -> Replicate: allgather dt_inp = dt_inp.redistribute(placements=(desired_layout,)) - prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp) + prepared_inputs.append( + dt_inp.to_local() if self.use_local_output else dt_inp + ) else: prepared_inputs.append(inp) return tuple(prepared_inputs) @@ -152,8 +165,12 @@ def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): input_layout = None if kwarg_key in self.input_kwarg_layouts: input_layout = self.input_kwarg_layouts[kwarg_key] - assert isinstance(kwarg_val, torch.Tensor), f"input of key {kwarg_key} to the module should be a Tensor!" - kwarg_val = DTensor.from_local(kwarg_val, device_mesh, (input_layout,), run_check=False) + assert isinstance( + kwarg_val, torch.Tensor + ), f"input of key {kwarg_key} to the module should be a Tensor!" + kwarg_val = DTensor.from_local( + kwarg_val, device_mesh, (input_layout,), run_check=False + ) kwarg_val = cast_to_float8_e4m3fn( kwarg_val, self.fwd_linear_config @@ -163,7 +180,9 @@ def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): if desired_layout != input_layout: kwarg_val = kwarg_val.redistribute(placements=(desired_layout,)) - prepared_kwarg_inputs[kwarg_key] = kwarg_val.to_local() if self.use_local_output else kwarg_val + prepared_kwarg_inputs[kwarg_key] = ( + kwarg_val.to_local() if self.use_local_output else kwarg_val + ) else: prepared_kwarg_inputs[kwarg_key] = kwarg_val @@ -171,6 +190,7 @@ def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + # search for ScaledMM configs for all the submodules and make sure they are the same fwd_linear_config = None for mod in module.modules(): @@ -178,7 +198,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if fwd_linear_config is None: fwd_linear_config = mod.forward_config else: - assert fwd_linear_config == mod.forward_config, "All the Float8DynamicLinear modules should have same forward config!" + assert ( + fwd_linear_config == mod.forward_config + ), "All the Float8DynamicLinear modules should have same forward config!" self.fwd_linear_config = fwd_linear_config super()._apply(module, device_mesh) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 5401b33..5064c09 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -23,7 +23,7 @@ from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, - PrepareFloat8ModuleInput + PrepareFloat8ModuleInput, ) from float8_experimental.float8_utils import tensor_to_scale from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard @@ -52,6 +52,7 @@ def __init__(self): def forward(self, x): return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) + class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() @@ -203,7 +204,9 @@ def test_fp8_mlp_tensor_parallelism_base( sp_model, mesh, { - "ffn": PrepareFloat8ModuleInput(input_layouts=Shard(1), desired_input_layouts=Replicate()), + "ffn": PrepareFloat8ModuleInput( + input_layouts=Shard(1), desired_input_layouts=Replicate() + ), "ffn.w1": Float8ColwiseParallel(), "ffn.w2": Float8ColwiseParallel(), "ffn.out_proj": Float8RowwiseParallel( @@ -227,9 +230,7 @@ def test_fp8_mlp_tensor_parallelism_base( global_out.sum().backward() torch.testing.assert_close(tp_out, global_out) torch.testing.assert_close(sp_out.full_tensor(), global_out) - torch.testing.assert_close( - tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad - ) + torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) torch.testing.assert_close( tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad ) From 757be5283cf63b65f96f234a8a8798f8f58968af Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 11 Jun 2024 22:39:10 -0700 Subject: [PATCH 3/6] address comments from vkuzo --- float8_experimental/float8_tensor_parallel.py | 136 ++++++++---------- 1 file changed, 61 insertions(+), 75 deletions(-) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 7d45d8c..46e786b 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -119,88 +119,74 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: class PrepareFloat8ModuleInput(PrepareModuleInput): # subclass the PrepareModuleInput classes, the only difference is that after we prepare # the input DTensor, we cast the input to DTensor(Float8Tensor) - def _prepare_input_fn(self, inputs, device_mesh): - if self.input_layouts is None: - return inputs - prepared_inputs = [] - if not isinstance(inputs, tuple): - inputs = (inputs,) - if len(inputs) != len(self.input_layouts): - raise ValueError("module inputs and input_layouts should have same length!") - - assert ( - self.desired_input_layouts is not None - ), "desired module inputs should not be None!" - for inp, input_layout, desired_layout in zip( - inputs, self.input_layouts, self.desired_input_layouts - ): - if input_layout is not None: - if isinstance(inp, DTensor): - # TODO: re-enable the check once we fix the compile path - # assert inp.placements[0] == input_layout - dt_inp = inp - else: - dt_inp = DTensor.from_local( - inp, device_mesh, (input_layout,), run_check=False - ) - - dt_inp = cast_to_float8_e4m3fn( - dt_inp, self.fwd_linear_config - ) # DTensor(Float8Tensor) - if desired_layout is not None and input_layout != desired_layout: - # i.e. Shard -> Replicate: allgather - dt_inp = dt_inp.redistribute(placements=(desired_layout,)) - prepared_inputs.append( - dt_inp.to_local() if self.use_local_output else dt_inp - ) - else: - prepared_inputs.append(inp) - return tuple(prepared_inputs) - - def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): - prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) - prepared_kwarg_inputs = {} - for kwarg_key in kwarg_inputs.keys(): - kwarg_val = kwarg_inputs[kwarg_key] - input_layout = None - if kwarg_key in self.input_kwarg_layouts: - input_layout = self.input_kwarg_layouts[kwarg_key] - assert isinstance( - kwarg_val, torch.Tensor - ), f"input of key {kwarg_key} to the module should be a Tensor!" - kwarg_val = DTensor.from_local( - kwarg_val, device_mesh, (input_layout,), run_check=False - ) - - kwarg_val = cast_to_float8_e4m3fn( - kwarg_val, self.fwd_linear_config - ) # DTensor(Float8Tensor) - if kwarg_key in self.desired_input_kwarg_layouts: - desired_layout = self.desired_input_kwarg_layouts[kwarg_key] - if desired_layout != input_layout: - kwarg_val = kwarg_val.redistribute(placements=(desired_layout,)) - - prepared_kwarg_inputs[kwarg_key] = ( - kwarg_val.to_local() if self.use_local_output else kwarg_val - ) + # This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) + # so that if there are multiple float8 users of the input activation, we perform fp8 allgather + # only once. + + def __init__( + self, + *, + input_layouts = None, + desired_input_layouts = None, + input_kwarg_layouts = None, + desired_input_kwarg_layouts = None, + use_local_output = False, + float8_dtype = torch.float8_e4m3fn, + fwd_config_submodule_fqn = None, + ): + super().__init__( + input_layouts=input_layouts, + desired_input_layouts=desired_input_layouts, + input_kwarg_layouts=input_kwarg_layouts, + desired_input_kwarg_layouts=desired_input_kwarg_layouts, + use_local_output=use_local_output, + ) + + # fp8 specific fields + self.float8_dtype = float8_dtype + self.fwd_config_submodule_fqn = fwd_config_submodule_fqn + + if self.float8_dtype != torch.float8_e4m3fn: + raise NotImplementedError("PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now") + + def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): + if input_layout is not None: + if isinstance(input, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert inp.placements[0] == input_layout + dt_inp = input else: - prepared_kwarg_inputs[kwarg_key] = kwarg_val + assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!" + dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False) + + dt_inp = cast_to_float8_e4m3fn( + dt_inp, self.fwd_linear_config + ) # DTensor(Float8Tensor) + if desired_layout is not None and input_layout != desired_layout: + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) - return (prepared_arg_inputs, prepared_kwarg_inputs) + return dt_inp.to_local() if self.use_local_output else dt_inp + else: + return input def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear - # search for ScaledMM configs for all the submodules and make sure they are the same fwd_linear_config = None - for mod in module.modules(): - if isinstance(mod, Float8DynamicLinear): - if fwd_linear_config is None: - fwd_linear_config = mod.forward_config - else: - assert ( - fwd_linear_config == mod.forward_config - ), "All the Float8DynamicLinear modules should have same forward config!" + if self.fwd_config_submodule_fqn is not None: + fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) + assert isinstance(fwd_linear, Float8DynamicLinear) + fwd_linear_config = fwd_linear.forward_config + else: + # search for ScaledMM configs for all the submodules and make sure they are the same + for mod in module.modules(): + if isinstance(mod, Float8DynamicLinear): + if fwd_linear_config is None: + fwd_linear_config = mod.forward_config + else: + assert ( + fwd_linear_config == mod.forward_config + ), "All the Float8DynamicLinear modules should have same forward config!" self.fwd_linear_config = fwd_linear_config super()._apply(module, device_mesh) From 77c2353b9ddf28c186f7607af7e5de899c058cf7 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 11 Jun 2024 22:41:48 -0700 Subject: [PATCH 4/6] lint --- float8_experimental/float8_tensor_parallel.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 46e786b..ea671c9 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -126,13 +126,13 @@ class PrepareFloat8ModuleInput(PrepareModuleInput): def __init__( self, *, - input_layouts = None, - desired_input_layouts = None, - input_kwarg_layouts = None, - desired_input_kwarg_layouts = None, - use_local_output = False, - float8_dtype = torch.float8_e4m3fn, - fwd_config_submodule_fqn = None, + input_layouts=None, + desired_input_layouts=None, + input_kwarg_layouts=None, + desired_input_kwarg_layouts=None, + use_local_output=False, + float8_dtype=torch.float8_e4m3fn, + fwd_config_submodule_fqn=None, ): super().__init__( input_layouts=input_layouts, @@ -147,7 +147,9 @@ def __init__( self.fwd_config_submodule_fqn = fwd_config_submodule_fqn if self.float8_dtype != torch.float8_e4m3fn: - raise NotImplementedError("PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now") + raise NotImplementedError( + "PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now" + ) def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): if input_layout is not None: @@ -156,8 +158,12 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): # assert inp.placements[0] == input_layout dt_inp = input else: - assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!" - dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False) + assert isinstance( + input, torch.Tensor + ), "expecting input to be a torch.Tensor!" + dt_inp = DTensor.from_local( + input, mesh, (input_layout,), run_check=False + ) dt_inp = cast_to_float8_e4m3fn( dt_inp, self.fwd_linear_config From 6d04ad91cfeb2e4d742d81cfcec5c99aab0d26c0 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 12 Jun 2024 12:39:07 -0700 Subject: [PATCH 5/6] add more docs/comment and tests --- float8_experimental/float8_tensor_parallel.py | 11 +++++-- test/test_dtensor.py | 33 +++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index ea671c9..48cdc8b 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -117,11 +117,18 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: class PrepareFloat8ModuleInput(PrepareModuleInput): - # subclass the PrepareModuleInput classes, the only difference is that after we prepare - # the input DTensor, we cast the input to DTensor(Float8Tensor) + # subclass the PrepareModuleInput classes to implement fp8 specific logic, the only difference is that + # after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor) # This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) # so that if there are multiple float8 users of the input activation, we perform fp8 allgather # only once. + # FP8 Args: + # float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, + # we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn + # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used + # for the float8 cast. If not specified, we will search for the Float8DynamicLinear in the submodules + # and use the forward config from that module, in this case all module's forward config must be + # the same. def __init__( self, diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 5064c09..a0758c9 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -215,8 +215,34 @@ def test_fp8_mlp_tensor_parallelism_base( }, ) + # PrepareFloat8ModuleInput with specific submodule fqn + sp_model2 = copy.deepcopy(toy_model) + sp_model2 = swap_linear_with_float8_linear( + sp_model2, Float8DynamicLinear, emulate=True + ) + + sp_model2 = parallelize_module( + sp_model2, + mesh, + { + "ffn": PrepareFloat8ModuleInput( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + fwd_config_submodule_fqn="w2", + ), + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + if compile: tp_model = torch.compile(tp_model) + sp_model = torch.compile(sp_model) + sp_model2 = torch.compile(sp_model2) x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) x_fp32_tp_input = x_fp32.clone() @@ -235,6 +261,13 @@ def test_fp8_mlp_tensor_parallelism_base( tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad ) + sp_out2 = sp_model2(x_fp32_sp_input) + sp_out2.sum().backward() + torch.testing.assert_close(sp_out2.full_tensor(), global_out) + torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad + ) def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) From edec7776309713c867826f693153de6f2a0fbeea Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 12 Jun 2024 12:40:07 -0700 Subject: [PATCH 6/6] lint --- test/test_dtensor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index a0758c9..e319608 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -238,7 +238,6 @@ def test_fp8_mlp_tensor_parallelism_base( }, ) - if compile: tp_model = torch.compile(tp_model) sp_model = torch.compile(sp_model) @@ -264,11 +263,14 @@ def test_fp8_mlp_tensor_parallelism_base( sp_out2 = sp_model2(x_fp32_sp_input) sp_out2.sum().backward() torch.testing.assert_close(sp_out2.full_tensor(), global_out) - torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad) + torch.testing.assert_close( + tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad + ) torch.testing.assert_close( tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad ) + def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)