Skip to content

add torch_xla_graph_execution_check_level (default disabled) flag that emits warning(1) or throw error(2) during tensor sync and output the python frame #9057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

aws-yyjau
Copy link

@aws-yyjau aws-yyjau commented Apr 29, 2025

Revision from #9050

This PR introduces a new configuration flag torch_xla_graph_execution_check_level to provide better visibility into tensor synchronization operations during XLA graph execution. The AWS neuron team will use this flag during HLO conversion so that it can help developers catch the issues of evaluating the input tensor value during compilation.

Key changes:

  • Added new configuration flag torch_xla_graph_execution_check_level (default: disabled).
  • Implemented warning/error logging during tensor synchronization events
  • Added Python stack trace output for debugging tensor sync operations
  • Log messages include relevant context about the tensor shape being performed
  • Throw an error when we set to level 2

The checking levels supported are:

  • DISABLED (default): No logging
  • WARNING (value: 1): check and log tensor sync operations as warnings
  • ERROR (value: 2): check and log tensor sync operations as warnings and throw an XLA error

Example usage:

import torch_xla
torch_xla._XLAC._set_torch_xla_graph_execution_check_level(1)

This enhancement helps developers:

The team saw issues when we use the tensor value for if-else statement.
For example,

def forward(self, tensor):
  if tensor[0] == 1:
     return tensor
  else:
     return tensor * 2

The example above can compile and run. However, it may make the developers to believe the tensors can be evaluated on the fly, leading to unexpected behaviors.
With the change and some other future changes during graph tracing, we can

  1. Identify potential code path issue in XLA graph execution and prevent the users from using tensor values in the graph
  2. Debug and trace tensor synchronization issues more effectively

Testing:

Added unit tests for different check levels
Verified log output format and stack trace information
Tested with tensor sync scenarios

Documentation (To do):

Add usage examples and checking level descriptions
Include troubleshooting guide for common tensor sync issues

root and others added 2 commits April 29, 2025 19:20
…t emits warning(1) or throw error(2) during tensor sync and output the python frame
Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for the contribution!

What is the relationship of this feature with PT_XLA_DEBUG_LEVEL c.f. https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#pytorchxla-debugging-tool ?

Would it be possible to improve the existing feature and/or maybe introduce a Pythonic API for it?

@aws-yyjau
Copy link
Author

Hi, thanks for the contribution!

What is the relationship of this feature with PT_XLA_DEBUG_LEVEL c.f. https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#pytorchxla-debugging-tool ?

It is different from PT_XLA_DEBUG_LEVEL because this new flag will be used during model tracing to avoid unexpected path. However, PT_XLA_DEBUG_LEVEL is an XLA level debugging level that supports debugging purposes.

Would it be possible to improve the existing feature and/or maybe introduce a Pythonic API for it?

If there's any suggestion of where the feature can be merged with existing flags/feature, please let me know. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants