Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

early return by check tensor already casted or not #233

Closed
wants to merge 8 commits into from

Conversation

wanchaol
Copy link
Contributor

@wanchaol wanchaol commented Mar 4, 2024

as titled, it turns out we don't need to install additional hooks base on our TP + FP8 design. The only thing we need to do here is to be able to turn off activation casting, so that we can put activation casting in the TP hooks

So just check if the tensor already been casted to fp8 or not, as in TP we would cast activation into the DTensor's Float8Colwise/Rowwise instead.

as titled, it turns out we don't need to install additional hooks base
on our TP + FP8 design. The only thing we need to do here is to be able
to turn off activation casting, so that we can put activation casting in
the TP hooks

So renaming the flag to cast_activation instead and delete relevant
tests
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 4, 2024
@wanchaol wanchaol requested review from awgu and drisspg March 4, 2024 18:32
@facebook-github-bot
Copy link
Contributor

@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.


def forward(self, x):
# cast x to float8_e4m3fn if not using activation hooks
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x)
x_fp8 = self.cast_to_float8_e4m3fn(x) if self.cast_activation else x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self.cast_activation is False, then what type is x? Would it be a DTensor whose local tensor is Float8Tensor?

I wonder if we need the bool at all. For example, the cast_to_float8_e4m3fn function could be idempotent in that if it is passed a Float8Tensor, then we return it early. The wrinkle here is that it could be a DTensor wrapping a Float8Tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If its DTensor(Float8Tensor)), this would still be a no-op though right, since the cast has already been done? I think it could be if it was something like if no flaot8Tensor subclasses exist in the subclass hierarchy then pass through otherwise do a cast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if cast activation is false, it would be DTensor(torch.Tensor[fp32/16]), we want to turn off activation casting is because we want this casting happen inside the TP preforward/forward hooks (as it needs to happen in certain order, i.e. after from_local and before redistribute), see changes in this PR #234

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if you are going to be using the TP strategies then you turn off activation casting: https://github.com/pytorch-labs/float8_experimental/pull/234/files#diff-0c0c016522783d31fb102a4088caaa2a64f783f7b5449559be4670a11fd5ed31R172.

Doesnt this mean that when line 76 gets used,
x will be a DTensor(Float8Tensor) because x will be this input tensor: https://github.com/pytorch-labs/float8_experimental/pull/234/files#diff-ed26a8770aa85661ab521607c70b62b60123890c2518687d9f58f27e31200955R34 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated according to our discussions :)

@wanchaol wanchaol changed the title rename use_activation_hooks to cast_activation early return by check tensor already casted or not Mar 12, 2024
@facebook-github-bot
Copy link
Contributor

@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@wanchaol wanchaol requested review from awgu and drisspg March 12, 2024 05:00
@@ -30,63 +34,43 @@ def forward(

@staticmethod
def backward(ctx, gradY):
if tensor_already_casted_to_fp8(gradY):
Copy link
Contributor Author

@wanchaol wanchaol Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for cast_to_float8_e5m2_bw, unfortunately I can't do a forward check only and have to check the backward gradients to see if it's already casted, as the forward only have y but not grad_y

if isinstance(tensor, Float8Tensor):
return True
elif isinstance(tensor, DTensor) and isinstance(tensor._local_tensor, Float8Tensor):
# TODO: shall we stick to public API and directly use tensor.to_local() here?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if in general, subclasses composing with other subclasses should have a generic way to determine the nested types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll also need to think about how this should behave, this seems in general how we can get subclass hierarchy from some type info, simply trying out type.mro or inspect.getmro seems not working for me. Maybe need to check with @ezyang @bdhirsh to see if there's any better suggestions

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! I love the ratio of red lines to green!

I think the fixture in conftest.py for xfailing the grad hooks can also be removed

Thanks!

@facebook-github-bot
Copy link
Contributor

@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@wanchaol has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@wanchaol merged this pull request in bfc60fb.

"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic
cast_activation (bool): whether to use activation hooks instead of inlining the casting logic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This looks like it should be removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was removed in the next pr

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants