-
Notifications
You must be signed in to change notification settings - Fork 20
early return by check tensor already casted or not #233
Conversation
as titled, it turns out we don't need to install additional hooks base on our TP + FP8 design. The only thing we need to do here is to be able to turn off activation casting, so that we can put activation casting in the TP hooks So renaming the flag to cast_activation instead and delete relevant tests
@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
||
def forward(self, x): | ||
# cast x to float8_e4m3fn if not using activation hooks | ||
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x) | ||
x_fp8 = self.cast_to_float8_e4m3fn(x) if self.cast_activation else x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If self.cast_activation is False
, then what type is x
? Would it be a DTensor
whose local tensor is Float8Tensor
?
I wonder if we need the bool at all. For example, the cast_to_float8_e4m3fn
function could be idempotent in that if it is passed a Float8Tensor
, then we return it early. The wrinkle here is that it could be a DTensor
wrapping a Float8Tensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If its DTensor(Float8Tensor)), this would still be a no-op though right, since the cast has already been done? I think it could be if it was something like if no flaot8Tensor subclasses exist in the subclass hierarchy then pass through otherwise do a cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cast activation is false, it would be DTensor(torch.Tensor[fp32/16])
, we want to turn off activation casting is because we want this casting happen inside the TP preforward/forward hooks (as it needs to happen in certain order, i.e. after from_local
and before redistribute
), see changes in this PR #234
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if you are going to be using the TP strategies then you turn off activation casting: https://github.com/pytorch-labs/float8_experimental/pull/234/files#diff-0c0c016522783d31fb102a4088caaa2a64f783f7b5449559be4670a11fd5ed31R172.
Doesnt this mean that when line 76 gets used,
x will be a DTensor(Float8Tensor) because x will be this input tensor: https://github.com/pytorch-labs/float8_experimental/pull/234/files#diff-ed26a8770aa85661ab521607c70b62b60123890c2518687d9f58f27e31200955R34
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated according to our discussions :)
@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@@ -30,63 +34,43 @@ def forward( | |||
|
|||
@staticmethod | |||
def backward(ctx, gradY): | |||
if tensor_already_casted_to_fp8(gradY): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for cast_to_float8_e5m2_bw
, unfortunately I can't do a forward check only and have to check the backward gradients to see if it's already casted, as the forward only have y
but not grad_y
if isinstance(tensor, Float8Tensor): | ||
return True | ||
elif isinstance(tensor, DTensor) and isinstance(tensor._local_tensor, Float8Tensor): | ||
# TODO: shall we stick to public API and directly use tensor.to_local() here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if in general, subclasses composing with other subclasses should have a generic way to determine the nested types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! I love the ratio of red lines to green!
I think the fixture in conftest.py for xfailing the grad hooks can also be removed
Thanks!
@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
""" | ||
Create an nn.Linear with fp8 compute from a regular nn.Linear | ||
|
||
Args: | ||
mod (torch.nn.Linear): nn.Linear to convert | ||
emulate (bool): whether to emulate fp8 matmul logic in float32 | ||
use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic | ||
cast_activation (bool): whether to use activation hooks instead of inlining the casting logic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This looks like it should be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was removed in the next pr
as titled, it turns out we don't need to install additional hooks base on our TP + FP8 design. The only thing we need to do here is to be able to turn off activation casting, so that we can put activation casting in the TP hooks
So just check if the tensor already been casted to fp8 or not, as in TP we would cast activation into the DTensor's Float8Colwise/Rowwise instead.