Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

add PrepareFloat8ModuleInput for sequence parallel #275

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import torch
import torch.nn as nn
from float8_experimental.float8_dynamic_linear import (
cast_to_float8_e4m3fn,
cast_to_float8_e5m2_bw,
)
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
)

# subclass the ColwiseParallel and RowwiseParallel classes
# to add the float8 support
Expand Down Expand Up @@ -109,3 +114,94 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
)

return super()._apply(module, device_mesh)


class PrepareFloat8ModuleInput(PrepareModuleInput):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can have e4m3 in the name, and maybe add a TODO to support the AMD version of e4m3 eventually?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a quick docblock to explain that this is ensuring the float8 cast happens before the all-gather if there are multiple float8 users of the input activation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can have e4m3 in the name, and maybe add a TODO to support the AMD version of e4m3 eventually

I wonder what's your thought on these two choice: 1. make e4m3 appears in the name of this class 2. make this class constructor take an additional argument of fp8 dtype, i.e. float8_dtype=torch.float8_e4m3fn, and we default to this e4m3fn dtype, and then later we can add on the AMD version of e4m3 by passing a different float8_dtype` arg?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this class constructor take an additional argument of fp8 dtype

sgtm

# subclass the PrepareModuleInput classes, the only difference is that after we prepare
# the input DTensor, we cast the input to DTensor(Float8Tensor)
def _prepare_input_fn(self, inputs, device_mesh):
if self.input_layouts is None:
return inputs
prepared_inputs = []
if not isinstance(inputs, tuple):
inputs = (inputs,)
if len(inputs) != len(self.input_layouts):
raise ValueError("module inputs and input_layouts should have same length!")

assert (
self.desired_input_layouts is not None
), "desired module inputs should not be None!"
for inp, input_layout, desired_layout in zip(
inputs, self.input_layouts, self.desired_input_layouts
):
if input_layout is not None:
if isinstance(inp, DTensor):
# TODO: re-enable the check once we fix the compile path
# assert inp.placements[0] == input_layout
dt_inp = inp
else:
dt_inp = DTensor.from_local(
inp, device_mesh, (input_layout,), run_check=False
)

dt_inp = cast_to_float8_e4m3fn(
dt_inp, self.fwd_linear_config
) # DTensor(Float8Tensor)
if desired_layout is not None and input_layout != desired_layout:
# i.e. Shard -> Replicate: allgather
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
prepared_inputs.append(
dt_inp.to_local() if self.use_local_output else dt_inp
)
else:
prepared_inputs.append(inp)
return tuple(prepared_inputs)

def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
prepared_kwarg_inputs = {}
for kwarg_key in kwarg_inputs.keys():
kwarg_val = kwarg_inputs[kwarg_key]
input_layout = None
if kwarg_key in self.input_kwarg_layouts:
wanchaol marked this conversation as resolved.
Show resolved Hide resolved
input_layout = self.input_kwarg_layouts[kwarg_key]
assert isinstance(
kwarg_val, torch.Tensor
), f"input of key {kwarg_key} to the module should be a Tensor!"
kwarg_val = DTensor.from_local(
kwarg_val, device_mesh, (input_layout,), run_check=False
)

kwarg_val = cast_to_float8_e4m3fn(
kwarg_val, self.fwd_linear_config
) # DTensor(Float8Tensor)
if kwarg_key in self.desired_input_kwarg_layouts:
desired_layout = self.desired_input_kwarg_layouts[kwarg_key]
if desired_layout != input_layout:
kwarg_val = kwarg_val.redistribute(placements=(desired_layout,))

prepared_kwarg_inputs[kwarg_key] = (
kwarg_val.to_local() if self.use_local_output else kwarg_val
)
else:
prepared_kwarg_inputs[kwarg_key] = kwarg_val

return (prepared_arg_inputs, prepared_kwarg_inputs)

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

# search for ScaledMM configs for all the submodules and make sure they are the same
fwd_linear_config = None
for mod in module.modules():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT something like the following to avoid the logic below?

  1. PrepareFloat8ModuleInput takes a ScaledMMConfig constructor argument
  2. Float8DynamicLinear has logic where if the input is already a Float8Tensor, there is a check to verify the config matches

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this option too, the concern I have on this is that, this would make the API diverges from the TP API offered in core, so making the switch between fp8 and bf16 be harder.

Also I think user would need to know how to construct the ScaledMMConfig, this basically make ScaledMMConfig be a public facing API. I wasn't sure this is sth we want or not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this basically make ScaledMMConfig be a public facing API

Yeah, good point, that's not intended to be a user facing thing. How about something like requiring a name of the module to get the config from?

I think the user API of the current code is great (no extra args), but the restriction that all configs in the module need the same config is not ideal. If we are ok with changing that later, current API sgtm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this make sense! Let me draft up the changes for accept the module_fqn to get the scaled mm config from. My current thinking on how we could approach this:

  1. We add a fwd_config_module_fqn arg to the constructor so that user can specify which module config to take from
  2. This arg could be optional, where if user don't pass it in, we still do the search and restrict all configs in this specific module should all be the same.

if isinstance(mod, Float8DynamicLinear):
if fwd_linear_config is None:
fwd_linear_config = mod.forward_config
else:
assert (
fwd_linear_config == mod.forward_config
), "All the Float8DynamicLinear modules should have same forward config!"

self.fwd_linear_config = fwd_linear_config
super()._apply(module, device_mesh)
return module
44 changes: 29 additions & 15 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from float8_experimental.float8_dynamic_linear import (
Float8DynamicLinear,
Expand All @@ -22,6 +23,7 @@
from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from float8_experimental.float8_utils import tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
Expand All @@ -38,17 +40,26 @@ def setup_distributed():
return device_mesh


class ToyModel(nn.Module):
class FeedForward(nn.Module):
"""MLP based model"""

def __init__(self):
super(FeedForward, self).__init__()
self.w1 = nn.Linear(16, 32, bias=False)
self.w2 = nn.Linear(16, 32, bias=False)
self.out_proj = nn.Linear(32, 16, bias=False)

def forward(self, x):
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.in_proj = nn.Linear(16, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 16)
self.ffn = FeedForward()

def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))
return self.ffn(x)


def test_scaled_mm(mesh: DeviceMesh, size=16):
Expand Down Expand Up @@ -182,8 +193,9 @@ def test_fp8_mlp_tensor_parallelism_base(
tp_model,
mesh,
{
"in_proj": Float8ColwiseParallel(),
"out_proj": Float8RowwiseParallel(),
"ffn.w1": Float8ColwiseParallel(),
"ffn.w2": Float8ColwiseParallel(),
"ffn.out_proj": Float8RowwiseParallel(),
},
)

Expand All @@ -192,17 +204,21 @@ def test_fp8_mlp_tensor_parallelism_base(
sp_model,
mesh,
{
"in_proj": Float8ColwiseParallel(input_layouts=Shard(0)),
"out_proj": Float8RowwiseParallel(
output_layouts=Shard(0), use_local_output=False
"ffn": PrepareFloat8ModuleInput(
wanchaol marked this conversation as resolved.
Show resolved Hide resolved
input_layouts=Shard(1), desired_input_layouts=Replicate()
),
"ffn.w1": Float8ColwiseParallel(),
"ffn.w2": Float8ColwiseParallel(),
"ffn.out_proj": Float8RowwiseParallel(
output_layouts=Shard(1), use_local_output=False
),
},
)

if compile:
tp_model = torch.compile(tp_model)

x_fp32 = torch.rand(size * 2, size, device=device, requires_grad=False)
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])

Expand All @@ -214,11 +230,9 @@ def test_fp8_mlp_tensor_parallelism_base(
global_out.sum().backward()
torch.testing.assert_close(tp_out, global_out)
torch.testing.assert_close(sp_out.full_tensor(), global_out)
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
torch.testing.assert_close(
tp_model.in_proj.weight.grad, sp_model.in_proj.weight.grad
)
torch.testing.assert_close(
tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
)


Expand Down
Loading