-
Notifications
You must be signed in to change notification settings - Fork 19
early return by check tensor already casted or not #233
Changes from all commits
a104f02
be02472
0136d1f
eea4595
020642a
e5ec3bf
a1cb6ef
f036d59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -304,16 +304,15 @@ def forward(self, x): | |
return y | ||
|
||
@classmethod | ||
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False): | ||
def from_float(cls, mod, emulate: bool = False): | ||
""" | ||
Create an nn.Linear with fp8 compute from a regular nn.Linear | ||
|
||
Args: | ||
mod (torch.nn.Linear): nn.Linear to convert | ||
emulate (bool): whether to emulate fp8 matmul logic in float32 | ||
use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic | ||
cast_activation (bool): whether to use activation hooks instead of inlining the casting logic | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This looks like it should be removed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was removed in the next pr |
||
""" | ||
assert not use_activation_hooks, "use_activation_hooks is not supported yet!" | ||
# TODO Follow up! This is a great idea but we need the mixin base to create real | ||
# Tensors and the Linear base to create empty params | ||
# with torch.device("meta"): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,13 +7,29 @@ | |
|
||
import torch | ||
|
||
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated | ||
import torch.distributed._functional_collectives as funcol | ||
|
||
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated | ||
from torch.distributed._tensor import DTensor | ||
|
||
aten = torch.ops.aten | ||
|
||
|
||
def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: | ||
""" | ||
Check if the tensor is already casted to fp8 | ||
""" | ||
if isinstance(tensor, Float8Tensor): | ||
return True | ||
elif isinstance(tensor, DTensor): | ||
# TODO: shall we stick to public API and directly use tensor.to_local() here? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if in general, subclasses composing with other subclasses should have a generic way to determine the nested types There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return tensor_already_casted_to_fp8(tensor._local_tensor) | ||
elif isinstance(tensor, funcol.AsyncCollectiveTensor): | ||
return tensor_already_casted_to_fp8(tensor.elem) | ||
|
||
return False | ||
|
||
|
||
def to_fp8_no_autograd( | ||
x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, emulate: bool | ||
) -> "Float8Tensor": | ||
|
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for
cast_to_float8_e5m2_bw
, unfortunately I can't do a forward check only and have to check the backward gradients to see if it's already casted, as the forward only havey
but notgrad_y