-
Notifications
You must be signed in to change notification settings - Fork 20
add PrepareFloat8ModuleInput for sequence parallel #275
Changes from 2 commits
d015c2e
3edc3ec
757be52
77c2353
6d04ad9
edec777
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,16 @@ | ||
import torch | ||
import torch.nn as nn | ||
from float8_experimental.float8_dynamic_linear import ( | ||
cast_to_float8_e4m3fn, | ||
cast_to_float8_e5m2_bw, | ||
) | ||
from torch.distributed._tensor import DTensor | ||
from torch.distributed.device_mesh import DeviceMesh | ||
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
PrepareModuleInput, | ||
RowwiseParallel, | ||
) | ||
|
||
# subclass the ColwiseParallel and RowwiseParallel classes | ||
# to add the float8 support | ||
|
@@ -109,3 +114,94 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | |
) | ||
|
||
return super()._apply(module, device_mesh) | ||
|
||
|
||
class PrepareFloat8ModuleInput(PrepareModuleInput): | ||
# subclass the PrepareModuleInput classes, the only difference is that after we prepare | ||
# the input DTensor, we cast the input to DTensor(Float8Tensor) | ||
def _prepare_input_fn(self, inputs, device_mesh): | ||
if self.input_layouts is None: | ||
return inputs | ||
prepared_inputs = [] | ||
if not isinstance(inputs, tuple): | ||
inputs = (inputs,) | ||
if len(inputs) != len(self.input_layouts): | ||
raise ValueError("module inputs and input_layouts should have same length!") | ||
|
||
assert ( | ||
self.desired_input_layouts is not None | ||
), "desired module inputs should not be None!" | ||
for inp, input_layout, desired_layout in zip( | ||
inputs, self.input_layouts, self.desired_input_layouts | ||
): | ||
if input_layout is not None: | ||
if isinstance(inp, DTensor): | ||
# TODO: re-enable the check once we fix the compile path | ||
# assert inp.placements[0] == input_layout | ||
dt_inp = inp | ||
else: | ||
dt_inp = DTensor.from_local( | ||
inp, device_mesh, (input_layout,), run_check=False | ||
) | ||
|
||
dt_inp = cast_to_float8_e4m3fn( | ||
dt_inp, self.fwd_linear_config | ||
) # DTensor(Float8Tensor) | ||
if desired_layout is not None and input_layout != desired_layout: | ||
# i.e. Shard -> Replicate: allgather | ||
dt_inp = dt_inp.redistribute(placements=(desired_layout,)) | ||
prepared_inputs.append( | ||
dt_inp.to_local() if self.use_local_output else dt_inp | ||
) | ||
else: | ||
prepared_inputs.append(inp) | ||
return tuple(prepared_inputs) | ||
|
||
def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): | ||
prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) | ||
prepared_kwarg_inputs = {} | ||
for kwarg_key in kwarg_inputs.keys(): | ||
kwarg_val = kwarg_inputs[kwarg_key] | ||
input_layout = None | ||
if kwarg_key in self.input_kwarg_layouts: | ||
wanchaol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
input_layout = self.input_kwarg_layouts[kwarg_key] | ||
assert isinstance( | ||
kwarg_val, torch.Tensor | ||
), f"input of key {kwarg_key} to the module should be a Tensor!" | ||
kwarg_val = DTensor.from_local( | ||
kwarg_val, device_mesh, (input_layout,), run_check=False | ||
) | ||
|
||
kwarg_val = cast_to_float8_e4m3fn( | ||
kwarg_val, self.fwd_linear_config | ||
) # DTensor(Float8Tensor) | ||
if kwarg_key in self.desired_input_kwarg_layouts: | ||
desired_layout = self.desired_input_kwarg_layouts[kwarg_key] | ||
if desired_layout != input_layout: | ||
kwarg_val = kwarg_val.redistribute(placements=(desired_layout,)) | ||
|
||
prepared_kwarg_inputs[kwarg_key] = ( | ||
kwarg_val.to_local() if self.use_local_output else kwarg_val | ||
) | ||
else: | ||
prepared_kwarg_inputs[kwarg_key] = kwarg_val | ||
|
||
return (prepared_arg_inputs, prepared_kwarg_inputs) | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear | ||
|
||
# search for ScaledMM configs for all the submodules and make sure they are the same | ||
fwd_linear_config = None | ||
for mod in module.modules(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT something like the following to avoid the logic below?
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about this option too, the concern I have on this is that, this would make the API diverges from the TP API offered in core, so making the switch between fp8 and bf16 be harder. Also I think user would need to know how to construct the ScaledMMConfig, this basically make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, good point, that's not intended to be a user facing thing. How about something like requiring a name of the module to get the config from? I think the user API of the current code is great (no extra args), but the restriction that all configs in the module need the same config is not ideal. If we are ok with changing that later, current API sgtm. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this make sense! Let me draft up the changes for accept the module_fqn to get the scaled mm config from. My current thinking on how we could approach this:
|
||
if isinstance(mod, Float8DynamicLinear): | ||
if fwd_linear_config is None: | ||
fwd_linear_config = mod.forward_config | ||
else: | ||
assert ( | ||
fwd_linear_config == mod.forward_config | ||
), "All the Float8DynamicLinear modules should have same forward config!" | ||
|
||
self.fwd_linear_config = fwd_linear_config | ||
super()._apply(module, device_mesh) | ||
return module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe we can have e4m3 in the name, and maybe add a TODO to support the AMD version of e4m3 eventually?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a quick docblock to explain that this is ensuring the float8 cast happens before the all-gather if there are multiple float8 users of the input activation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what's your thought on these two choice: 1. make e4m3 appears in the name of this class 2. make this class constructor take an additional argument of fp8 dtype, i.e.
float8_dtype=torch.float8_e4m3fn
, and we default to this e4m3fn dtype, and then later we can add on the AMD version of e4m3 by passing a different float8_dtype` arg?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm