Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[wip] add option to do activation/grad cast from hooks #170

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Dec 26, 2023

Summary:

Testing moving activation casting logic into hooks, so we can start building towards composability of Float8 with DTensor

Note: needs pytorch/pytorch#116454 to land to enable backward pre hook in all cases

Current status:

  1. test_base.py works for eager mode
  2. test_compile.py fails with https://gist.github.com/vkuzo/3bfa1fcbb7b3c0ee0186eae2944eb75e , we need dynamo team help to debug this

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 26, 2023
test/test_bw_hook.py Outdated Show resolved Hide resolved
@vkuzo vkuzo force-pushed the 20231226_add_hooks branch from 3fe1055 to 855795c Compare December 26, 2023 21:46
@vkuzo vkuzo changed the title [wip] hooks [wip] add option to do activation/grad cast from hooks Dec 27, 2023
"""
Hook to cast the incoming gradient to `torch.float8_e5m2`
"""
new_output = NoopFwToFloat8E5M2Bw.apply(output, module.emulate)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised that torch.compile does not support this case with autograd function, as both DTensor pre-forward hook and forward hook are using autograd functions, probably because DTensor is using sth like allow_in_graph in dynamo. Maybe one way to workaround is to use a function that calls NoopFwToFloat8E5M2Bw.apply(output, module.emulate) inside, and then make it a allow_in_graph in dynamo?

@vkuzo
Copy link
Contributor Author

vkuzo commented Dec 27, 2023

note: backward pre hook works with pytorch/pytorch#116454, we still see the same dynamo error though (and not using torch.autograd.Function anymore)

@vkuzo vkuzo force-pushed the 20231226_add_hooks branch 2 times, most recently from 270ce64 to f00cac9 Compare December 27, 2023 17:53
@vkuzo
Copy link
Contributor Author

vkuzo commented Dec 27, 2023

@wanchaol , looks like the current dynamo issue is lack of support of Float8Tensor being in a subgraph boundary. Logs with evidence supporting this: https://gist.github.com/vkuzo/3cf42b6a54a68be1a45e8e1ca07eeb26. #166 has a more minimal repro.

Would you expect allow_in_graph to be related?

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20231226_add_hooks branch from f00cac9 to f83bf21 Compare January 13, 2024 00:47
@vkuzo vkuzo mentioned this pull request Jan 16, 2024
facebook-github-bot pushed a commit that referenced this pull request Jan 31, 2024
Summary:
This is a duplicate of: #170
With more testing, ideally I think we wouldn't have the choice between hooks and modified forwards and just use hooks. However compile does not appear to support this yet

Pull Request resolved: #198

Reviewed By: wanchaol

Differential Revision: D53287660

Pulled By: drisspg

fbshipit-source-id: 727e43e8850f3a480ba87df80c0710516ef45f28
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants