Skip to content

Commit c2280ca

Browse files
authored
【Hackathon 9th No.120】Convert torch._C._special.special_logit to torch.special.logit (#336)
* Convert torch._C._special.special_logit to torch.special.logit * update code
1 parent 526bca3 commit c2280ca

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,27 @@ def _impl_unstable_to_stable_fftn(self, gm):
126126

127127
return gm
128128

129+
def _impl_unstable_to_stable_special_logit(self, gm):
130+
"""
131+
Convert torch._C._special.special_logit to torch.special.logit
132+
"""
133+
issue_nodes = (
134+
node
135+
for node in gm.graph.nodes
136+
if node.op == "call_function"
137+
if hasattr(node.target, "__module__")
138+
if node.target.__module__ == "torch._C._special"
139+
if hasattr(node.target, "__name__")
140+
if node.target.__name__ == "special_logit"
141+
)
142+
for node in issue_nodes:
143+
node.target = torch.special.logit
144+
145+
# Recompile the graph
146+
gm.recompile()
147+
148+
return gm
149+
129150
def unstable_to_stable(self, gm):
130151
methods = (
131152
name

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2424
(r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("),
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
27+
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
2728
# Add new rules to this list as needed
2829
]
2930
for pattern, repl in replacements:

0 commit comments

Comments
 (0)