-
Notifications
You must be signed in to change notification settings - Fork 137
Change tolerance used to decide whether a constant is one in rewrite functions #1526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
to allow rewrites that would otherwise fail when the new and old dtype differ. Example: `np.array(1., "float64") - sigmoid(x)` cannot be rewritten as `sigmoid(-x)` (where x is an fmatrix) because the type would change. This commit allows an automatic cast to be added so the expression is rewritten as `cast(sigmoid(-x), "float64")`. Relevant tests added.
…tain dtype like MyType in the tests
…ion isclose, which uses 10 ULPs by default
pytensor/graph/rewriting/basic.py
Outdated
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype: | ||
ret = pytensor.tensor.basic.cast(ret, out_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all types have a dtype, we should check it's a TensorType before even trying to access dtype
and doing stuff with it. I would perhaps write like this:
The whole logic is weird though with the if ret.owner
, why do we care about the type of outputs we're not replacing. It's actually dangerous to try to replace only one of them without the user consent. Since this is WIP I would change to if len(node.outputs) != 1: return False
, before we try to unify.
Then here we just have to worry about the final else branch below:
[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
if not (
self.allow_cast
and isinstance(old_out.type, TensorType)
and isinstance(ret.type, TensorType)
):
return False
# Try to cast
ret = ret.astype(old_out.type.dtype)
if not old_out.type.is_super(ret.type):
return False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am happy to replace as you suggest but I am not sure how to fit it within the rest. This is the current code:
if ret.owner:
if not (
len(node.outputs) == len(ret.owner.outputs)
and all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
)
):
return False
else:
# ret is just an input variable
assert len(node.outputs) == 1
if not node.outputs[0].type.is_super(ret.type):
return False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you only need what I wrote, above, template something like this
def transform(...):
...
if node.op != self.op:
return False
if len(node.outputs) != 1:
# PatternNodeRewriter doesn't support replacing multi-output nodes
return False
...
if not self.allow_multiple_clients:
...
# New logic
[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
# Type doesn't match
if not (
self.allow_cast
and isinstance(old_out.type, TensorType)
and isinstance(ret.type, TensorType)
):
return False
# Try to cast tensors
ret = ret.astype(old_out.type.dtype)
if not old_out.type.is_super(ret.type):
# Still doesn't match
return False
return [ret]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure PatternNodeRewriter
is supposed to only work with single inputs? I get the following error:
def test_patternsub_different_output_lengths():
# Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
ps = PatternNodeRewriter(
(op1, "x"),
("x"),
name="ps",
)
rewriter = in2out(ps)
x = MyVariable("x")
e1, e2 = op_multiple_outputs(x)
o = op1(e1)
fgraph = FunctionGraph(inputs=[x], outputs=[o])
rewriter.rewrite(fgraph)
> assert fgraph.outputs[0].owner.op == op1
E assert OpMultipleOutputs == op1
E + where OpMultipleOutputs = OpMultipleOutputs(x).op
E + where OpMultipleOutputs(x) = OpMultipleOutputs.0.owner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that test makes sense. It's like saying you don't want to replace log(exp(x)
, if x
comes from a multi-output node. We usually don't care about the provenance of a root variable in a rewrite. Nothing in that rewrite cares about op_multiple_outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added that test, so let me see if I can dig up the rationale
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was here: https://github.com/aesara-devs/aesara/pull/803/files
The problem was before the zip would be shorter if node.outputs and replacement didn't match in length. But the whole thing goes away if you just say it doesn't support replacing multiple outputs nodes, which it doesn't really.
That test can be removed in favor of one where it refuses to replace OpMultipleOutputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. It sorts of makes sense to me but I know too little of the PyTensor internals to fully understand.
Can you propose a quick way to modify/replace the test with one where it refuses to replace OpMultipleOutputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you push your changes (if you haven't already), I can push the new test on top of it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have pushed all my changes.
tests/tensor/rewriting/test_math.py
Outdated
(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)), | ||
(np.array([[1.0]], "float64") - sigmoid(xd), sigmoid(-xd)), | ||
]: | ||
f = pytensor.function([x, xd], out, m, on_unused_input="ignore") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are not evaluating f, just rewrite it with rewrite_graph
, possibly including ("canonicalize", "stabilize", "specialize")
, or whatever is needed
tests/tensor/rewriting/test_math.py
Outdated
), "Expression:\n{}rewritten as:\n{}expected:\n{}".format( | ||
*( | ||
pytensor.dprint(expr, print_type=True, file="str") | ||
for expr in (out, f_outs, expected) | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do pytensor.dprint(tuple[Variable])
. If you want the rewritten, expected, which many times I do while writing these sort of tests we could add an assert_equal_computations
helper that does that. That way it's reusable and doesn't make each test very verbose like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree assert_equal_computations
might be useful (would that go into tests/unittest_tools.py ?). I think it would be useful to know the original expression too, what do yo think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like:
def assert_equal_computations(
rewritten,
expected,
*args,
original=None,
**kwargs
):
"""
Assert that `rewritten` computes the same as `expected`.
Parameters
----------
rewritten
The expression after the rewrite pass.
expected
The reference expression to compare against.
*args, **kwargs
Extra arguments forwarded to equal_computations.
original : optional
If given, will be printed in the error message.
"""
ok = equal_computations(rewritten, expected, *args, **kwargs)
if not ok:
parts = []
def _dprint(expr):
return pytensor.dprint(expr, print_type=True, file="str")
if original is not None:
parts.append(f"Original:\n{_dprint(original)}")
parts.append(f"Rewritten:\n{_dprint(rewritten)}")
parts.append(f"Expected:\n{_dprint(expected)}")
raise AssertionError("\n\n".join(parts))
return True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks lovely
pytensor/graph/rewriting/basic.py
Outdated
out_dtype = node.outputs[0].type.dtype | ||
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype: | ||
ret = pytensor.tensor.basic.cast(ret, out_dtype) | ||
if self.allow_cast: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was reviewing commit by commit, see you changed this after. Anyway my original comment still stands.
Generally, feel free to squash commits and force-push when iterating on a PR so the git changes stay clean
Description
The previous tolerance used within a rewrite to decide whether a constant is one (or minus one) is too large.$c=1 − p$ where p is 1 in 10000.
For example
c - sigmoid(x)
is rewritten assigmoid(-x)
even whenMany rewrites currently use np.isclose and np.allclose with the default tolerances (rtol=1e-05, atol=1e-08), which are unnecessarily large (and independent on the data type of the constant computed).
This PR implements a function
isclose
used within all rewrites in place ofnp.isclose
andnp.allclose
. This new function uses a much smaller tolerance by default, i.e. 10 unit in the last place (ULPs). This tolerance is dtype dependent, so it's stricter for a float64 than a float32. See #1497 for a back of the envelope justification for choosing 10 ULPs.This PR also implements
allow_cast
in PatternNodeRewriter to allow rewrites that would otherwise fail when the new and old dtype differ. For example, a rewrite attempt fornp.array(1., "float64") - sigmoid(x)
(where x isfmatrix
) currently fails because in the rewritesigmoid(-x)
the type would change. This PR allows an automatic cast to be added so the expression is rewritten ascast(sigmoid(-x), "float64")
.Relevant tests added.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1526.org.readthedocs.build/en/1526/