@@ -142,13 +142,50 @@ def __str__(self):
142
142
disconnected_type = DisconnectedType ()
143
143
144
144
145
- def Rop (
146
- f : Variable | Sequence [Variable ],
147
- wrt : Variable | Sequence [Variable ],
148
- eval_points : Variable | Sequence [Variable ],
145
+ def pushforward_through_pullback (
146
+ outputs : Sequence [Variable ],
147
+ inputs : Sequence [Variable ],
148
+ tangents : Sequence [Variable ],
149
149
disconnected_outputs : Literal ["ignore" , "warn" , "raise" ] = "raise" ,
150
150
return_disconnected : Literal ["none" , "zero" , "disconnected" ] = "zero" ,
151
- ) -> Variable | None | Sequence [Variable | None ]:
151
+ ) -> Sequence [Variable | None ]:
152
+ """Compute the pushforward (Rop) through two applications of a pullback (Lop) operation.
153
+
154
+ References
155
+ ----------
156
+ .. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017.
157
+ Available: https://j-towns.github.io/2017/06/12/A-new-trick.html
158
+
159
+ """
160
+ # Cotangents are just auxiliary variables that should be pruned from the final graph,
161
+ # but that would require a graph rewrite before the user tries to compile a pytensor function.
162
+ # To avoid trouble we use .zeros_like() instead of .type(), which does not create a new root variable.
163
+ cotangents = [out .zeros_like (dtype = config .floatX ) for out in outputs ] # type: ignore
164
+
165
+ input_cotangents = Lop (
166
+ f = outputs ,
167
+ wrt = inputs ,
168
+ eval_points = cotangents ,
169
+ disconnected_inputs = disconnected_outputs ,
170
+ return_disconnected = "zero" ,
171
+ )
172
+
173
+ return Lop (
174
+ f = input_cotangents , # type: ignore
175
+ wrt = cotangents ,
176
+ eval_points = tangents ,
177
+ disconnected_inputs = "ignore" ,
178
+ return_disconnected = return_disconnected ,
179
+ )
180
+
181
+
182
+ def _rop_legacy (
183
+ f : Sequence [Variable ],
184
+ wrt : Sequence [Variable ],
185
+ eval_points : Sequence [Variable ],
186
+ disconnected_outputs : Literal ["ignore" , "warn" , "raise" ] = "raise" ,
187
+ return_disconnected : Literal ["none" , "zero" , "disconnected" ] = "zero" ,
188
+ ) -> Sequence [Variable | None ]:
152
189
"""Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
153
190
154
191
Mathematically this stands for the Jacobian of `f` right multiplied by the
@@ -190,38 +227,6 @@ def Rop(
190
227
If `f` is a list/tuple, then return a list/tuple with the results.
191
228
"""
192
229
193
- if not isinstance (wrt , list | tuple ):
194
- _wrt : list [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
195
- else :
196
- _wrt = [pytensor .tensor .as_tensor_variable (x ) for x in wrt ]
197
-
198
- if not isinstance (eval_points , list | tuple ):
199
- _eval_points : list [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
200
- else :
201
- _eval_points = [pytensor .tensor .as_tensor_variable (x ) for x in eval_points ]
202
-
203
- if not isinstance (f , list | tuple ):
204
- _f : list [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
205
- else :
206
- _f = [pytensor .tensor .as_tensor_variable (x ) for x in f ]
207
-
208
- if len (_wrt ) != len (_eval_points ):
209
- raise ValueError ("`wrt` must be the same length as `eval_points`." )
210
-
211
- # Check that each element of wrt corresponds to an element
212
- # of eval_points with the same dimensionality.
213
- for i , (wrt_elem , eval_point ) in enumerate (zip (_wrt , _eval_points , strict = True )):
214
- try :
215
- if wrt_elem .type .ndim != eval_point .type .ndim :
216
- raise ValueError (
217
- f"Elements { i } of `wrt` and `eval_point` have mismatched dimensionalities: "
218
- f"{ wrt_elem .type .ndim } and { eval_point .type .ndim } "
219
- )
220
- except AttributeError :
221
- # wrt_elem and eval_point don't always have ndim like random type
222
- # Tensor, Sparse have the ndim attribute
223
- pass
224
-
225
230
seen_nodes : dict [Apply , Sequence [Variable ]] = {}
226
231
227
232
def _traverse (node ):
@@ -237,8 +242,8 @@ def _traverse(node):
237
242
# inputs of the node
238
243
local_eval_points = []
239
244
for inp in inputs :
240
- if inp in _wrt :
241
- local_eval_points .append (_eval_points [ _wrt .index (inp )])
245
+ if inp in wrt :
246
+ local_eval_points .append (eval_points [ wrt .index (inp )])
242
247
elif inp .owner is None :
243
248
try :
244
249
local_eval_points .append (inp .zeros_like ())
@@ -292,13 +297,13 @@ def _traverse(node):
292
297
# end _traverse
293
298
294
299
# Populate the dictionary
295
- for out in _f :
300
+ for out in f :
296
301
_traverse (out .owner )
297
302
298
303
rval : list [Variable | None ] = []
299
- for out in _f :
300
- if out in _wrt :
301
- rval .append (_eval_points [ _wrt .index (out )])
304
+ for out in f :
305
+ if out in wrt :
306
+ rval .append (eval_points [ wrt .index (out )])
302
307
elif (
303
308
seen_nodes .get (out .owner , None ) is None
304
309
or seen_nodes [out .owner ][out .owner .outputs .index (out )] is None
@@ -337,6 +342,116 @@ def _traverse(node):
337
342
else :
338
343
rval .append (seen_nodes [out .owner ][out .owner .outputs .index (out )])
339
344
345
+ return rval
346
+
347
+
348
+ def Rop (
349
+ f : Variable | Sequence [Variable ],
350
+ wrt : Variable | Sequence [Variable ],
351
+ eval_points : Variable | Sequence [Variable ],
352
+ disconnected_outputs : Literal ["ignore" , "warn" , "raise" ] = "raise" ,
353
+ return_disconnected : Literal ["none" , "zero" , "disconnected" ] = "zero" ,
354
+ use_op_rop_implementation : bool = False ,
355
+ ) -> Variable | None | Sequence [Variable | None ]:
356
+ """Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`.
357
+
358
+ Mathematically this stands for the Jacobian of `f` right multiplied by the
359
+ `eval_points`.
360
+
361
+ By default, the R-operator is implemented as a double application of the L_operator [1]_.
362
+ In most cases this should be as performant as a specialized implementation of the R-operator.
363
+ However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
364
+ such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.
365
+
366
+ When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
367
+ `use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.
368
+
369
+ Parameters
370
+ ----------
371
+ f
372
+ The outputs of the computational graph to which the R-operator is
373
+ applied.
374
+ wrt
375
+ Variables for which the R-operator of `f` is computed.
376
+ eval_points
377
+ Points at which to evaluate each of the variables in `wrt`.
378
+ disconnected_outputs
379
+ Defines the behaviour if some of the variables in `f`
380
+ have no dependency on any of the variable in `wrt` (or if
381
+ all links are non-differentiable). The possible values are:
382
+
383
+ - ``'ignore'``: considers that the gradient on these parameters is zero.
384
+ - ``'warn'``: consider the gradient zero, and print a warning.
385
+ - ``'raise'``: raise `DisconnectedInputError`.
386
+
387
+ return_disconnected
388
+ - ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
389
+ ``wrt[i].zeros_like()``.
390
+ - ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be
391
+ ``None``
392
+ - ``'disconnected'`` : returns variables of type `DisconnectedType`
393
+ use_op_lop_implementation: bool, default=True
394
+ If `True`, we obtain Rop via double application of Lop.
395
+ If `False`, the legacy Rop implementation is used. The number of graphs that support this form
396
+ is much more restricted, and the generated graphs may be less optimized.
397
+
398
+ Returns
399
+ -------
400
+ :class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
401
+ A symbolic expression such obeying
402
+ ``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
403
+ where the indices in that expression are magic multidimensional
404
+ indices that specify both the position within a list and all
405
+ coordinates of the tensor elements.
406
+ If `f` is a list/tuple, then return a list/tuple with the results.
407
+
408
+ References
409
+ ----------
410
+ .. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017.
411
+ Available: https://j-towns.github.io/2017/06/12/A-new-trick.html
412
+ """
413
+
414
+ if not isinstance (wrt , list | tuple ):
415
+ _wrt : list [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
416
+ else :
417
+ _wrt = [pytensor .tensor .as_tensor_variable (x ) for x in wrt ]
418
+
419
+ if not isinstance (eval_points , list | tuple ):
420
+ _eval_points : list [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
421
+ else :
422
+ _eval_points = [pytensor .tensor .as_tensor_variable (x ) for x in eval_points ]
423
+
424
+ if not isinstance (f , list | tuple ):
425
+ _f : list [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
426
+ else :
427
+ _f = [pytensor .tensor .as_tensor_variable (x ) for x in f ]
428
+
429
+ if len (_wrt ) != len (_eval_points ):
430
+ raise ValueError ("`wrt` must be the same length as `eval_points`." )
431
+
432
+ # Check that each element of wrt corresponds to an element
433
+ # of eval_points with the same dimensionality.
434
+ for i , (wrt_elem , eval_point ) in enumerate (zip (_wrt , _eval_points , strict = True )):
435
+ try :
436
+ if wrt_elem .type .ndim != eval_point .type .ndim :
437
+ raise ValueError (
438
+ f"Elements { i } of `wrt` and `eval_point` have mismatched dimensionalities: "
439
+ f"{ wrt_elem .type .ndim } and { eval_point .type .ndim } "
440
+ )
441
+ except AttributeError :
442
+ # wrt_elem and eval_point don't always have ndim like random type
443
+ # Tensor, Sparse have the ndim attribute
444
+ pass
445
+
446
+ if use_op_rop_implementation :
447
+ rval = _rop_legacy (
448
+ _f , _wrt , _eval_points , disconnected_outputs , return_disconnected
449
+ )
450
+ else :
451
+ rval = pushforward_through_pullback (
452
+ _f , _wrt , _eval_points , disconnected_outputs , return_disconnected
453
+ )
454
+
340
455
using_list = isinstance (f , list )
341
456
using_tuple = isinstance (f , tuple )
342
457
return as_list_or_tuple (using_list , using_tuple , rval )
@@ -348,6 +463,7 @@ def Lop(
348
463
eval_points : Variable | Sequence [Variable ],
349
464
consider_constant : Sequence [Variable ] | None = None ,
350
465
disconnected_inputs : Literal ["ignore" , "warn" , "raise" ] = "raise" ,
466
+ return_disconnected : Literal ["none" , "zero" , "disconnected" ] = "zero" ,
351
467
) -> Variable | None | Sequence [Variable | None ]:
352
468
"""Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`.
353
469
@@ -404,6 +520,7 @@ def Lop(
404
520
consider_constant = consider_constant ,
405
521
wrt = _wrt ,
406
522
disconnected_inputs = disconnected_inputs ,
523
+ return_disconnected = return_disconnected ,
407
524
)
408
525
409
526
using_list = isinstance (wrt , list )
0 commit comments