diff --git a/pytensor/gradient.py b/pytensor/gradient.py
index 78862de7e1..9e25d4f77f 100644
--- a/pytensor/gradient.py
+++ b/pytensor/gradient.py
@@ -4,7 +4,7 @@
 import warnings
 from collections.abc import Callable, Mapping, MutableSequence, Sequence
 from functools import partial, reduce
-from typing import TYPE_CHECKING, Literal, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union
 
 import numpy as np
 
@@ -12,9 +12,9 @@
 from pytensor.compile.ops import ViewOp
 from pytensor.configdefaults import config
 from pytensor.graph import utils
-from pytensor.graph.basic import Apply, NominalVariable, Variable
+from pytensor.graph.basic import Apply, NominalVariable, Variable, io_toposort
 from pytensor.graph.null_type import NullType, null_type
-from pytensor.graph.op import get_test_values
+from pytensor.graph.op import Op, OutputStorageType, get_test_values
 from pytensor.graph.type import Type
 
 
@@ -2292,3 +2292,227 @@ def grad_scale(x, multiplier):
     0.416...
     """
     return GradScale(multiplier)(x)
+
+
+# ===========================================
+# The following is more or less pseudocode...
+# ===========================================
+
+
+# Use transpose and forward mode autodiff to get reverse mode autodiff
+# Ops that only define push_forward (Rop) could use this, which is nice
+# because push_forward is usually easier to derive and think about.
+def pull_back_through_transpose(outputs, inputs, output_cotangents):
+    tangents = [input.type() for input in inputs]
+    output_tangents = push_forward(outputs, inputs, tangents)
+    return linear_transpose(output_tangents, tangents, output_cotangents)
+
+
+# Ops that only define pull_back (Lop) could use this to derive push_forward.
+def push_forward_through_pull_back(outputs, inputs, tangents):
+    cotangents = [out.type("u") for out in outputs]
+    input_cotangents = pull_back(outputs, inputs, cotangents)
+    return pull_back(input_cotangents, cotangents, tangents)
+
+
+def _push_forward_impl(outputs, inputs, input_tangents):
+    # Get the nodes in topological order and precompute
+    # a set of values that are used in the graph.
+    nodes = io_toposort(inputs, outputs)
+    used_values = set(outputs)
+    for node in reversed(nodes):
+        if any(output in used_values for output in node.outputs):
+            used_values.update(node.inputs)
+
+    # Maybe a lazy gradient op could use this during rewrite time?
+    recorded_rewrites = {}
+    known_tangents = dict(zip(inputs, input_tangents, strict=True))
+    for node in nodes:
+        tangents = [known_tangents.get(input, None) for input in node.inputs]
+        result_nums = [
+            i for i in range(len(node.outputs)) if node.outputs[i] in used_values
+        ]
+        new_outputs, output_tangents = node.op.push_forward(node, tangents, result_nums)
+        if new_outputs is not None:
+            recorded_rewrites[node] = new_outputs
+
+        for i, tangent in zip(result_nums, output_tangents, strict=True):
+            known_tangents[node.outputs[i]] = tangent
+
+    return [known_tangents[output] for output in outputs]
+
+
+def _pull_back_impl(outputs, inputs, output_cotangents):
+    known_cotangents = dict(zip(outputs, output_cotangents, strict=True))
+
+    nodes = io_toposort(inputs, outputs)
+    used_values = set(outputs)
+    for node in reversed(nodes):
+        if any(output in used_values for output in node.outputs):
+            used_values.update(node.inputs)
+
+    # Maybe a lazy gradient op could use this during rewrite time?
+    recorded_rewrites = {}
+    for node in reversed(nodes):
+        cotangents = [known_cotangents.get(output, None) for output in node.outputs]
+        argnums = [i for i in range(len(node.inputs)) if node.inputs[i] in used_values]
+        new_outputs, input_cotangents = node.op.pull_back(node, cotangents, argnums)
+        if new_outputs is not None:
+            recorded_rewrites[node] = new_outputs
+
+        for i, cotangent in zip(argnums, input_cotangents, strict=True):
+            if cotangent is None:
+                continue
+            input = node.inputs[i]
+            if input not in known_cotangents:
+                known_cotangents[input] = cotangent
+            else:
+                # TODO check that we are not broadcasting?
+                known_cotangents[input] += cotangent
+
+    return [known_cotangents[input] for input in inputs]
+
+
+def pullback_grad(cost, wrt):
+    """A new pt.grad that uses the pull_back function.
+
+    At some point we might want to replace pt.grad with this?
+    """
+    from pytensor.tensor import as_tensor_variable
+
+    # Error checking and allow non-list wrt...
+    return pull_back([cost], wrt, [as_tensor_variable(1.0)])
+
+
+def linear_transpose(outputs, inputs, transposed_inputs):
+    """Given a linear function from inputs to outputs, return the transposed function."""
+    # some loop over inv_toposort...
+    # Should look similar to pull_back?
+
+
+class PullBackOp(Op):
+    __props__ = ("n_outputs", "n_inputs")
+
+    def __init__(self, n_outputs, n_inputs):
+        self.n_outputs = n_outputs
+        self.n_inputs = n_inputs
+        super().__init__()
+
+    def make_node(self, *all_inputs) -> Apply:
+        # all_inputs is [*outputs, *inputs, *output_cotangents]
+        if len(all_inputs) != 2 * self.n_outputs + self.n_inputs:
+            raise ValueError("Incorrect number of inputs")
+
+        inputs_output_cotangents = all_inputs[self.n_outputs :]
+        inputs = inputs_output_cotangents[: self.n_inputs]
+
+        input_cotangents = [input.type() for input in inputs]
+
+        # TODO
+        continous_dtypes = ["float64", "float32", "float16"]
+        for input in inputs:
+            if input.type.dtype not in continous_dtypes:
+                raise ValueError(
+                    f"Can not compute pullback for non-continous value {input}"
+                )
+
+        return Apply(self, all_inputs, input_cotangents)
+
+    def _get_pullback_primal_outputs(self, node):
+        return node.inputs[: self.n_outputs]
+
+    def _get_pullback_primal_inputs(self, node):
+        return node.inputs[self.n_outputs : self.n_outputs + self.n_inputs]
+
+    def _get_pullback_output_cotangents(self, node):
+        return node.inputs[self.n_outputs + self.n_inputs :]
+
+    def _get_pullback_input_cotangents(self, node):
+        return node.outputs
+
+    def _pullback_split_args(self, node):
+        return (
+            self._get_pullback_primal_outputs(node),
+            self._get_pullback_primal_inputs(node),
+            self._get_pullback_output_cotangents(node),
+        )
+
+    def perform(
+        self, node: Apply, inputs: Sequence[Any], output_storage: OutputStorageType
+    ) -> None:
+        raise NotImplementedError(
+            "PullBackOp can not be executed, but needs to be removed in rewrites"
+        )
+
+    def infer_shape(self, fgraph, node, shapes):
+        return shapes[self.n_outputs + self.n_inputs :]
+
+
+class PushForwardOp(Op):
+    __props__ = ("n_outputs", "n_inputs")
+
+    def __init__(self, n_outputs, n_inputs):
+        self.n_outputs = n_outputs
+        self.n_inputs = n_inputs
+        super().__init__()
+
+    def make_node(self, *all_inputs) -> Apply:
+        # all_inputs is [*outputs, *inputs, *input_tangents]
+        if len(all_inputs) != self.n_outputs + 2 * self.n_inputs:
+            raise ValueError("Incorrect number of inputs")
+
+        outputs = all_inputs[: self.n_outputs]
+        inputs_input_tangents = all_inputs[self.n_outputs :]
+
+        inputs = inputs_input_tangents[: self.n_inputs]
+
+        output_tangents = [output.type() for output in outputs]
+
+        # TODO
+        for input in inputs:
+            continous_dtypes = ["float64", "float32", "float16"]
+            if input.type.dtype not in continous_dtypes:
+                raise ValueError(
+                    f"Can not compute push forward for non-continous value {input}"
+                )
+
+        return Apply(self, all_inputs, output_tangents)
+
+    def _get_push_forward_primal_outputs(self, node):
+        return node.inputs[: self.n_outputs]
+
+    def _get_push_forward_primal_inputs(self, node):
+        return node.inputs[self.n_outputs : self.n_outputs + self.n_inputs]
+
+    def _get_push_forward_output_tangents(self, node):
+        return node.outputs
+
+    def _get_push_forward_input_tangents(self, node):
+        return node.inputs[self.n_outputs + self.n_inputs :]
+
+    def _push_forward_split_args(self, node):
+        return (
+            self._get_push_forward_primal_outputs(node),
+            self._get_push_forward_primal_inputs(node),
+            self._get_push_forward_input_tangents(node),
+        )
+
+    def perform(
+        self, node: Apply, inputs: Sequence[Any], output_storage: OutputStorageType
+    ) -> None:
+        raise NotImplementedError(
+            "PullBackOp can not be executed, but needs to be removed in rewrites"
+        )
+
+    def infer_shape(self, fgraph, node, shapes):
+        return shapes[: self.n_outputs]
+
+
+def pull_back(outputs, inputs, output_cotangents):
+    op = PullBackOp(len(outputs), len(inputs))
+    return op(*outputs, *inputs, *output_cotangents, return_list=True)
+
+
+def push_forward(outputs, inputs, input_tangents):
+    op = PushForwardOp(len(outputs), len(inputs))
+    return op(*outputs, *inputs, *input_tangents, return_list=True)
diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py
index 160a65dd7a..acd7dec11f 100644
--- a/pytensor/graph/op.py
+++ b/pytensor/graph/op.py
@@ -6,7 +6,9 @@
 from typing import (
     TYPE_CHECKING,
     Any,
+    Optional,
     Protocol,
+    Tuple,
     TypeVar,
     cast,
 )
@@ -323,6 +325,119 @@ def __ne__(self, other: Any) -> bool:
     # just to self.add_tag_trace
     add_tag_trace = staticmethod(add_tag_trace)
 
+    def linear_transpose(
+        self,
+        node: Apply,
+        transposed_inputs: Sequence[Variable],
+        linear_inputs: Sequence[int],
+        linear_outputs: Sequence[int],
+    ) -> Sequence[Variable]:
+        """Transpose a linear function.
+
+        The function f: [node.inputs[i] for i in linear_inputs] to [node.outputs[i] ofr i in linear_outputs]
+        given the remaining inputs as constants must be linear. This function can then
+        be implemented by an Op, and return f^*(transposed_inputs).
+
+        Parameters
+        ----------
+        node: Apply
+            The point at which to do the transpose
+        transposed_inputs:
+            The inputs for the transposed function.
+        linear_inputs:
+            Indices of input arguments to consider.
+        linear_outputs:
+            Indices of output arguments to consider.
+        """
+        raise NotImplementedError(f"Linear transpos of {self} is not defined or not implemented.")
+
+    def push_forward(
+        self,
+        node: Apply,
+        input_tangents: Sequence[Variable | None],
+        result_nums: Sequence[int],
+    ) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]:
+        """Compute the push_forward of tangent vectors at the specified point.
+
+        Parameters
+        ----------
+        node: Apply
+            The point at which to compute the push_forward. (ie at x = node.inputs
+            and f(x) = node.outputs).
+        input_tangents:
+            The values of the tangent vectors that we wish to map. Values that
+            are set to None are assumed to be constants.
+        result_nums:
+            Compute only the output tangents of [node.outputs[i] for i in argnums].
+
+        Returns
+        -------
+        alternative_outputs:
+            Optionally a hint to the rewriter that the outputs of the op could
+            also be computed with the provided values, if the tangents are also
+            computed.
+        output_tangents:
+            The tangents of the outputs specified in argnums.
+            If the value is None, this indicates that the output did
+            not depend on the inputs that had tangents provided..
+        """
+        from pytensor.gradient import DisconnectedType
+        from pytensor.graph.null_type import NullType
+        from pytensor.tensor.basic import zeros_like
+
+        tangents_filled = [
+            # TODO do the R_op methods also accept a disconnected_grad?
+            tangent if tangent is not None else zeros_like(input)
+            for tangent, input in zip(input_tangents, node.inputs, strict=True)
+        ]
+        output_tangents = self.R_op(node.inputs, tangents_filled)
+        output_tangents = [output_tangents[i] for i in result_nums]
+
+        mapped_output_tangents = []
+        for argnum, tangent in zip(result_nums, output_tangents):
+            if isinstance(tangent.type, DisconnectedType):
+                mapped_output_tangents.append(None)
+            elif isinstance(tangent.type, NullType):
+                raise NotImplementedError(
+                    f"The push_forward of argument {argnum} of op "
+                    f"{self} is not implemented or not defined."
+                )
+            else:
+                mapped_output_tangents.append(tangent)
+        return (None, mapped_output_tangents)
+
+    def pull_back(
+        self,
+        node: Apply,
+        output_cotangents: Sequence[Variable | None],
+        argnums: Sequence[int],
+    ) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]:
+        from pytensor.gradient import DisconnectedType
+        from pytensor.graph.null_type import NullType
+        from pytensor.tensor.basic import zeros_like
+
+        cotangents_filled = [
+            # TODO do the L_op methods also accept a disconnected_grad?
+            cotangent if cotangent is not None else zeros_like(input)
+            for cotangent, input in zip(output_cotangents, node.outputs, strict=True)
+        ]
+
+        input_cotangents = self.L_op(node.inputs, node.outputs, cotangents_filled)
+        input_cotangents = [input_cotangents[i] for i in argnums]
+
+        mapped_input_cotangents = []
+        for argnum, cotangent in zip(argnums, input_cotangents):
+            if isinstance(cotangent.type, DisconnectedType):
+                mapped_input_cotangents.append(None)
+            elif isinstance(cotangent.type, NullType):
+                raise NotImplementedError(
+                    f"The push_forward of argument {argnum} of op "
+                    f"{self} is not implemented or not defined."
+                )
+            else:
+                mapped_input_cotangents.append(cotangent)
+        return (None, mapped_input_cotangents)
+
     def grad(
         self, inputs: Sequence[Variable], output_grads: Sequence[Variable]
     ) -> list[Variable]:
diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py
index 06d023d780..1fbb180853 100644
--- a/pytensor/tensor/rewriting/math.py
+++ b/pytensor/tensor/rewriting/math.py
@@ -2366,27 +2366,32 @@ def local_log_add_exp(fgraph, node):
 
     TODO: in canonicalize, change log10 and log2 -> log
     """
+    z = node.inputs[0]
+    if not z.owner or z.owner.op != add:
+        return
 
-    if node.op == log:
-        z = node.inputs[0]
-        if z.owner and z.owner.op == add:
-            zi = z.owner.inputs
-            pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp]
-            # all arguments to add are exp(<something>)
-            if len(pre_exp) == len(zi):
-                # Do not offset when max_pre = -np.inf, to avoid nan in the output
-                # Switch statement is placed directly inside add to break the self-symmetry
-                # of the returned output (otherwise the rewrite would not stabilize)
-                max_pre = reduce(maximum, pre_exp)
-                ret = max_pre + log(
-                    add(
-                        *[
-                            switch(isinf(max_pre), exp(max_pre), exp(p - max_pre))
-                            for p in pre_exp
-                        ]
-                    )
-                )
-                return [ret]
+    zi = z.owner.inputs
+    pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp]
+
+    # all arguments to add are exp(<something>)
+    if len(pre_exp) != len(zi):
+        return
+
+    if len(zi) == 2:
+        a, b = pre_exp
+        replace_val = switch(a > b, a + log1p(a - b), b + log1p(b - a))
+        # Handle inf cases
+        replace_val = switch(eq(a, b), a + log(2), replace_val)
+        return [replace_val]
+
+    # Do not offset when max_pre = -np.inf, to avoid nan in the output
+    # Switch statement is placed directly inside add to break the self-symmetry
+    # of the returned output (otherwise the rewrite would not stabilize)
+    max_pre = reduce(maximum, pre_exp)
+    ret = max_pre + log(
+        add(*[switch(isinf(max_pre), exp(max_pre), exp(p - max_pre)) for p in pre_exp])
+    )
+    return [ret]
 
 
 @register_stabilize
diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py
index c89049105f..43315b46c0 100644
--- a/pytensor/tensor/subtensor.py
+++ b/pytensor/tensor/subtensor.py
@@ -9,7 +9,7 @@
 import pytensor
 from pytensor import scalar as ps
 from pytensor.configdefaults import config
-from pytensor.gradient import DisconnectedType
+from pytensor.gradient import DisconnectedType, linear_transpose, push_forward
 from pytensor.graph.basic import Apply, Constant, Variable
 from pytensor.graph.op import Op
 from pytensor.graph.replace import _vectorize_node
@@ -837,6 +837,27 @@ def infer_shape(self, fgraph, node, shapes):
         assert len(outshp) == node.outputs[0].ndim
         return [outshp]
 
+    def linear_transpose(self, node, transposed_inputs, linear_inputs, linear_outputs):
+        assert linear_inputs == [0]
+        assert linear_outputs == [0]
+        (transposed_input,) = transposed_inputs
+
+        x, *others = node.inputs
+        return [IncSubtensor(self.idx_list)(x.zeros_like(), transposed_input, *others)]
+
+    def push_forward(self, node, input_tangents, result_nums):
+        if len(result_nums) == 0:
+            return None, []
+
+        assert result_nums[0] == 0
+
+        value_tangent, *_ = input_tangents
+        if value_tangent is None:
+            return None, [None]
+
+        _, *others = node.inputs
+        return None, [self(value_tangent, *others)]
+
     def grad(self, inputs, grads):
         (gz,) = grads
         x = inputs[0]