Skip to content

Commit 0da894f

Browse files
trichmopytorchmergebot
authored andcommitted
Log global state (pytorch#170070)
Fixes pytorch#166268 ### Problem The guard logging currently prints only the check function for the GLOBAL_STATE guard. From the logs, it is unclear what the global state properties the check function has captured to compare against. ### Solution The current verbose_code contained only the string shown in the log. Added a call to the guard manager __getstate__ to return a json structured string of the captured global state. The result is a logged check call function signature (preserved) along with the new additional global state used in all comparison checks. ### Testing Added a logging check to verify that the global state attributes are logged with munged keys. ### Comparative Change ``` import torch @torch.compile def fn(x): return x + 1 fn(torch.ones(3, 3)) ``` Previously the above code produced: > TREE_GUARD_MANAGER: +- RootGuardManager | +- LAMBDA_GUARD: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # _dynamo/output_graph.py:807 in init_ambient_guards | +- GLOBAL_STATE: ___check_global_state() | +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack() | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:794 in init_ambient_guards | +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0), type=<class 'torch.Tensor'>, tag_safe=(True, False) | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1]) # return x + 1 # repro_global_state.py:5 in fn | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # return x + 1 # repro_global_state.py:5 in fn With the proposed change it produces: > TREE_GUARD_MANAGER: +- RootGuardManager | +- LAMBDA_GUARD: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # _dynamo/output_graph.py:807 in init_ambient_guards | +- GLOBAL_STATE: ___check_global_state() against {"allow_bf16_reduce":0,"allow_fp16_reduce":0,"allow_tf32":false,"autocast_state":{"cached_enabled":true,"dtype":[15,5,5,15,5,5,15,15,5,5],"enabled":[false,false,false,false,false,false,false,false,false,false]},"default_dtype":6,"deterministic_algorithms":false,"deterministic_algorithms_warn_only":false,"grad_mode":true,"num_threads":16,"torch_function":true,"torch_function_all_disabled":false} +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack() +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:794 in init_ambient_guards | +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0), type=<class 'torch.Tensor'>, tag_safe=(True, False) | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1]) # return x + 1 # repro_global_state.py:5 in fn | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # return x + 1 # repro_global_state.py:5 in fn Pull Request resolved: pytorch#170070 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
1 parent d9f0090 commit 0da894f

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

test/dynamo/test_logging.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ def munge(s):
6868
return "\n".join([line for line, nsubs in lines if nsubs > 0])
6969

7070

71+
def munge_global_state_json(text):
72+
import re
73+
74+
match = re.search(r"\+- GLOBAL_STATE:.*", text)
75+
if not match:
76+
return ""
77+
78+
line = match.group(0)
79+
while "[" in line:
80+
line = re.sub(r"\[[^\[\]]*\]", '"#"', line)
81+
82+
line = re.sub(r':\s*(\d+|true|false|"[^"]*")', r': "#"', line)
83+
return line
84+
85+
7186
LOG_PREFIX_PATTERNS = [
7287
re.compile(r"^\[rank\d+\]:\s*"),
7388
re.compile(r"^[A-Z]+:[^:]+:\s*"),
@@ -821,6 +836,20 @@ def f(x, y):
821836
+- __SHAPE_GUARD__: 3 <= L['y'].size()[0] <= 14 # torch._check(x.size(0) > 5) # #:# in # #:# in # and torch._check(x.size(0) < 30) # #:# in # #:# in #""", # noqa: B950
822837
)
823838

839+
@make_logging_test(guards=True)
840+
def test_global_state_guard_logging(self, records):
841+
@torch.compile(backend="eager")
842+
def f(x):
843+
return x + 1
844+
845+
f(torch.randn(3))
846+
847+
record = self.getRecord(records, "TREE_GUARD_MANAGER")
848+
self.assertExpectedInline(
849+
munge_global_state_json(record.getMessage()),
850+
"""+- GLOBAL_STATE: ___check_global_state() against {"allow_bf16_reduce": "#","allow_fp16_reduce": "#","allow_tf32": "#","autocast_state":{"cached_enabled": "#","dtype": "#","enabled": "#"},"default_dtype": "#","deterministic_algorithms": "#","deterministic_algorithms_warn_only": "#","grad_mode": "#","num_threads": "#","torch_function": "#","torch_function_all_disabled": "#"}""", # noqa: B950
851+
)
852+
824853
@make_logging_test(cudagraph_static_inputs=True)
825854
def test_cudagraph_static_inputs(self, records):
826855
@torch.compile(mode="reduce-overhead")

torch/_dynamo/guards.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2553,9 +2553,11 @@ def GLOBAL_STATE(self, guard: Guard) -> None:
25532553
assert output_graph is not None
25542554
global_state = output_graph.global_state_guard
25552555
self.check_fn_manager.global_state = global_state
2556-
self.guard_manager.root.add_global_state_guard(
2557-
global_state, ["___check_global_state()"]
2558-
)
2556+
code = [
2557+
f"___check_global_state() against {self.check_fn_manager.global_state.__getstate__()}"
2558+
]
2559+
2560+
self.guard_manager.root.add_global_state_guard(global_state, code)
25592561

25602562
def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
25612563
assert self.check_fn_manager.torch_function_mode_stack is not None

0 commit comments

Comments
 (0)