Commit 0da894f
Log global state (pytorch#170070)
Fixes pytorch#166268
### Problem
The guard logging currently prints only the check function for the GLOBAL_STATE guard. From the logs, it is unclear what the global state properties the check function has captured to compare against.
### Solution
The current verbose_code contained only the string shown in the log. Added a call to the guard manager __getstate__ to return a json structured string of the captured global state. The result is a logged check call function signature (preserved) along with the new additional global state used in all comparison checks.
### Testing
Added a logging check to verify that the global state attributes are logged with munged keys.
### Comparative Change
```
import torch
@torch.compile
def fn(x):
return x + 1
fn(torch.ones(3, 3))
```
Previously the above code produced:
> TREE_GUARD_MANAGER:
+- RootGuardManager
| +- LAMBDA_GUARD: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # _dynamo/output_graph.py:807 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:794 in init_ambient_guards
| +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0), type=<class 'torch.Tensor'>, tag_safe=(True, False)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1]) # return x + 1 # repro_global_state.py:5 in fn
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # return x + 1 # repro_global_state.py:5 in fn
With the proposed change it produces:
> TREE_GUARD_MANAGER:
+- RootGuardManager
| +- LAMBDA_GUARD: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # _dynamo/output_graph.py:807 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state() against {"allow_bf16_reduce":0,"allow_fp16_reduce":0,"allow_tf32":false,"autocast_state":{"cached_enabled":true,"dtype":[15,5,5,15,5,5,15,15,5,5],"enabled":[false,false,false,false,false,false,false,false,false,false]},"default_dtype":6,"deterministic_algorithms":false,"deterministic_algorithms_warn_only":false,"grad_mode":true,"num_threads":16,"torch_function":true,"torch_function_all_disabled":false}
+- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
+- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:794 in init_ambient_guards
| +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0), type=<class 'torch.Tensor'>, tag_safe=(True, False)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1]) # return x + 1 # repro_global_state.py:5 in fn
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # return x + 1 # repro_global_state.py:5 in fn
Pull Request resolved: pytorch#170070
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos1 parent d9f0090 commit 0da894f
2 files changed
+34
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
68 | 68 | | |
69 | 69 | | |
70 | 70 | | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
71 | 86 | | |
72 | 87 | | |
73 | 88 | | |
| |||
821 | 836 | | |
822 | 837 | | |
823 | 838 | | |
| 839 | + | |
| 840 | + | |
| 841 | + | |
| 842 | + | |
| 843 | + | |
| 844 | + | |
| 845 | + | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
824 | 853 | | |
825 | 854 | | |
826 | 855 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2553 | 2553 | | |
2554 | 2554 | | |
2555 | 2555 | | |
2556 | | - | |
2557 | | - | |
2558 | | - | |
| 2556 | + | |
| 2557 | + | |
| 2558 | + | |
| 2559 | + | |
| 2560 | + | |
2559 | 2561 | | |
2560 | 2562 | | |
2561 | 2563 | | |
| |||
0 commit comments