diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py
index f80dfaaf5c..63a1ba835b 100644
--- a/pytensor/compile/mode.py
+++ b/pytensor/compile/mode.py
@@ -490,6 +490,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
             "fusion",
             "inplace",
             "scan_save_mem_prealloc",
+            "reuse_lu_decomposition_multiple_solves",
+            "scan_split_non_sequence_lu_decomposition_solve",
         ],
     ),
 )
diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py
index b8e6b009d8..c49fbadce4 100644
--- a/pytensor/scan/rewriting.py
+++ b/pytensor/scan/rewriting.py
@@ -2561,7 +2561,6 @@ def scan_push_out_dot1(fgraph, node):
     position=1,
 )
 
-
 scan_seqopt1.register(
     "scan_push_out_non_seq",
     in2out(scan_push_out_non_seq, ignore_newtrees=True),
@@ -2569,10 +2568,9 @@ def scan_push_out_dot1(fgraph, node):
     "fast_run",
     "scan",
     "scan_pushout",
-    position=2,
+    position=3,
 )
 
-
 scan_seqopt1.register(
     "scan_push_out_seq",
     in2out(scan_push_out_seq, ignore_newtrees=True),
@@ -2580,7 +2578,7 @@ def scan_push_out_dot1(fgraph, node):
     "fast_run",
     "scan",
     "scan_pushout",
-    position=3,
+    position=4,
 )
 
 
@@ -2592,7 +2590,7 @@ def scan_push_out_dot1(fgraph, node):
     "more_mem",
     "scan",
     "scan_pushout",
-    position=4,
+    position=5,
 )
 
 
@@ -2605,7 +2603,7 @@ def scan_push_out_dot1(fgraph, node):
     "more_mem",
     "scan",
     "scan_pushout",
-    position=5,
+    position=6,
 )
 
 scan_eqopt2.register(
diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py
index c6b421d003..ce590f8228 100644
--- a/pytensor/tensor/__init__.py
+++ b/pytensor/tensor/__init__.py
@@ -114,6 +114,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
 
 
 # isort: off
+import pytensor.tensor._linalg
 from pytensor.tensor import linalg
 from pytensor.tensor import special
 from pytensor.tensor import signal
diff --git a/pytensor/tensor/_linalg/__init__.py b/pytensor/tensor/_linalg/__init__.py
new file mode 100644
index 0000000000..767374b10b
--- /dev/null
+++ b/pytensor/tensor/_linalg/__init__.py
@@ -0,0 +1,2 @@
+# Register rewrites
+import pytensor.tensor._linalg.solve
diff --git a/pytensor/tensor/_linalg/solve/__init__.py b/pytensor/tensor/_linalg/solve/__init__.py
new file mode 100644
index 0000000000..1d85f4a66b
--- /dev/null
+++ b/pytensor/tensor/_linalg/solve/__init__.py
@@ -0,0 +1,2 @@
+# Register rewrites in the database
+import pytensor.tensor._linalg.solve.rewriting
diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py
new file mode 100644
index 0000000000..ff1c74cdec
--- /dev/null
+++ b/pytensor/tensor/_linalg/solve/rewriting.py
@@ -0,0 +1,198 @@
+from collections.abc import Container
+from copy import copy
+
+from pytensor.graph import Constant, graph_inputs
+from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
+from pytensor.scan.op import Scan
+from pytensor.scan.rewriting import scan_seqopt1
+from pytensor.tensor.basic import atleast_Nd
+from pytensor.tensor.blockwise import Blockwise
+from pytensor.tensor.elemwise import DimShuffle
+from pytensor.tensor.rewriting.basic import register_specialize
+from pytensor.tensor.rewriting.linalg import is_matrix_transpose
+from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
+from pytensor.tensor.variable import TensorVariable
+
+
+def decompose_A(A, assume_a):
+    if assume_a == "gen":
+        return lu_factor(A, check_finite=False)
+    else:
+        raise NotImplementedError
+
+
+def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False):
+    if assume_a == "gen":
+        return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed)
+    else:
+        raise NotImplementedError
+
+
+def _split_lu_solve_steps(
+    fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
+):
+    if not isinstance(node.op.core_op, Solve):
+        return None
+
+    def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
+        # Find the root variable of the first input to Solve
+        # If `a` is a left expand_dims or matrix transpose (DimShuffle variants),
+        # the root variable is the pre-DimShuffled input.
+        # Otherwise, `a` is considered the root variable.
+        # We also return whether the root `a` is transposed.
+        transposed = False
+        if a.owner is not None and isinstance(a.owner.op, DimShuffle):
+            if a.owner.op.is_left_expand_dims:
+                [a] = a.owner.inputs
+            elif is_matrix_transpose(a):
+                [a] = a.owner.inputs
+                transposed = True
+        return a, transposed
+
+    def find_solve_clients(var, assume_a):
+        clients = []
+        for cl, idx in fgraph.clients[var]:
+            if (
+                idx == 0
+                and isinstance(cl.op, Blockwise)
+                and isinstance(cl.op.core_op, Solve)
+                and (cl.op.core_op.assume_a == assume_a)
+            ):
+                clients.append(cl)
+            elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
+                # If it's a left expand_dims, recurse on the output
+                clients.extend(find_solve_clients(cl.outputs[0], assume_a))
+        return clients
+
+    assume_a = node.op.core_op.assume_a
+
+    if assume_a not in allowed_assume_a:
+        return None
+
+    A, _ = get_root_A(node.inputs[0])
+
+    # Find Solve using A (or left expand_dims of A)
+    # TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
+    #  that to the A_decomp outputs
+    A_solve_clients_and_transpose = [
+        (client, False) for client in find_solve_clients(A, assume_a)
+    ]
+
+    # Find Solves using A.T
+    for cl, _ in fgraph.clients[A]:
+        if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
+            A_T = cl.out
+            A_solve_clients_and_transpose.extend(
+                (client, True) for client in find_solve_clients(A_T, assume_a)
+            )
+
+    if not eager and len(A_solve_clients_and_transpose) == 1:
+        # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
+        # That's a "reuse" inside the inner vectorized loop
+        batch_ndim = node.op.batch_ndim(node)
+        (client, _) = A_solve_clients_and_transpose[0]
+        original_A, b = client.inputs
+        if not any(
+            a_bcast and not b_bcast
+            for a_bcast, b_bcast in zip(
+                original_A.type.broadcastable[:batch_ndim],
+                b.type.broadcastable[:batch_ndim],
+                strict=True,
+            )
+        ):
+            return None
+
+    A_decomp = decompose_A(A, assume_a=assume_a)
+
+    replacements = {}
+    for client, transposed in A_solve_clients_and_transpose:
+        _, b = client.inputs
+        b_ndim = client.op.core_op.b_ndim
+        new_x = solve_lu_decomposed_system(
+            A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed
+        )
+        [old_x] = client.outputs
+        new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
+        copy_stack_trace(old_x, new_x)
+        replacements[old_x] = new_x
+
+    return replacements
+
+
+def _scan_split_non_sequence_lu_decomposition_solve(
+    fgraph, node, *, allowed_assume_a: Container[str]
+):
+    """If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
+
+    The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite.
+    """
+    scan_op: Scan = node.op
+    non_sequences = set(scan_op.inner_non_seqs(scan_op.inner_inputs))
+    new_scan_fgraph = scan_op.fgraph
+
+    changed = False
+    while True:
+        for inner_node in new_scan_fgraph.toposort():
+            if (
+                isinstance(inner_node.op, Blockwise)
+                and isinstance(inner_node.op.core_op, Solve)
+                and inner_node.op.core_op.assume_a in allowed_assume_a
+            ):
+                A, b = inner_node.inputs
+                if all(
+                    (isinstance(root_inp, Constant) or (root_inp in non_sequences))
+                    for root_inp in graph_inputs([A])
+                ):
+                    if new_scan_fgraph is scan_op.fgraph:
+                        # Clone the first time to avoid mutating the original fgraph
+                        new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
+                        non_sequences = {equiv[non_seq] for non_seq in non_sequences}
+                        inner_node = equiv[inner_node]  # type: ignore
+
+                    replace_dict = _split_lu_solve_steps(
+                        new_scan_fgraph,
+                        inner_node,
+                        eager=True,
+                        allowed_assume_a=allowed_assume_a,
+                    )
+                    assert (
+                        isinstance(replace_dict, dict) and len(replace_dict) > 0
+                    ), "Rewrite failed"
+                    new_scan_fgraph.replace_all(replace_dict.items())
+                    changed = True
+                    break  # Break to start over with a fresh toposort
+        else:  # no_break
+            break  # Nothing else changed
+
+    if not changed:
+        return
+
+    # Return a new scan to indicate that a rewrite was done
+    new_scan_op = copy(scan_op)
+    new_scan_op.fgraph = new_scan_fgraph
+    new_outs = new_scan_op.make_node(*node.inputs).outputs
+    copy_stack_trace(node.outputs, new_outs)
+    return new_outs
+
+
+@register_specialize
+@node_rewriter([Blockwise])
+def reuse_lu_decomposition_multiple_solves(fgraph, node):
+    return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
+
+
+@node_rewriter([Scan])
+def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
+    return _scan_split_non_sequence_lu_decomposition_solve(
+        fgraph, node, allowed_assume_a={"gen"}
+    )
+
+
+scan_seqopt1.register(
+    "scan_split_non_sequence_lu_decomposition_solve",
+    in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
+    "fast_run",
+    "scan",
+    "scan_pushout",
+    position=2,
+)
diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py
index cd202fe3ed..af42bee236 100644
--- a/pytensor/tensor/rewriting/linalg.py
+++ b/pytensor/tensor/rewriting/linalg.py
@@ -75,6 +75,13 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
         if ndims < 2:
             return False
         transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)
+
+        # Allow expand_dims on the left of the transpose
+        if (diff := len(transpose_order) - len(node.op.new_order)) > 0:
+            transpose_order = (
+                *(["x"] * diff),
+                *transpose_order,
+            )
         return node.op.new_order == transpose_order
     return False
 
diff --git a/tests/tensor/linalg/__init__.py b/tests/tensor/linalg/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py
new file mode 100644
index 0000000000..6f04fac5fb
--- /dev/null
+++ b/tests/tensor/linalg/test_rewriting.py
@@ -0,0 +1,163 @@
+import numpy as np
+import pytest
+
+from pytensor import config, function, scan
+from pytensor.compile.mode import get_default_mode
+from pytensor.gradient import grad
+from pytensor.scan.op import Scan
+from pytensor.tensor._linalg.solve.rewriting import (
+    reuse_lu_decomposition_multiple_solves,
+    scan_split_non_sequence_lu_decomposition_solve,
+)
+from pytensor.tensor.blockwise import Blockwise
+from pytensor.tensor.linalg import solve
+from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular
+from pytensor.tensor.type import tensor
+
+
+def count_vanilla_solve_nodes(nodes) -> int:
+    return sum(
+        (
+            isinstance(node.op, Solve)
+            or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve))
+        )
+        for node in nodes
+    )
+
+
+def count_lu_decom_nodes(nodes) -> int:
+    return sum(
+        (
+            isinstance(node.op, LUFactor)
+            or (
+                isinstance(node.op, Blockwise) and isinstance(node.op.core_op, LUFactor)
+            )
+        )
+        for node in nodes
+    )
+
+
+def count_lu_solve_nodes(nodes) -> int:
+    count = sum(
+        (
+            isinstance(node.op, SolveTriangular)
+            or (
+                isinstance(node.op, Blockwise)
+                and isinstance(node.op.core_op, SolveTriangular)
+            )
+        )
+        for node in nodes
+    )
+    # Each LU solve uses two Triangular solves
+    return count // 2
+
+
+@pytest.mark.parametrize("transposed", (False, True))
+def test_lu_decomposition_reused_forward_and_gradient(transposed):
+    rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
+    mode = get_default_mode()
+
+    A = tensor("A", shape=(2, 2))
+    b = tensor("b", shape=(2, 3))
+
+    x = solve(A, b, assume_a="gen", transposed=transposed)
+    grad_x_wrt_A = grad(x.sum(), A)
+    fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name))
+    no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
+    assert count_vanilla_solve_nodes(no_opt_nodes) == 2
+    assert count_lu_decom_nodes(no_opt_nodes) == 0
+    assert count_lu_solve_nodes(no_opt_nodes) == 0
+
+    fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name))
+    opt_nodes = fn_opt.maker.fgraph.apply_nodes
+    assert count_vanilla_solve_nodes(opt_nodes) == 0
+    assert count_lu_decom_nodes(opt_nodes) == 1
+    assert count_lu_solve_nodes(opt_nodes) == 2
+
+    # Make sure results are correct
+    rng = np.random.default_rng(31)
+    A_test = rng.random(A.type.shape, dtype=A.type.dtype)
+    b_test = rng.random(b.type.shape, dtype=b.type.dtype)
+    resx0, resg0 = fn_no_opt(A_test, b_test)
+    resx1, resg1 = fn_opt(A_test, b_test)
+    rtol = 1e-7 if config.floatX == "float64" else 1e-6
+    np.testing.assert_allclose(resx0, resx1, rtol=rtol)
+    np.testing.assert_allclose(resg0, resg1, rtol=rtol)
+
+
+@pytest.mark.parametrize("transposed", (False, True))
+def test_lu_decomposition_reused_blockwise(transposed):
+    rewrite_name = reuse_lu_decomposition_multiple_solves.__name__
+    mode = get_default_mode()
+
+    A = tensor("A", shape=(2, 2))
+    b = tensor("b", shape=(2, 2, 3))
+
+    x = solve(A, b, transposed=transposed)
+    fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name))
+    no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
+    assert count_vanilla_solve_nodes(no_opt_nodes) == 1
+    assert count_lu_decom_nodes(no_opt_nodes) == 0
+    assert count_lu_solve_nodes(no_opt_nodes) == 0
+
+    fn_opt = function([A, b], [x], mode=mode.including(rewrite_name))
+    opt_nodes = fn_opt.maker.fgraph.apply_nodes
+    assert count_vanilla_solve_nodes(opt_nodes) == 0
+    assert count_lu_decom_nodes(opt_nodes) == 1
+    assert count_lu_solve_nodes(opt_nodes) == 1
+
+    # Make sure results are correct
+    rng = np.random.default_rng(31)
+    A_test = rng.random(A.type.shape, dtype=A.type.dtype)
+    b_test = rng.random(b.type.shape, dtype=b.type.dtype)
+    resx0 = fn_no_opt(A_test, b_test)
+    resx1 = fn_opt(A_test, b_test)
+    np.testing.assert_allclose(resx0, resx1)
+
+
+@pytest.mark.parametrize("transposed", (False, True))
+def test_lu_decomposition_reused_scan(transposed):
+    rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__
+    mode = get_default_mode()
+
+    A = tensor("A", shape=(2, 2))
+    x0 = tensor("b", shape=(2, 3))
+
+    xs, _ = scan(
+        lambda xtm1, A: solve(A, xtm1, assume_a="general", transposed=transposed),
+        outputs_info=[x0],
+        non_sequences=[A],
+        n_steps=10,
+    )
+
+    fn_no_opt = function(
+        [A, x0],
+        [xs],
+        mode=mode.excluding(rewrite_name),
+    )
+    [no_opt_scan_node] = [
+        node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
+    ]
+    no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes
+    assert count_vanilla_solve_nodes(no_opt_nodes) == 1
+    assert count_lu_decom_nodes(no_opt_nodes) == 0
+    assert count_lu_solve_nodes(no_opt_nodes) == 0
+
+    fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name))
+    [opt_scan_node] = [
+        node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
+    ]
+    opt_nodes = opt_scan_node.op.fgraph.apply_nodes
+    assert count_vanilla_solve_nodes(opt_nodes) == 0
+    # The LU decomp is outside of the scan!
+    assert count_lu_decom_nodes(opt_nodes) == 0
+    assert count_lu_solve_nodes(opt_nodes) == 1
+
+    # Make sure results are correct
+    rng = np.random.default_rng(170)
+    A_test = rng.random(A.type.shape, dtype=A.type.dtype)
+    x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype)
+    resx0 = fn_no_opt(A_test, x0_test)
+    resx1 = fn_opt(A_test, x0_test)
+    rtol = 1e-7 if config.floatX == "float64" else 1e-6
+    np.testing.assert_allclose(resx0, resx1, rtol=rtol)
diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py
index a140e07846..cbaf27da29 100644
--- a/tests/tensor/test_blockwise.py
+++ b/tests/tensor/test_blockwise.py
@@ -328,7 +328,7 @@ class BlockwiseOpTester:
 
     @classmethod
     def setup_class(cls):
-        seed = sum(map(ord, str(cls.core_op)))
+        seed = sum(map(ord, str(cls.core_op) + cls.signature))
         cls.rng = np.random.default_rng(seed)
         cls.params_sig, cls.outputs_sig = _parse_gufunc_signature(cls.signature)
         if cls.batcheable_axes is None:
@@ -579,7 +579,10 @@ def test_solve(self, solve_fn, batched_A, batched_b):
         else:
             x = solve_fn(A, b, b_ndim=1)
 
-        mode = get_default_mode().excluding("batched_vector_b_solve_to_matrix_b_solve")
+        mode = get_default_mode().excluding(
+            "batched_vector_b_solve_to_matrix_b_solve",
+            "reuse_lu_decomposition_multiple_solves",
+        )
         fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)
 
         op = fn.maker.fgraph.outputs[0].owner.op