Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
address comments from vkuzo
Browse files Browse the repository at this point in the history
  • Loading branch information
wanchaol committed Jun 12, 2024
1 parent 3edc3ec commit 757be52
Showing 1 changed file with 61 additions and 75 deletions.
136 changes: 61 additions & 75 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,88 +119,74 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
class PrepareFloat8ModuleInput(PrepareModuleInput):
# subclass the PrepareModuleInput classes, the only difference is that after we prepare
# the input DTensor, we cast the input to DTensor(Float8Tensor)
def _prepare_input_fn(self, inputs, device_mesh):
if self.input_layouts is None:
return inputs
prepared_inputs = []
if not isinstance(inputs, tuple):
inputs = (inputs,)
if len(inputs) != len(self.input_layouts):
raise ValueError("module inputs and input_layouts should have same length!")

assert (
self.desired_input_layouts is not None
), "desired module inputs should not be None!"
for inp, input_layout, desired_layout in zip(
inputs, self.input_layouts, self.desired_input_layouts
):
if input_layout is not None:
if isinstance(inp, DTensor):
# TODO: re-enable the check once we fix the compile path
# assert inp.placements[0] == input_layout
dt_inp = inp
else:
dt_inp = DTensor.from_local(
inp, device_mesh, (input_layout,), run_check=False
)

dt_inp = cast_to_float8_e4m3fn(
dt_inp, self.fwd_linear_config
) # DTensor(Float8Tensor)
if desired_layout is not None and input_layout != desired_layout:
# i.e. Shard -> Replicate: allgather
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
prepared_inputs.append(
dt_inp.to_local() if self.use_local_output else dt_inp
)
else:
prepared_inputs.append(inp)
return tuple(prepared_inputs)

def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
prepared_kwarg_inputs = {}
for kwarg_key in kwarg_inputs.keys():
kwarg_val = kwarg_inputs[kwarg_key]
input_layout = None
if kwarg_key in self.input_kwarg_layouts:
input_layout = self.input_kwarg_layouts[kwarg_key]
assert isinstance(
kwarg_val, torch.Tensor
), f"input of key {kwarg_key} to the module should be a Tensor!"
kwarg_val = DTensor.from_local(
kwarg_val, device_mesh, (input_layout,), run_check=False
)

kwarg_val = cast_to_float8_e4m3fn(
kwarg_val, self.fwd_linear_config
) # DTensor(Float8Tensor)
if kwarg_key in self.desired_input_kwarg_layouts:
desired_layout = self.desired_input_kwarg_layouts[kwarg_key]
if desired_layout != input_layout:
kwarg_val = kwarg_val.redistribute(placements=(desired_layout,))

prepared_kwarg_inputs[kwarg_key] = (
kwarg_val.to_local() if self.use_local_output else kwarg_val
)
# This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate)
# so that if there are multiple float8 users of the input activation, we perform fp8 allgather
# only once.

def __init__(
self,
*,
input_layouts = None,
desired_input_layouts = None,
input_kwarg_layouts = None,
desired_input_kwarg_layouts = None,
use_local_output = False,
float8_dtype = torch.float8_e4m3fn,
fwd_config_submodule_fqn = None,
):
super().__init__(
input_layouts=input_layouts,
desired_input_layouts=desired_input_layouts,
input_kwarg_layouts=input_kwarg_layouts,
desired_input_kwarg_layouts=desired_input_kwarg_layouts,
use_local_output=use_local_output,
)

# fp8 specific fields
self.float8_dtype = float8_dtype
self.fwd_config_submodule_fqn = fwd_config_submodule_fqn

if self.float8_dtype != torch.float8_e4m3fn:
raise NotImplementedError("PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now")

def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
if input_layout is not None:
if isinstance(input, DTensor):
# TODO: re-enable the check once we fix the compile path
# assert inp.placements[0] == input_layout
dt_inp = input
else:
prepared_kwarg_inputs[kwarg_key] = kwarg_val
assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!"
dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False)

dt_inp = cast_to_float8_e4m3fn(
dt_inp, self.fwd_linear_config
) # DTensor(Float8Tensor)
if desired_layout is not None and input_layout != desired_layout:
dt_inp = dt_inp.redistribute(placements=(desired_layout,))

return (prepared_arg_inputs, prepared_kwarg_inputs)
return dt_inp.to_local() if self.use_local_output else dt_inp
else:
return input

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

# search for ScaledMM configs for all the submodules and make sure they are the same
fwd_linear_config = None
for mod in module.modules():
if isinstance(mod, Float8DynamicLinear):
if fwd_linear_config is None:
fwd_linear_config = mod.forward_config
else:
assert (
fwd_linear_config == mod.forward_config
), "All the Float8DynamicLinear modules should have same forward config!"
if self.fwd_config_submodule_fqn is not None:
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
assert isinstance(fwd_linear, Float8DynamicLinear)
fwd_linear_config = fwd_linear.forward_config
else:
# search for ScaledMM configs for all the submodules and make sure they are the same
for mod in module.modules():
if isinstance(mod, Float8DynamicLinear):
if fwd_linear_config is None:
fwd_linear_config = mod.forward_config
else:
assert (
fwd_linear_config == mod.forward_config
), "All the Float8DynamicLinear modules should have same forward config!"

self.fwd_linear_config = fwd_linear_config
super()._apply(module, device_mesh)
Expand Down

0 comments on commit 757be52

Please sign in to comment.