diff --git a/graph_net/torch/backend/unstable_to_stable_backend.py b/graph_net/torch/backend/unstable_to_stable_backend.py index 49cadfb3d..7bbf588fd 100644 --- a/graph_net/torch/backend/unstable_to_stable_backend.py +++ b/graph_net/torch/backend/unstable_to_stable_backend.py @@ -126,6 +126,27 @@ def _impl_unstable_to_stable_fftn(self, gm): return gm + def _impl_unstable_to_stable_special_logit(self, gm): + """ + Convert torch._C._special.special_logit to torch.special.logit + """ + issue_nodes = ( + node + for node in gm.graph.nodes + if node.op == "call_function" + if hasattr(node.target, "__module__") + if node.target.__module__ == "torch._C._special" + if hasattr(node.target, "__name__") + if node.target.__name__ == "special_logit" + ) + for node in issue_nodes: + node.target = torch.special.logit + + # Recompile the graph + gm.recompile() + + return gm + def unstable_to_stable(self, gm): methods = ( name diff --git a/graph_net/torch/fx_graph_serialize_util.py b/graph_net/torch/fx_graph_serialize_util.py index f89716d53..0814eef93 100644 --- a/graph_net/torch/fx_graph_serialize_util.py +++ b/graph_net/torch/fx_graph_serialize_util.py @@ -24,6 +24,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str: (r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("), (r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("), (r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("), + (r"torch\._C\._special\.special_logit\(", "torch.special.logit("), # Add new rules to this list as needed ] for pattern, repl in replacements: