Skip to content

Compute pushforward (Rop) via double application of pullback (Lop) and fix Scan and Max gradient bugs #1207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion doc/extending/op.rst
Original file line number Diff line number Diff line change
@@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`.
the outputs) back to their corresponding shapes and return them as the
output of the :meth:`Op.R_op` method.

:ref:`List of op with r op support <R_op_list>`.
76 changes: 0 additions & 76 deletions doc/library/gradient.rst

This file was deleted.

2 changes: 0 additions & 2 deletions doc/library/tensor/basic.rst
Original file line number Diff line number Diff line change
@@ -1791,5 +1791,3 @@ Gradient / Differentiation
:members: grad
:noindex:

See the :ref:`gradient <libdoc_gradient>` page for complete documentation
of the gradient module.
21 changes: 16 additions & 5 deletions doc/tutorial/gradients.rst
Original file line number Diff line number Diff line change
@@ -86,9 +86,7 @@ of symbolic differentiation).
``i`` of the output list is the gradient of the first argument of
`pt.grad` with respect to the ``i``-th element of the list given as second argument.
The first argument of `pt.grad` has to be a scalar (a tensor
of size 1). For more information on the semantics of the arguments of
`pt.grad` and details about the implementation, see
:ref:`this<libdoc_gradient>` section of the library.
of size 1).

Additional information on the inner workings of differentiation may also be
found in the more advanced tutorial :ref:`Extending PyTensor<extending>`.
@@ -204,7 +202,21 @@ you need to do something similar to this:
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])

:ref:`List <R_op_list>` of Op that implement Rop.
By default, the R-operator is implemented as a double application of the L_operator
(see `reference <https://j-towns.github.io/2017/06/12/A-new-trick.html>`_).
In most cases this should be as performant as a specialized implementation of the R-operator.
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.

When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.


>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True)
>>> f = pytensor.function([W, V, x], JV)
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])


L-operator
----------
@@ -234,7 +246,6 @@ array([[ 0., 0.],
as the input parameter, while the result of the R-operator has a shape similar
to that of the output.

:ref:`List of op with r op support <R_op_list>`.

Hessian times a Vector
======================
13 changes: 12 additions & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
@@ -340,6 +340,12 @@ def __init__(
``None``, this will be used as the connection_pattern for this
:class:`Op`.

.. warning::

rop overrides is ignored when `pytensor.gradient.Rop` is called with
`use_op_rop_implementation=False` (default). In this case the Lop
is used twice to obtain a mathematically equivalent Rop.

strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
@@ -641,7 +647,12 @@ def _build_and_cache_rop_op(self):
return rop_overrides

eval_points = [inp_t() for inp_t in self.input_types]
fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points)
fn_rop = partial(
Rop,
wrt=inner_inputs,
eval_points=eval_points,
use_op_rop_implementation=True,
)

callable_args = (inner_inputs, eval_points)
if rop_overrides is None:
Loading