Skip to content

Change tolerance used to decide whether a constant is one in rewrite functions #1526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

lciti
Copy link
Contributor

@lciti lciti commented Jul 7, 2025

Description

The previous tolerance used within a rewrite to decide whether a constant is one (or minus one) is too large.
For example c - sigmoid(x) is rewritten as sigmoid(-x) even when $c=1 − p$ where p is 1 in 10000.
Many rewrites currently use np.isclose and np.allclose with the default tolerances (rtol=1e-05, atol=1e-08), which are unnecessarily large (and independent on the data type of the constant computed).

This PR implements a function isclose used within all rewrites in place of np.isclose and np.allclose. This new function uses a much smaller tolerance by default, i.e. 10 unit in the last place (ULPs). This tolerance is dtype dependent, so it's stricter for a float64 than a float32. See #1497 for a back of the envelope justification for choosing 10 ULPs.

This PR also implements allow_cast in PatternNodeRewriter to allow rewrites that would otherwise fail when the new and old dtype differ. For example, a rewrite attempt for np.array(1., "float64") - sigmoid(x) (where x is fmatrix) currently fails because in the rewrite sigmoid(-x) the type would change. This PR allows an automatic cast to be added so the expression is rewritten as cast(sigmoid(-x), "float64").

Relevant tests added.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1526.org.readthedocs.build/en/1526/

Luca Citi added 4 commits July 7, 2025 15:25
to allow rewrites that would otherwise fail when the new and old dtype differ.
Example:
`np.array(1., "float64") - sigmoid(x)` cannot be rewritten as
`sigmoid(-x)` (where x is an fmatrix) because the type would change.
This commit allows an automatic cast to be added so the expression
is rewritten as `cast(sigmoid(-x), "float64")`.
Relevant tests added.
…ion isclose, which uses 10 ULPs by default
@lciti lciti changed the title Fix 1497 Fix 1497 - Change tolerance used to decide whether a constant is one in rewrite functions Jul 7, 2025
Comment on lines 1672 to 1673
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype:
ret = pytensor.tensor.basic.cast(ret, out_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all types have a dtype, we should check it's a TensorType before even trying to access dtype and doing stuff with it. I would perhaps write like this:

The whole logic is weird though with the if ret.owner, why do we care about the type of outputs we're not replacing. It's actually dangerous to try to replace only one of them without the user consent. Since this is WIP I would change to if len(node.outputs) != 1: return False, before we try to unify.

Then here we just have to worry about the final else branch below:

[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
  if not (
    self.allow_cast 
    and isinstance(old_out.type, TensorType) 
    and isinstance(ret.type, TensorType)
  ):
    return False

  # Try to cast
  ret = ret.astype(old_out.type.dtype)
  if not old_out.type.is_super(ret.type):
    return False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to replace as you suggest but I am not sure how to fit it within the rest. This is the current code:

        if ret.owner:
            if not (
                len(node.outputs) == len(ret.owner.outputs)
                and all(
                    o.type.is_super(new_o.type)
                    for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
                )
            ):
                return False
        else:
            # ret is just an input variable
            assert len(node.outputs) == 1
            if not node.outputs[0].type.is_super(ret.type):
                return False

Copy link
Member

@ricardoV94 ricardoV94 Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only need what I wrote, above, template something like this

def transform(...):

...

if node.op != self.op:
    return False

if len(node.outputs) != 1:
  # PatternNodeRewriter doesn't support replacing multi-output nodes  
  return False

...

if not self.allow_multiple_clients:
  ...


# New logic

[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
  # Type doesn't match
  if not (
    self.allow_cast 
    and isinstance(old_out.type, TensorType) 
    and isinstance(ret.type, TensorType)
  ):
    return False

  # Try to cast tensors
  ret = ret.astype(old_out.type.dtype)
  if not old_out.type.is_super(ret.type):
    # Still doesn't match
    return False

return [ret]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure PatternNodeRewriter is supposed to only work with single inputs? I get the following error:

    def test_patternsub_different_output_lengths():
        # Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
        ps = PatternNodeRewriter(
            (op1, "x"),
            ("x"),
            name="ps",
        )
        rewriter = in2out(ps)
    
        x = MyVariable("x")
        e1, e2 = op_multiple_outputs(x)
        o = op1(e1)
    
        fgraph = FunctionGraph(inputs=[x], outputs=[o])
        rewriter.rewrite(fgraph)
>       assert fgraph.outputs[0].owner.op == op1
E       assert OpMultipleOutputs == op1
E        +  where OpMultipleOutputs = OpMultipleOutputs(x).op
E        +    where OpMultipleOutputs(x) = OpMultipleOutputs.0.owner

Copy link
Member

@ricardoV94 ricardoV94 Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that test makes sense. It's like saying you don't want to replace log(exp(x), if x comes from a multi-output node. We usually don't care about the provenance of a root variable in a rewrite. Nothing in that rewrite cares about op_multiple_outputs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added that test, so let me see if I can dig up the rationale

Copy link
Member

@ricardoV94 ricardoV94 Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was here: https://github.com/aesara-devs/aesara/pull/803/files

The problem was before the zip would be shorter if node.outputs and replacement didn't match in length. But the whole thing goes away if you just say it doesn't support replacing multiple outputs nodes, which it doesn't really.

That test can be removed in favor of one where it refuses to replace OpMultipleOutputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. It sorts of makes sense to me but I know too little of the PyTensor internals to fully understand.
Can you propose a quick way to modify/replace the test with one where it refuses to replace OpMultipleOutputs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you push your changes (if you haven't already), I can push the new test on top of it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have pushed all my changes.

(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)),
(np.array([[1.0]], "float64") - sigmoid(xd), sigmoid(-xd)),
]:
f = pytensor.function([x, xd], out, m, on_unused_input="ignore")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are not evaluating f, just rewrite it with rewrite_graph, possibly including ("canonicalize", "stabilize", "specialize"), or whatever is needed

Comment on lines 4144 to 4149
), "Expression:\n{}rewritten as:\n{}expected:\n{}".format(
*(
pytensor.dprint(expr, print_type=True, file="str")
for expr in (out, f_outs, expected)
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do pytensor.dprint(tuple[Variable]). If you want the rewritten, expected, which many times I do while writing these sort of tests we could add an assert_equal_computations helper that does that. That way it's reusable and doesn't make each test very verbose like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree assert_equal_computations might be useful (would that go into tests/unittest_tools.py ?). I think it would be useful to know the original expression too, what do yo think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like:

def assert_equal_computations(
    rewritten,
    expected,
    *args,
    original=None,
    **kwargs
):
    """
    Assert that `rewritten` computes the same as `expected`.
    
    Parameters
    ----------
    rewritten
        The expression after the rewrite pass.
    expected
        The reference expression to compare against.
    *args, **kwargs
        Extra arguments forwarded to equal_computations.
    original : optional
        If given, will be printed in the error message.
    """
    ok = equal_computations(rewritten, expected, *args, **kwargs)

    if not ok:
        parts = []

        def _dprint(expr):
            return pytensor.dprint(expr, print_type=True, file="str")

        if original is not None:
            parts.append(f"Original:\n{_dprint(original)}")
        parts.append(f"Rewritten:\n{_dprint(rewritten)}")
        parts.append(f"Expected:\n{_dprint(expected)}")

        raise AssertionError("\n\n".join(parts))

    return True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks lovely

out_dtype = node.outputs[0].type.dtype
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype:
ret = pytensor.tensor.basic.cast(ret, out_dtype)
if self.allow_cast:
Copy link
Member

@ricardoV94 ricardoV94 Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was reviewing commit by commit, see you changed this after. Anyway my original comment still stands.

Generally, feel free to squash commits and force-push when iterating on a PR so the git changes stay clean

@ricardoV94 ricardoV94 changed the title Fix 1497 - Change tolerance used to decide whether a constant is one in rewrite functions Change tolerance used to decide whether a constant is one in rewrite functions Jul 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: the tolerance used to decide whether a constant is one (or minus one) in rewrite functions may be too large
2 participants