Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[wip] add option to do activation/grad cast from hooks #170

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
# option is useful for safety, but not strictly necessary.
enable_pre_and_post_forward = True

# If True, dynamic linear uses hooks for activation casting
# TODO(before land): add test coverage for both cases
dynamic_use_activation_hooks = True
# dynamic_use_activation_hooks = False
39 changes: 37 additions & 2 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
import float8_experimental.config as config


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -39,25 +40,54 @@ def backward(ctx, gradY):
None,
)

def cast_x_to_float8_e4m3fn_pre_hook(module, args):
"""
Hook to cast the incoming activation to `torch.float8_e4m3fn`
"""
return module.cast_to_float8(args[0])

def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output):
"""
Hook to cast the incoming gradient to `torch.float8_e5m2`
"""
gradY = grad_output[0]
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
gradY_scaled = gradY * gradY_scale
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2)
tensor_fp8 = Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=module.emulate)
return (tensor_fp8,)

class Float8DynamicLinear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
conversion to fp8 of the input and weight tensors.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_activation_hooks = config.dynamic_use_activation_hooks

def forward(self, x):
x_fp8 = self.cast_to_float8(x)
# cast x to float8_e4m3fn
if self.use_activation_hooks:
x_fp8 = x
else:
x_fp8 = self.cast_to_float8(x)

# cast w to float8_e4m3fn
w_fp8 = self.cast_to_float8(self.weight)

y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

# Cast gradY to float8_e5m2 during backward
y = self.cast_to_float8e5m2_bw(y)
if self.use_activation_hooks:
pass
else:
y = self.cast_to_float8e5m2_bw(y)

return y

def cast_to_float8(self, inpt_tensor):
# TODO rename this function to clarify e4m3
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
Expand All @@ -80,4 +110,9 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
new_mod.use_activation_hooks = config.dynamic_use_activation_hooks
if new_mod.use_activation_hooks:
# install the hooks
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
new_mod.register_full_backward_pre_hook(cast_dldy_to_float8_e5m2_backward_pre_hook)
return new_mod
Loading