Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8031d33

Browse files
ricardoV94aseyboldt
andcommittedFeb 14, 2025·
Compute pushforward via double application of pullback
Also fixes bug in Scan L_op and Max R_op Co-authored-by: Adrian Seyboldt <aseyboldt@users.noreply.github.com>
1 parent 1c88f72 commit 8031d33

File tree

11 files changed

+297
-162
lines changed

11 files changed

+297
-162
lines changed
 

‎doc/extending/op.rst‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`.
506506
the outputs) back to their corresponding shapes and return them as the
507507
output of the :meth:`Op.R_op` method.
508508

509-
:ref:`List of op with r op support <R_op_list>`.

‎doc/library/gradient.rst‎

Lines changed: 0 additions & 76 deletions
This file was deleted.

‎doc/library/tensor/basic.rst‎

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,5 +1791,3 @@ Gradient / Differentiation
17911791
:members: grad
17921792
:noindex:
17931793

1794-
See the :ref:`gradient <libdoc_gradient>` page for complete documentation
1795-
of the gradient module.

‎doc/tutorial/gradients.rst‎

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ of symbolic differentiation).
8686
``i`` of the output list is the gradient of the first argument of
8787
`pt.grad` with respect to the ``i``-th element of the list given as second argument.
8888
The first argument of `pt.grad` has to be a scalar (a tensor
89-
of size 1). For more information on the semantics of the arguments of
90-
`pt.grad` and details about the implementation, see
91-
:ref:`this<libdoc_gradient>` section of the library.
89+
of size 1).
9290

9391
Additional information on the inner workings of differentiation may also be
9492
found in the more advanced tutorial :ref:`Extending PyTensor<extending>`.
@@ -204,7 +202,21 @@ you need to do something similar to this:
204202
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
205203
array([ 2., 2.])
206204

207-
:ref:`List <R_op_list>` of Op that implement Rop.
205+
By default, the R-operator is implemented as a double application of the L_operator
206+
(see `reference <https://j-towns.github.io/2017/06/12/A-new-trick.html>`_).
207+
In most cases this should be as performant as a specialized implementation of the R-operator.
208+
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
209+
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.
210+
211+
When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
212+
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.
213+
214+
215+
>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True)
216+
>>> f = pytensor.function([W, V, x], JV)
217+
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
218+
array([ 2., 2.])
219+
208220

209221
L-operator
210222
----------
@@ -234,7 +246,6 @@ array([[ 0., 0.],
234246
as the input parameter, while the result of the R-operator has a shape similar
235247
to that of the output.
236248

237-
:ref:`List of op with r op support <R_op_list>`.
238249

239250
Hessian times a Vector
240251
======================

‎pytensor/compile/builders.py‎

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,12 @@ def __init__(
340340
``None``, this will be used as the connection_pattern for this
341341
:class:`Op`.
342342
343+
.. warning::
344+
345+
rop overrides is ignored when `pytensor.gradient.Rop` is called with
346+
`use_op_rop_implementation=False` (default). In this case the Lop
347+
is used twice to obtain a mathematically equivalent Rop.
348+
343349
strict: bool, default False
344350
If true, it raises when any variables needed to compute the inner graph
345351
are not provided as explici inputs. This can only happen for graphs with
@@ -641,7 +647,12 @@ def _build_and_cache_rop_op(self):
641647
return rop_overrides
642648

643649
eval_points = [inp_t() for inp_t in self.input_types]
644-
fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points)
650+
fn_rop = partial(
651+
Rop,
652+
wrt=inner_inputs,
653+
eval_points=eval_points,
654+
use_op_rop_implementation=True,
655+
)
645656

646657
callable_args = (inner_inputs, eval_points)
647658
if rop_overrides is None:

‎pytensor/gradient.py‎

Lines changed: 160 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,50 @@ def __str__(self):
142142
disconnected_type = DisconnectedType()
143143

144144

145-
def Rop(
146-
f: Variable | Sequence[Variable],
147-
wrt: Variable | Sequence[Variable],
148-
eval_points: Variable | Sequence[Variable],
145+
def pushforward_through_pullback(
146+
outputs: Sequence[Variable],
147+
inputs: Sequence[Variable],
148+
tangents: Sequence[Variable],
149149
disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise",
150150
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
151-
) -> Variable | None | Sequence[Variable | None]:
151+
) -> Sequence[Variable | None]:
152+
"""Compute the pushforward (Rop) through two applications of a pullback (Lop) operation.
153+
154+
References
155+
----------
156+
.. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017.
157+
Available: https://j-towns.github.io/2017/06/12/A-new-trick.html
158+
159+
"""
160+
# Cotangents are just auxiliary variables that should be pruned from the final graph,
161+
# but that would require a graph rewrite before the user tries to compile a pytensor function.
162+
# To avoid trouble we use .zeros_like() instead of .type(), which does not create a new root variable.
163+
cotangents = [out.zeros_like(dtype=config.floatX) for out in outputs] # type: ignore
164+
165+
input_cotangents = Lop(
166+
f=outputs,
167+
wrt=inputs,
168+
eval_points=cotangents,
169+
disconnected_inputs=disconnected_outputs,
170+
return_disconnected="zero",
171+
)
172+
173+
return Lop(
174+
f=input_cotangents, # type: ignore
175+
wrt=cotangents,
176+
eval_points=tangents,
177+
disconnected_inputs="ignore",
178+
return_disconnected=return_disconnected,
179+
)
180+
181+
182+
def _rop_legacy(
183+
f: Sequence[Variable],
184+
wrt: Sequence[Variable],
185+
eval_points: Sequence[Variable],
186+
disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise",
187+
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
188+
) -> Sequence[Variable | None]:
152189
"""Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
153190
154191
Mathematically this stands for the Jacobian of `f` right multiplied by the
@@ -190,38 +227,6 @@ def Rop(
190227
If `f` is a list/tuple, then return a list/tuple with the results.
191228
"""
192229

193-
if not isinstance(wrt, list | tuple):
194-
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
195-
else:
196-
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
197-
198-
if not isinstance(eval_points, list | tuple):
199-
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
200-
else:
201-
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
202-
203-
if not isinstance(f, list | tuple):
204-
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
205-
else:
206-
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
207-
208-
if len(_wrt) != len(_eval_points):
209-
raise ValueError("`wrt` must be the same length as `eval_points`.")
210-
211-
# Check that each element of wrt corresponds to an element
212-
# of eval_points with the same dimensionality.
213-
for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)):
214-
try:
215-
if wrt_elem.type.ndim != eval_point.type.ndim:
216-
raise ValueError(
217-
f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: "
218-
f"{wrt_elem.type.ndim} and {eval_point.type.ndim}"
219-
)
220-
except AttributeError:
221-
# wrt_elem and eval_point don't always have ndim like random type
222-
# Tensor, Sparse have the ndim attribute
223-
pass
224-
225230
seen_nodes: dict[Apply, Sequence[Variable]] = {}
226231

227232
def _traverse(node):
@@ -237,8 +242,8 @@ def _traverse(node):
237242
# inputs of the node
238243
local_eval_points = []
239244
for inp in inputs:
240-
if inp in _wrt:
241-
local_eval_points.append(_eval_points[_wrt.index(inp)])
245+
if inp in wrt:
246+
local_eval_points.append(eval_points[wrt.index(inp)])
242247
elif inp.owner is None:
243248
try:
244249
local_eval_points.append(inp.zeros_like())
@@ -292,13 +297,13 @@ def _traverse(node):
292297
# end _traverse
293298

294299
# Populate the dictionary
295-
for out in _f:
300+
for out in f:
296301
_traverse(out.owner)
297302

298303
rval: list[Variable | None] = []
299-
for out in _f:
300-
if out in _wrt:
301-
rval.append(_eval_points[_wrt.index(out)])
304+
for out in f:
305+
if out in wrt:
306+
rval.append(eval_points[wrt.index(out)])
302307
elif (
303308
seen_nodes.get(out.owner, None) is None
304309
or seen_nodes[out.owner][out.owner.outputs.index(out)] is None
@@ -337,6 +342,116 @@ def _traverse(node):
337342
else:
338343
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)])
339344

345+
return rval
346+
347+
348+
def Rop(
349+
f: Variable | Sequence[Variable],
350+
wrt: Variable | Sequence[Variable],
351+
eval_points: Variable | Sequence[Variable],
352+
disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise",
353+
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
354+
use_op_rop_implementation: bool = False,
355+
) -> Variable | None | Sequence[Variable | None]:
356+
"""Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
357+
358+
Mathematically this stands for the Jacobian of `f` right multiplied by the
359+
`eval_points`.
360+
361+
By default, the R-operator is implemented as a double application of the L_operator [1]_.
362+
In most cases this should be as performant as a specialized implementation of the R-operator.
363+
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
364+
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.
365+
366+
When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
367+
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.
368+
369+
Parameters
370+
----------
371+
f
372+
The outputs of the computational graph to which the R-operator is
373+
applied.
374+
wrt
375+
Variables for which the R-operator of `f` is computed.
376+
eval_points
377+
Points at which to evaluate each of the variables in `wrt`.
378+
disconnected_outputs
379+
Defines the behaviour if some of the variables in `f`
380+
have no dependency on any of the variable in `wrt` (or if
381+
all links are non-differentiable). The possible values are:
382+
383+
- ``'ignore'``: considers that the gradient on these parameters is zero.
384+
- ``'warn'``: consider the gradient zero, and print a warning.
385+
- ``'raise'``: raise `DisconnectedInputError`.
386+
387+
return_disconnected
388+
- ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
389+
``wrt[i].zeros_like()``.
390+
- ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
391+
``None``
392+
- ``'disconnected'`` : returns variables of type `DisconnectedType`
393+
use_op_lop_implementation: bool, default=True
394+
If `True`, we obtain Rop via double application of Lop.
395+
If `False`, the legacy Rop implementation is used. The number of graphs that support this form
396+
is much more restricted, and the generated graphs may be less optimized.
397+
398+
Returns
399+
-------
400+
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
401+
A symbolic expression such obeying
402+
``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
403+
where the indices in that expression are magic multidimensional
404+
indices that specify both the position within a list and all
405+
coordinates of the tensor elements.
406+
If `f` is a list/tuple, then return a list/tuple with the results.
407+
408+
References
409+
----------
410+
.. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017.
411+
Available: https://j-towns.github.io/2017/06/12/A-new-trick.html
412+
"""
413+
414+
if not isinstance(wrt, list | tuple):
415+
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
416+
else:
417+
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
418+
419+
if not isinstance(eval_points, list | tuple):
420+
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
421+
else:
422+
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
423+
424+
if not isinstance(f, list | tuple):
425+
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
426+
else:
427+
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
428+
429+
if len(_wrt) != len(_eval_points):
430+
raise ValueError("`wrt` must be the same length as `eval_points`.")
431+
432+
# Check that each element of wrt corresponds to an element
433+
# of eval_points with the same dimensionality.
434+
for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)):
435+
try:
436+
if wrt_elem.type.ndim != eval_point.type.ndim:
437+
raise ValueError(
438+
f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: "
439+
f"{wrt_elem.type.ndim} and {eval_point.type.ndim}"
440+
)
441+
except AttributeError:
442+
# wrt_elem and eval_point don't always have ndim like random type
443+
# Tensor, Sparse have the ndim attribute
444+
pass
445+
446+
if use_op_rop_implementation:
447+
rval = _rop_legacy(
448+
_f, _wrt, _eval_points, disconnected_outputs, return_disconnected
449+
)
450+
else:
451+
rval = pushforward_through_pullback(
452+
_f, _wrt, _eval_points, disconnected_outputs, return_disconnected
453+
)
454+
340455
using_list = isinstance(f, list)
341456
using_tuple = isinstance(f, tuple)
342457
return as_list_or_tuple(using_list, using_tuple, rval)
@@ -348,6 +463,7 @@ def Lop(
348463
eval_points: Variable | Sequence[Variable],
349464
consider_constant: Sequence[Variable] | None = None,
350465
disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise",
466+
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
351467
) -> Variable | None | Sequence[Variable | None]:
352468
"""Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`.
353469
@@ -404,6 +520,7 @@ def Lop(
404520
consider_constant=consider_constant,
405521
wrt=_wrt,
406522
disconnected_inputs=disconnected_inputs,
523+
return_disconnected=return_disconnected,
407524
)
408525

409526
using_list = isinstance(wrt, list)

‎pytensor/scan/op.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3165,7 +3165,12 @@ def R_op(self, inputs, eval_points):
31653165
rop_self_outputs = self_outputs
31663166
if info.n_shared_outs > 0:
31673167
rop_self_outputs = rop_self_outputs[: -info.n_shared_outs]
3168-
rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
3168+
rop_outs = Rop(
3169+
rop_self_outputs,
3170+
rop_of_inputs,
3171+
inner_eval_points,
3172+
use_op_rop_implementation=True,
3173+
)
31693174
if not isinstance(rop_outs, list | tuple):
31703175
rop_outs = [rop_outs]
31713176
# Step 2. Figure out what corresponds to what in the scan

‎tests/compile/test_builders.py‎

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ def lop_ov(inps, outs, grads):
306306
@pytest.mark.parametrize(
307307
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
308308
)
309-
def test_rop(self, cls_ofg):
309+
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
310+
def test_rop(self, cls_ofg, use_op_rop_implementation):
310311
a = vector()
311312
M = matrix()
312313
b = dot(a, M)
@@ -315,7 +316,7 @@ def test_rop(self, cls_ofg):
315316
W = matrix()
316317
y = op_matmul(x, W)
317318
du = vector()
318-
dv = Rop(y, x, du)
319+
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
319320
fn = function([x, W, du], dv)
320321
xval = np.random.random((16,)).astype(config.floatX)
321322
Wval = np.random.random((16, 16)).astype(config.floatX)
@@ -324,7 +325,8 @@ def test_rop(self, cls_ofg):
324325
dvval2 = fn(xval, Wval, duval)
325326
np.testing.assert_array_almost_equal(dvval2, dvval, 4)
326327

327-
def test_rop_multiple_outputs(self):
328+
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
329+
def test_rop_multiple_outputs(self, use_op_rop_implementation):
328330
a = vector()
329331
M = matrix()
330332
b = dot(a, M)
@@ -339,21 +341,21 @@ def test_rop_multiple_outputs(self):
339341
duval = np.random.random((16,)).astype(config.floatX)
340342

341343
y = op_matmul(x, W)[0]
342-
dv = Rop(y, x, du)
344+
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
343345
fn = function([x, W, du], dv)
344346
result_dvval = fn(xval, Wval, duval)
345347
expected_dvval = np.dot(duval, Wval)
346348
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
347349

348350
y = op_matmul(x, W)[1]
349-
dv = Rop(y, x, du)
351+
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
350352
fn = function([x, W, du], dv)
351353
result_dvval = fn(xval, Wval, duval)
352354
expected_dvval = -np.dot(duval, Wval)
353355
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
354356

355357
y = pt.add(*op_matmul(x, W))
356-
dv = Rop(y, x, du)
358+
dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation)
357359
fn = function([x, W, du], dv)
358360
result_dvval = fn(xval, Wval, duval)
359361
expected_dvval = np.zeros_like(np.dot(duval, Wval))
@@ -362,7 +364,16 @@ def test_rop_multiple_outputs(self):
362364
@pytest.mark.parametrize(
363365
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
364366
)
365-
def test_rop_override(self, cls_ofg):
367+
@pytest.mark.parametrize(
368+
"use_op_rop_implementation",
369+
[
370+
True,
371+
pytest.param(
372+
False, marks=pytest.mark.xfail(reason="Custom ROp is ignored")
373+
),
374+
],
375+
)
376+
def test_rop_override(self, cls_ofg, use_op_rop_implementation):
366377
x, y = vectors("xy")
367378

368379
def ro(inps, epts):
@@ -380,7 +391,12 @@ def ro(inps, epts):
380391
du, dv = vector("du"), vector("dv")
381392
for op in [op_mul, op_mul2]:
382393
zz = op_mul(xx, yy)
383-
dw = Rop(zz, [xx, yy], [du, dv])
394+
dw = Rop(
395+
zz,
396+
[xx, yy],
397+
[du, dv],
398+
use_op_rop_implementation=use_op_rop_implementation,
399+
)
384400
fn = function([xx, yy, du, dv], dw)
385401
vals = np.random.random((4, 32)).astype(config.floatX)
386402
dwval = fn(*vals)

‎tests/scan/test_basic.py‎

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1922,7 +1922,8 @@ def inner_fn():
19221922
fgrad = function([], g_sh)
19231923
assert fgrad() == 1
19241924

1925-
def test_R_op(self):
1925+
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
1926+
def test_R_op(self, use_op_rop_implementation):
19261927
seed = utt.fetch_seed()
19271928
rng = np.random.default_rng(seed)
19281929
floatX = config.floatX
@@ -1957,9 +1958,9 @@ def rnn_fn(_u, _y, _W):
19571958
eh0 = vector("eh0")
19581959
eW = matrix("eW")
19591960

1960-
nwo_u = Rop(o, _u, eu)
1961-
nwo_h0 = Rop(o, _h0, eh0)
1962-
nwo_W = Rop(o, _W, eW)
1961+
nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation)
1962+
nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation)
1963+
nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation)
19631964
fn_rop = function(
19641965
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore"
19651966
)
@@ -1997,7 +1998,8 @@ def rnn_fn(_u, _y, _W):
19971998
np.testing.assert_allclose(vnW, tnW, atol=1e-6)
19981999

19992000
@pytest.mark.slow
2000-
def test_R_op_2(self):
2001+
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
2002+
def test_R_op_2(self, use_op_rop_implementation):
20012003
seed = utt.fetch_seed()
20022004
rng = np.random.default_rng(seed)
20032005
floatX = config.floatX
@@ -2040,9 +2042,9 @@ def rnn_fn(_u, _y, _W):
20402042
eh0 = vector("eh0")
20412043
eW = matrix("eW")
20422044

2043-
nwo_u = Rop(o, _u, eu)
2044-
nwo_h0 = Rop(o, _h0, eh0)
2045-
nwo_W = Rop(o, _W, eW)
2045+
nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation)
2046+
nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation)
2047+
nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation)
20462048
fn_rop = function(
20472049
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W, o], on_unused_input="ignore"
20482050
)
@@ -2078,7 +2080,8 @@ def rnn_fn(_u, _y, _W):
20782080
np.testing.assert_allclose(vnh0, tnh0, atol=1e-6)
20792081
np.testing.assert_allclose(vnW, tnW, atol=2e-6)
20802082

2081-
def test_R_op_mitmot(self):
2083+
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
2084+
def test_R_op_mitmot(self, use_op_rop_implementation):
20822085
# this test is a copy paste from the script given by Justin Bayer to
20832086
# reproduce this bug
20842087
# We have 2 parameter groups with the following shapes.
@@ -2126,7 +2129,12 @@ def test_R_op_mitmot(self):
21262129

21272130
p = dvector()
21282131
# TODO: We should test something about the Rop!
2129-
Rop(d_cost_wrt_pars, pars, p)
2132+
Rop(
2133+
d_cost_wrt_pars,
2134+
pars,
2135+
p,
2136+
use_op_rop_implementation=use_op_rop_implementation,
2137+
)
21302138

21312139
def test_second_derivative_disconnected_cost_with_mit_mot(self):
21322140
# This test is a regression test for a bug that was revealed

‎tests/tensor/rewriting/test_linalg.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@ def test_matrix_inverse_rop_lop():
4949
v = vector("v")
5050
y = MatrixInverse()(mx).sum(axis=0)
5151

52-
yv = pytensor.gradient.Rop(y, mx, mv)
52+
yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True)
5353
rop_f = function([mx, mv], yv)
5454

55+
yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False)
56+
rop_via_lop_f = function([mx, mv], yv_via_lop)
57+
5558
sy, _ = pytensor.scan(
5659
lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(),
5760
sequences=pt.arange(y.shape[0]),
@@ -65,6 +68,7 @@ def test_matrix_inverse_rop_lop():
6568

6669
v_ref = scan_f(vx, vv)
6770
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
71+
np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol)
6872

6973
with pytest.raises(ValueError):
7074
pytensor.gradient.Rop(

‎tests/test_rop.py‎

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def check_nondiff_rop(self, y, x, v):
8888
test that an error is raised.
8989
"""
9090
with pytest.raises(ValueError):
91-
Rop(y, x, v)
91+
Rop(y, x, v, use_op_rop_implementation=True)
9292

9393
def check_mat_rop_lop(self, y, out_shape):
9494
"""
@@ -116,8 +116,14 @@ def check_mat_rop_lop(self, y, out_shape):
116116
vv = np.asarray(
117117
self.rng.uniform(size=self.mat_in_shape), pytensor.config.floatX
118118
)
119-
yv = Rop(y, self.mx, self.mv)
119+
yv = Rop(y, self.mx, self.mv, use_op_rop_implementation=True)
120120
rop_f = function([self.mx, self.mv], yv, on_unused_input="ignore")
121+
122+
yv_through_lop = Rop(y, self.mx, self.mv, use_op_rop_implementation=False)
123+
rop_through_lop_f = function(
124+
[self.mx, self.mv], yv_through_lop, on_unused_input="ignore"
125+
)
126+
121127
sy, _ = pytensor.scan(
122128
lambda i, y, x, v: (grad(y[i], x) * v).sum(),
123129
sequences=pt.arange(y.shape[0]),
@@ -127,6 +133,7 @@ def check_mat_rop_lop(self, y, out_shape):
127133

128134
v_ref = scan_f(vx, vv)
129135
np.testing.assert_allclose(rop_f(vx, vv), v_ref)
136+
np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref)
130137

131138
self.check_nondiff_rop(
132139
pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}),
@@ -156,8 +163,14 @@ def check_rop_lop(self, y, out_shape, check_nondiff_rop: bool = True):
156163
vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
157164
vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
158165

159-
yv = Rop(y, self.x, self.v)
166+
yv = Rop(y, self.x, self.v, use_op_rop_implementation=True)
160167
rop_f = function([self.x, self.v], yv, on_unused_input="ignore")
168+
169+
yv_through_lop = Rop(y, self.x, self.v, use_op_rop_implementation=False)
170+
rop_through_lop_f = function(
171+
[self.x, self.v], yv_through_lop, on_unused_input="ignore"
172+
)
173+
161174
J, _ = pytensor.scan(
162175
lambda i, y, x: grad(y[i], x),
163176
sequences=pt.arange(y.shape[0]),
@@ -168,6 +181,7 @@ def check_rop_lop(self, y, out_shape, check_nondiff_rop: bool = True):
168181

169182
v_ref = scan_f(vx, vv)
170183
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
184+
np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref, rtol=rtol)
171185

172186
if check_nondiff_rop:
173187
self.check_nondiff_rop(
@@ -255,12 +269,12 @@ def test_dot(self):
255269
insh = self.in_shape[0]
256270
vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX)
257271
W = pytensor.shared(vW)
258-
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
272+
# check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths
259273
# See: test_Rop_partially_differentiable_paths
260274
self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False)
261275

262276
def test_elemwise0(self):
263-
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
277+
# check_nondiff_rop reveals an error in how legacy Rop handles non-differentiable paths
264278
# See: test_Rop_partially_differentiable_paths
265279
self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False)
266280

@@ -294,11 +308,18 @@ def test_alloc(self):
294308
self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0],
295309
)
296310

297-
def test_invalid_input(self):
311+
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
312+
def test_invalid_input(self, use_op_rop_implementation):
298313
with pytest.raises(ValueError):
299-
Rop(0.0, [matrix()], [vector()])
314+
Rop(
315+
0.0,
316+
[matrix()],
317+
[vector()],
318+
use_op_rop_implementation=use_op_rop_implementation,
319+
)
300320

301-
def test_multiple_outputs(self):
321+
@pytest.mark.parametrize("use_op_rop_implementation", [True, False])
322+
def test_multiple_outputs(self, use_op_rop_implementation):
302323
m = matrix("m")
303324
v = vector("v")
304325
m_ = matrix("m_")
@@ -309,10 +330,20 @@ def test_multiple_outputs(self):
309330
m_val = self.rng.uniform(size=(3, 7)).astype(pytensor.config.floatX)
310331
v_val = self.rng.uniform(size=(7,)).astype(pytensor.config.floatX)
311332

312-
rop_out1 = Rop([m, v, m + v], [m, v], [m_, v_])
333+
rop_out1 = Rop(
334+
[m, v, m + v],
335+
[m, v],
336+
[m_, v_],
337+
use_op_rop_implementation=use_op_rop_implementation,
338+
)
313339
assert isinstance(rop_out1, list)
314340
assert len(rop_out1) == 3
315-
rop_out2 = Rop((m, v, m + v), [m, v], [m_, v_])
341+
rop_out2 = Rop(
342+
(m, v, m + v),
343+
[m, v],
344+
[m_, v_],
345+
use_op_rop_implementation=use_op_rop_implementation,
346+
)
316347
assert isinstance(rop_out2, tuple)
317348
assert len(rop_out2) == 3
318349

@@ -322,8 +353,11 @@ def test_multiple_outputs(self):
322353
f = pytensor.function([m, v, m_, v_], all_outs)
323354
f(mval, vval, m_val, v_val)
324355

325-
@pytest.mark.xfail()
326-
def test_Rop_partially_differentiable_paths(self):
356+
@pytest.mark.parametrize(
357+
"use_op_rop_implementation",
358+
[pytest.param(True, marks=pytest.mark.xfail()), False],
359+
)
360+
def test_Rop_partially_differentiable_paths(self, use_op_rop_implementation):
327361
# This test refers to a bug reported by Jeremiah Lowin on 18th Oct
328362
# 2013. The bug consists when through a dot operation there is only
329363
# one differentiable path (i.e. there is no gradient wrt to one of
@@ -336,7 +370,12 @@ def test_Rop_partially_differentiable_paths(self):
336370
grad(d, v),
337371
v,
338372
v,
339-
disconnected_outputs="raise",
373+
use_op_rop_implementation=use_op_rop_implementation,
374+
# 2025: This is a tricky case, the gradient of the gradient does not depend on v
375+
# although v still exists in the graph inside a `Second` operator.
376+
# The original test was checking that Rop wouldn't raise an error, but Lop does.
377+
# Since the correct behavior is ambiguous, I let both implementations off the hook.
378+
disconnected_outputs="raise" if use_op_rop_implementation else "ignore",
340379
)
341380

342381
# 2025: Here is an unambiguous test for the original commented issue:
@@ -348,10 +387,11 @@ def test_Rop_partially_differentiable_paths(self):
348387
out,
349388
[x],
350389
[x.type()],
390+
use_op_rop_implementation=use_op_rop_implementation,
351391
disconnected_outputs="raise",
352392
)
353393

354-
# More extensive testing shows that the Rop implementation FAILS to raise when
394+
# More extensive testing shows that the legacy Rop implementation FAILS to raise when
355395
# the cost is linked through strictly non-differentiable paths.
356396
# This is not Dot specific, we would observe the same with any operation where the gradient
357397
# with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...)
@@ -361,6 +401,7 @@ def test_Rop_partially_differentiable_paths(self):
361401
out,
362402
[x],
363403
[x.type()],
404+
use_op_rop_implementation=use_op_rop_implementation,
364405
disconnected_outputs="raise",
365406
)
366407

@@ -371,5 +412,6 @@ def test_Rop_partially_differentiable_paths(self):
371412
out,
372413
[x],
373414
[x.type()],
415+
use_op_rop_implementation=use_op_rop_implementation,
374416
disconnected_outputs="raise",
375417
)

0 commit comments

Comments
 (0)
Please sign in to comment.