Skip to content

Reuse LU decomposition in Solve #1396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"fusion",
"inplace",
"scan_save_mem_prealloc",
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
)
Expand Down
10 changes: 4 additions & 6 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2561,26 +2561,24 @@ def scan_push_out_dot1(fgraph, node):
position=1,
)


scan_seqopt1.register(
"scan_push_out_non_seq",
in2out(scan_push_out_non_seq, ignore_newtrees=True),
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
"fast_run",
"scan",
"scan_pushout",
position=2,
position=3,
)


scan_seqopt1.register(
"scan_push_out_seq",
in2out(scan_push_out_seq, ignore_newtrees=True),
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
"fast_run",
"scan",
"scan_pushout",
position=3,
position=4,
)


Expand All @@ -2592,7 +2590,7 @@ def scan_push_out_dot1(fgraph, node):
"more_mem",
"scan",
"scan_pushout",
position=4,
position=5,
)


Expand All @@ -2605,7 +2603,7 @@ def scan_push_out_dot1(fgraph, node):
"more_mem",
"scan",
"scan_pushout",
position=5,
position=6,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ordering necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we want the rewrite that splits the LU before the pushout which is the one that actually removes it from the inner graph.

I could have used decimals, but it makes sense to have something whole between the previous rewrite and this

)

scan_eqopt2.register(
Expand Down
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:


# isort: off
import pytensor.tensor._linalg
from pytensor.tensor import linalg
from pytensor.tensor import special
from pytensor.tensor import signal
Expand Down
2 changes: 2 additions & 0 deletions pytensor/tensor/_linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Register rewrites
import pytensor.tensor._linalg.solve
2 changes: 2 additions & 0 deletions pytensor/tensor/_linalg/solve/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Register rewrites in the database
import pytensor.tensor._linalg.solve.rewriting
198 changes: 198 additions & 0 deletions pytensor/tensor/_linalg/solve/rewriting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from collections.abc import Container
from copy import copy

from pytensor.graph import Constant, graph_inputs
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor.basic import atleast_Nd
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you choose to put these rewrites in tensor._linalg.solve.rewriting instead of in tensor.rewriting._linalg.solve ? It breaks the usual pattern.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it does. For instance the random rewrites are in tensor/random/rewriting, not tensor/rewriting/random

from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable


def decompose_A(A, assume_a):
if assume_a == "gen":
return lu_factor(A, check_finite=False)
else:
raise NotImplementedError

Check warning on line 21 in pytensor/tensor/_linalg/solve/rewriting.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/_linalg/solve/rewriting.py#L21

Added line #L21 was not covered by tests


def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False):
if assume_a == "gen":
return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed)
else:
raise NotImplementedError

Check warning on line 28 in pytensor/tensor/_linalg/solve/rewriting.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/_linalg/solve/rewriting.py#L28

Added line #L28 was not covered by tests


def _split_lu_solve_steps(
fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
):
if not isinstance(node.op.core_op, Solve):
return None

def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
# Find the root variable of the first input to Solve
# If `a` is a left expand_dims or matrix transpose (DimShuffle variants),
# the root variable is the pre-DimShuffled input.
# Otherwise, `a` is considered the root variable.
# We also return whether the root `a` is transposed.
transposed = False
if a.owner is not None and isinstance(a.owner.op, DimShuffle):
if a.owner.op.is_left_expand_dims:
[a] = a.owner.inputs
elif is_matrix_transpose(a):
[a] = a.owner.inputs
transposed = True
return a, transposed

def find_solve_clients(var, assume_a):
clients = []
for cl, idx in fgraph.clients[var]:
if (
idx == 0
and isinstance(cl.op, Blockwise)
and isinstance(cl.op.core_op, Solve)
and (cl.op.core_op.assume_a == assume_a)
):
clients.append(cl)
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
# If it's a left expand_dims, recurse on the output
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
return clients

assume_a = node.op.core_op.assume_a

if assume_a not in allowed_assume_a:
return None

A, _ = get_root_A(node.inputs[0])

# Find Solve using A (or left expand_dims of A)
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
# that to the A_decomp outputs
A_solve_clients_and_transpose = [
(client, False) for client in find_solve_clients(A, assume_a)
]

# Find Solves using A.T
for cl, _ in fgraph.clients[A]:
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
A_T = cl.out
A_solve_clients_and_transpose.extend(
(client, True) for client in find_solve_clients(A_T, assume_a)
)

if not eager and len(A_solve_clients_and_transpose) == 1:
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
# That's a "reuse" inside the inner vectorized loop
batch_ndim = node.op.batch_ndim(node)
(client, _) = A_solve_clients_and_transpose[0]
original_A, b = client.inputs
if not any(
a_bcast and not b_bcast
for a_bcast, b_bcast in zip(
original_A.type.broadcastable[:batch_ndim],
b.type.broadcastable[:batch_ndim],
strict=True,
)
):
return None

A_decomp = decompose_A(A, assume_a=assume_a)

replacements = {}
for client, transposed in A_solve_clients_and_transpose:
_, b = client.inputs
b_ndim = client.op.core_op.b_ndim
new_x = solve_lu_decomposed_system(
A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed
)
[old_x] = client.outputs
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
copy_stack_trace(old_x, new_x)
replacements[old_x] = new_x

return replacements


def _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, *, allowed_assume_a: Container[str]
):
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.

The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite.
"""
scan_op: Scan = node.op
non_sequences = set(scan_op.inner_non_seqs(scan_op.inner_inputs))
new_scan_fgraph = scan_op.fgraph

changed = False
while True:
for inner_node in new_scan_fgraph.toposort():
if (
isinstance(inner_node.op, Blockwise)
and isinstance(inner_node.op.core_op, Solve)
and inner_node.op.core_op.assume_a in allowed_assume_a
):
A, b = inner_node.inputs
if all(
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
for root_inp in graph_inputs([A])
):
if new_scan_fgraph is scan_op.fgraph:
# Clone the first time to avoid mutating the original fgraph
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
inner_node = equiv[inner_node] # type: ignore

replace_dict = _split_lu_solve_steps(
new_scan_fgraph,
inner_node,
eager=True,
allowed_assume_a=allowed_assume_a,
)
assert (
isinstance(replace_dict, dict) and len(replace_dict) > 0
), "Rewrite failed"
new_scan_fgraph.replace_all(replace_dict.items())
changed = True
break # Break to start over with a fresh toposort
else: # no_break
break # Nothing else changed

if not changed:
return

# Return a new scan to indicate that a rewrite was done
new_scan_op = copy(scan_op)
new_scan_op.fgraph = new_scan_fgraph
new_outs = new_scan_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, new_outs)
return new_outs


@register_specialize
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})


@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
)


scan_seqopt1.register(
"scan_split_non_sequence_lu_decomposition_solve",
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
"fast_run",
"scan",
"scan_pushout",
position=2,
)
7 changes: 7 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@
if ndims < 2:
return False
transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)

# Allow expand_dims on the left of the transpose
if (diff := len(transpose_order) - len(node.op.new_order)) > 0:
transpose_order = (

Check warning on line 81 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L81

Added line #L81 was not covered by tests
*(["x"] * diff),
*transpose_order,
)
return node.op.new_order == transpose_order
return False

Expand Down
Empty file added tests/tensor/linalg/__init__.py
Empty file.
Loading