Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hacked up #88

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .chain_matmul import call_replacement_chain_matmul
from .cholesky import call_replacement_cholesky
from .qr import call_replacement_qr
from .size_average import call_replacement_loss

from .range import call_replacement_range

Expand Down Expand Up @@ -54,6 +55,7 @@ def _call_replacement(
"torch.qr": call_replacement_qr,
"torch.cuda.amp.autocast": call_replacement_cuda_amp_autocast,
"torch.cpu.amp.autocast": call_replacement_cpu_amp_autocast,
"torch.nn.functional.soft_margin_loss": call_replacement_loss
}
replacement = None

Expand Down Expand Up @@ -103,7 +105,8 @@ def visit_Call(self, node) -> None:
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name is None:
return

self.deprecated_config["torch.nn.functional.soft_margin_loss"] = {}
self.deprecated_config["torch.nn.functional.soft_margin_loss"]["remove_pr"] = None
if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERRORS[1].error_code
Expand All @@ -112,7 +115,6 @@ def visit_Call(self, node) -> None:
error_code = self.ERRORS[0].error_code
message = self.ERRORS[0].message(old_name=qualified_name)
replacement = self._call_replacement(node, qualified_name)

reference = self.deprecated_config[qualified_name].get("reference")
if reference is not None:
message = f"{message}: {reference}"
Expand Down
60 changes: 60 additions & 0 deletions torchfix/visitors/deprecated_symbols/size_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""size_average and reduce are deprecated, please use reduction='mean' instead."""

import libcst as cst
from ...common import TorchVisitor, get_module_name
from torch.nn._reduction import legacy_get_string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now TorchFix doesn't depend on or use PyTorch.
Maybe we can add PyTorch to CI, but then which version?
Is it hard to copy or re-implement legacy_get_string?


def call_replacement_loss(node: cst.Call) -> cst.CSTNode:
"""
Replace loss function that contains size_average / reduce with a new loss function
that uses reduction='mean' instead. Uses the logic from torch.nn._reduction to
determine the correct reduction value.

Args:
node: The CST Call node representing the loss function call

Returns:
A new CST node with updated reduction parameter
"""
# Extract existing arguments
input_arg = TorchVisitor.get_specific_arg(node, "input", 0)
target_arg = TorchVisitor.get_specific_arg(node, "target", 1)

size_average_arg = TorchVisitor.get_specific_arg(node, "size_average", 2)
reduce_arg = TorchVisitor.get_specific_arg(node, "reduce", 3)

# Ensure input and target args maintain their commas
input_arg = cst.ensure_type(input_arg, cst.Arg).with_changes(
comma=cst.MaybeSentinel.DEFAULT
)

target_arg = cst.ensure_type(target_arg, cst.Arg).with_changes(
comma=cst.MaybeSentinel.DEFAULT
)

# Extract size_average and reduce values
size_average_value = None
reduce_value = None

if size_average_arg:
size_average_value = getattr(size_average_arg.value, "value", True)
if reduce_arg:
reduce_value = getattr(reduce_arg.value, "value", True)

if size_average_value is None and reduce_value is None:
# We want to return the original call as is
return node
# Use legacy_get_string to determine the correct reduction value
reduction = legacy_get_string(size_average_value, reduce_value, emit_warning=False)

# Create new reduction argument
reduction_arg = cst.Arg(
value=cst.SimpleString(f"'{reduction}'"),
keyword=cst.Name("reduction"),
comma=cst.MaybeSentinel.DEFAULT,
)

# Build new arguments list
new_args = [input_arg, target_arg, reduction_arg]
replacement = node.with_changes(args=new_args)
return replacement
Loading