Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[wip] hooks
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Dec 27, 2023
1 parent 31fba04 commit f00cac9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
9 changes: 9 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@
# according to their microbatching/pipeline parallel setup.
# Note: this is currently a global flag for simplicity and dynamo performance.
weight_cache_enabled = False

#
# Other
#

# If True, dynamic linear uses hooks for activation casting
# TODO(before land): add test coverage for both cases
dynamic_use_activation_hooks = True
# dynamic_use_activation_hooks = False
38 changes: 36 additions & 2 deletions float8_experimental/dynamic_linear/dynamic_float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
import float8_experimental.config as config


class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
Expand Down Expand Up @@ -38,6 +39,22 @@ def backward(ctx, gradY):
None,
)

def cast_x_to_float8_e4m3fn_pre_hook(module, args):
"""
Hook to cast the incoming activation to `torch.float8_e4m3fn`
"""
return module.cast_to_float8(args[0])

def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output):
"""
Hook to cast the incoming gradient to `torch.float8_e5m2`
"""
gradY = grad_output[0]
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
gradY_scaled = gradY * gradY_scale
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2)
tensor_fp8 = Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=module.emulate)
return (tensor_fp8,)

class Float8DynamicLinear(torch.nn.Linear):
"""
Expand All @@ -48,9 +65,16 @@ class Float8DynamicLinear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_weight_tag()
self.use_activation_hooks = config.dynamic_use_activation_hooks

def forward(self, x):
x_fp8 = self.cast_to_float8(x)
# cast x to float8_e4m3fn
if self.use_activation_hooks:
x_fp8 = x
else:
x_fp8 = self.cast_to_float8(x)

# cast w to float8_e4m3fn
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
w_fp8 = self._w_fp8
else:
Expand All @@ -59,7 +83,10 @@ def forward(self, x):
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

# Cast gradY to float8_e5m2 during backward
y = self.cast_to_float8e5m2_bw(y)
if self.use_activation_hooks:
pass
else:
y = self.cast_to_float8e5m2_bw(y)

return y

Expand All @@ -69,6 +96,7 @@ def add_weight_tag(self):
self.weight._is_fp8_weight = True

def cast_to_float8(self, inpt_tensor):
# TODO rename this function to clarify e4m3
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
Expand All @@ -92,4 +120,10 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.bias = mod.bias
new_mod.emulate = emulate
new_mod.add_weight_tag()

new_mod.use_activation_hooks = config.dynamic_use_activation_hooks
if new_mod.use_activation_hooks:
# install the hooks
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
new_mod.register_full_backward_pre_hook(cast_dldy_to_float8_e5m2_backward_pre_hook)
return new_mod

0 comments on commit f00cac9

Please sign in to comment.