This repository was archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
early return by check tensor already casted or not #233
Closed
Closed
Changes from 3 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a104f02
rename use_activation_hooks to cast_activation
wanchaol be02472
format
wanchaol 0136d1f
switch to have tensor_casted_to_fp8 util function instead
wanchaol eea4595
remove conftest
wanchaol 020642a
recursive check already casted
wanchaol e5ec3bf
Merge branch 'main' into cast_activation
wanchaol a1cb6ef
lint
wanchaol f036d59
Merge branch 'main' into cast_activation
wanchaol File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -304,16 +304,15 @@ def forward(self, x): | |
return y | ||
|
||
@classmethod | ||
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False): | ||
def from_float(cls, mod, emulate: bool = False): | ||
""" | ||
Create an nn.Linear with fp8 compute from a regular nn.Linear | ||
|
||
Args: | ||
mod (torch.nn.Linear): nn.Linear to convert | ||
emulate (bool): whether to emulate fp8 matmul logic in float32 | ||
use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic | ||
cast_activation (bool): whether to use activation hooks instead of inlining the casting logic | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This looks like it should be removed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was removed in the next pr |
||
""" | ||
assert not use_activation_hooks, "use_activation_hooks is not supported yet!" | ||
# TODO Follow up! This is a great idea but we need the mixin base to create real | ||
# Tensors and the Linear base to create empty params | ||
# with torch.device("meta"): | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,19 @@ | |
aten = torch.ops.aten | ||
|
||
|
||
def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: | ||
""" | ||
Check if the tensor is already casted to fp8 | ||
""" | ||
if isinstance(tensor, Float8Tensor): | ||
return True | ||
elif isinstance(tensor, DTensor) and isinstance(tensor._local_tensor, Float8Tensor): | ||
# TODO: shall we stick to public API and directly use tensor.to_local() here? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if in general, subclasses composing with other subclasses should have a generic way to determine the nested types There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return True | ||
|
||
return False | ||
|
||
|
||
def to_fp8_no_autograd( | ||
x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, emulate: bool | ||
) -> "Float8Tensor": | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for
cast_to_float8_e5m2_bw
, unfortunately I can't do a forward check only and have to check the backward gradients to see if it's already casted, as the forward only havey
but notgrad_y