-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Debug CUDA graph support with operation-based API #1117
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
transformer_engine/pytorch/ops/op.py
Outdated
if fp8_recipe is None: | ||
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() | ||
if fp8_recipe is None: | ||
fp8_recipe = get_default_fp8_recipe() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmmm, this second if looks like logic that should be inside get_fp8_recipe in the FP8GlobalStateManager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, since this is an internal function, couldn't we just always ask for a valid recipe here and just deal with getting it int the caller?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case shouldn't happen in any of our current use-cases (FP8GlobalStateManager.get_fp8_recipe()
is set within fp8_autocast
, fp8_recipe
is provided within make_graphed_callables
), but it seems delicate to rely on that assumption.
if curr_len == amax_history_len: | ||
continue | ||
|
||
# Reallocate amax history |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be its own function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried to keep this logic similar to how it's handled in the modules:
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: |
I think it would be nice to consolidate this logic in
fp8.py
and reuse it for both modules and operations, but that's probably best done in a pure refactor PR.
@@ -260,6 +275,21 @@ def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: | |||
pad=(0, 0, 0, amax_history_len - curr_len), | |||
) | |||
|
|||
# Update global buffers for amax reductions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not look like graph specific thing - was the lack of this in the previous code a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, if the amax history length changes then I don't expect amax reductions to be handled correctly.
Return default recipe from FP8GlobalStateManager.get_fp8_recipe if needed. Expand error message when failing to load FP8 state after capturing CUDA graph. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
/te-ci pytorch |
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
) * Debug CUDA graph support with operation-based API Signed-off-by: Tim Moon <[email protected]> * Refactoring CUDA graph tests Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Review suggestions from @ptrendx Return default recipe from FP8GlobalStateManager.get_fp8_recipe if needed. Expand error message when failing to load FP8 state after capturing CUDA graph. Signed-off-by: Tim Moon <[email protected]> * Avoid unnecessary recursion when saving/loading FP8 state Signed-off-by: Tim Moon <[email protected]> * Fix circular import Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
This PR debugs CUDA graph support with the operation-based API (see #707). The CUDA graph logic is similar to the module-based API.
Type of change
Changes
Checklist: