Skip to content

Commit 9652c5d

Browse files
committed
Decompose Tridiagonal Solve into core steps
1 parent 59d9b7d commit 9652c5d

File tree

8 files changed

+430
-47
lines changed

8 files changed

+430
-47
lines changed

pytensor/compile/mode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
477477
"fusion",
478478
"inplace",
479479
"scan_save_mem_prealloc",
480+
"reuse_lu_decomposition_multiple_solves",
481+
"scan_split_non_sequence_lu_decomposition_solve",
480482
],
481483
),
482484
)

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy import ndarray
77
from scipy import linalg
88

9+
from pytensor.link.numba.dispatch import numba_funcify
910
from pytensor.link.numba.dispatch.basic import numba_njit
1011
from pytensor.link.numba.dispatch.linalg._LAPACK import (
1112
_LAPACK,
@@ -20,6 +21,10 @@
2021
_solve_check,
2122
_trans_char_to_int,
2223
)
24+
from pytensor.tensor._linalg.solve.tridiagonal import (
25+
LUFactorTridiagonal,
26+
SolveLUFactorTridiagonal,
27+
)
2328

2429

2530
@numba_njit
@@ -297,3 +302,48 @@ def impl(
297302
return X
298303

299304
return impl
305+
306+
307+
@numba_funcify.register(LUFactorTridiagonal)
308+
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
309+
overwrite_dl = op.overwrite_dl
310+
overwrite_d = op.overwrite_d
311+
overwrite_du = op.overwrite_du
312+
313+
@numba_njit(cache=False)
314+
def lu_factor_tridiagonal(dl, d, du):
315+
if not overwrite_dl:
316+
dl = dl.copy()
317+
if not overwrite_d:
318+
d = d.copy()
319+
if not overwrite_du:
320+
du = du.copy()
321+
322+
dl, d, du, du2, ipiv, _ = _gttrf(dl, d, du)
323+
return dl, d, du, du2, ipiv
324+
325+
return lu_factor_tridiagonal
326+
327+
328+
@numba_funcify.register(SolveLUFactorTridiagonal)
329+
def numba_funcify_SolveLUFactorTridiagonal(
330+
op: SolveLUFactorTridiagonal, node, **kwargs
331+
):
332+
overwrite_b = op.overwrite_b
333+
transposed = op.transposed
334+
335+
@numba_njit(cache=False)
336+
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
337+
x, _ = _gttrs(
338+
dl,
339+
d,
340+
du,
341+
du2,
342+
ipiv,
343+
b,
344+
overwrite_b=overwrite_b,
345+
trans=transposed,
346+
)
347+
return x
348+
349+
return solve_lu_factor_tridiagonal

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
from collections.abc import Container
12
from copy import copy
23

4+
from pytensor import compile
35
from pytensor.graph import Constant, graph_inputs
46
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
57
from pytensor.scan.op import Scan
68
from pytensor.scan.rewriting import scan_seqopt1
9+
from pytensor.tensor._linalg.solve.tridiagonal import (
10+
tridiagonal_lu_factor,
11+
tridiagonal_lu_solve,
12+
)
713
from pytensor.tensor.basic import atleast_Nd
814
from pytensor.tensor.blockwise import Blockwise
915
from pytensor.tensor.elemwise import DimShuffle
@@ -16,21 +22,24 @@
1622
def decompose_A(A, assume_a):
1723
if assume_a == "gen":
1824
return lu_factor(A, check_finite=False)
25+
elif assume_a == "tridiagonal":
26+
return tridiagonal_lu_factor(A)
1927
else:
2028
raise NotImplementedError
2129

2230

2331
def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False):
2432
if assume_a == "gen":
2533
return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed)
34+
elif assume_a == "tridiagonal":
35+
return tridiagonal_lu_solve(A_decomp, b, b_ndim=b_ndim, transposed=transposed)
2636
else:
2737
raise NotImplementedError
2838

2939

30-
_SPLITTABLE_SOLVE_ASSUME_A = {"gen"}
31-
32-
33-
def _split_lu_solve_steps(fgraph, node, *, eager: bool):
40+
def _split_lu_solve_steps(
41+
fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
42+
):
3443
if not isinstance(node.op.core_op, Solve):
3544
return None
3645

@@ -66,7 +75,7 @@ def find_solve_clients(var, assume_a):
6675

6776
assume_a = node.op.core_op.assume_a
6877

69-
if assume_a not in _SPLITTABLE_SOLVE_ASSUME_A:
78+
if assume_a not in allowed_assume_a:
7079
return None
7180

7281
A, _ = get_root_A(node.inputs[0])
@@ -119,19 +128,9 @@ def find_solve_clients(var, assume_a):
119128
return replacements
120129

121130

122-
@register_specialize
123-
@node_rewriter([Blockwise])
124-
def reuse_lu_decomposition_multiple_solves(fgraph, node):
125-
return _split_lu_solve_steps(fgraph, node, eager=False)
126-
127-
128-
@node_rewriter([Blockwise])
129-
def eager_split_lu_solve_steps(fgraph, node):
130-
return _split_lu_solve_steps(fgraph, node, eager=True)
131-
132-
133-
@node_rewriter([Scan])
134-
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
131+
def _scan_split_non_sequence_lu_decomposition_solve(
132+
fgraph, node, *, allowed_assume_a: Container[str]
133+
):
135134
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
136135

137136
The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite.
@@ -146,7 +145,7 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
146145
if (
147146
isinstance(inner_node.op, Blockwise)
148147
and isinstance(inner_node.op.core_op, Solve)
149-
and inner_node.op.core_op.assume_a in _SPLITTABLE_SOLVE_ASSUME_A
148+
and inner_node.op.core_op.assume_a in allowed_assume_a
150149
):
151150
A, b = inner_node.inputs
152151
if all(
@@ -159,8 +158,11 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
159158
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
160159
inner_node = equiv[inner_node]
161160

162-
replace_dict = eager_split_lu_solve_steps.transform(
163-
new_scan_fgraph, inner_node
161+
replace_dict = _split_lu_solve_steps(
162+
new_scan_fgraph,
163+
inner_node,
164+
eager=True,
165+
allowed_assume_a=allowed_assume_a,
164166
)
165167
assert (
166168
isinstance(replace_dict, dict) and len(replace_dict) > 0
@@ -182,11 +184,56 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
182184
return new_outs
183185

184186

187+
@register_specialize
188+
@node_rewriter([Blockwise])
189+
def reuse_lu_decomposition_multiple_solves(fgraph, node):
190+
return _split_lu_solve_steps(
191+
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
192+
)
193+
194+
195+
@node_rewriter([Scan])
196+
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
197+
return _scan_split_non_sequence_lu_decomposition_solve(
198+
fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
199+
)
200+
201+
185202
scan_seqopt1.register(
186-
scan_split_non_sequence_lu_decomposition_solve.__name__,
203+
"scan_split_non_sequence_lu_decomposition_solve",
187204
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
188205
"fast_run",
189206
"scan",
190207
"scan_pushout",
191208
position=2,
192209
)
210+
211+
212+
# JAX cannot decompose tridiagonal matrices
213+
@node_rewriter([Blockwise])
214+
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
215+
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
216+
217+
218+
compile.optdb["specialize"].register(
219+
reuse_lu_decomposition_multiple_solves_jax.__name__,
220+
reuse_lu_decomposition_multiple_solves_jax,
221+
"jax",
222+
)
223+
224+
225+
@node_rewriter([Scan])
226+
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
227+
return _scan_split_non_sequence_lu_decomposition_solve(
228+
fgraph, node, allowed_assume_a={"gen"}
229+
)
230+
231+
232+
scan_seqopt1.register(
233+
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
234+
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
235+
"scan",
236+
"scan_pushout",
237+
"jax",
238+
position=2,
239+
)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import numpy as np
2+
from scipy.linalg import get_lapack_funcs
3+
4+
from pytensor.graph import Apply, Op
5+
from pytensor.tensor.basic import as_tensor, diagonal
6+
from pytensor.tensor.blockwise import Blockwise
7+
from pytensor.tensor.type import tensor, vector
8+
9+
10+
class LUFactorTridiagonal(Op):
11+
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
12+
13+
__props__ = (
14+
"overwrite_dl",
15+
"overwrite_d",
16+
"overwrite_du",
17+
)
18+
gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
19+
20+
def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
21+
self.destroy_map = dm = {}
22+
if overwrite_dl:
23+
dm[0] = [0]
24+
if overwrite_d:
25+
dm[1] = [1]
26+
if overwrite_du:
27+
dm[2] = [2]
28+
self.overwrite_dl = overwrite_dl
29+
self.overwrite_d = overwrite_d
30+
self.overwrite_du = overwrite_du
31+
super().__init__()
32+
33+
def make_node(self, dl, d, du):
34+
dl, d, du = map(as_tensor, (dl, d, du))
35+
36+
if not all(inp.type.ndim == 1 for inp in (dl, d, du)):
37+
raise ValueError("Diagonals must be vectors")
38+
39+
ndl, nd, ndu = (inp.type.shape[-1] for inp in (dl, d, du))
40+
n = (
41+
ndl + 1
42+
if ndl is not None
43+
else (nd if nd is not None else (ndu + 1 if ndu is not None else None))
44+
)
45+
dummy_arrays = [np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du)]
46+
out_dtype = get_lapack_funcs("gttrf", dummy_arrays).dtype
47+
outputs = [
48+
vector(shape=(None if n is None else (n - 1),), dtype=out_dtype),
49+
vector(shape=(n,), dtype=out_dtype),
50+
vector(shape=(None if n is None else n - 1,), dtype=out_dtype),
51+
vector(shape=(None if n is None else n - 2,), dtype=out_dtype),
52+
vector(shape=(n,), dtype=np.int32),
53+
]
54+
return Apply(self, [dl, d, du], outputs)
55+
56+
def perform(self, node, inputs, output_storage):
57+
gttrf = get_lapack_funcs("gttrf", dtype=node.outputs[0].type.dtype)
58+
dl, d, du, du2, ipiv, _ = gttrf(
59+
*inputs,
60+
overwrite_dl=self.overwrite_dl,
61+
overwrite_d=self.overwrite_d,
62+
overwrite_du=self.overwrite_du,
63+
)
64+
output_storage[0][0] = dl
65+
output_storage[1][0] = d
66+
output_storage[2][0] = du
67+
output_storage[3][0] = du2
68+
output_storage[4][0] = ipiv
69+
70+
71+
class SolveLUFactorTridiagonal(Op):
72+
"""Solve a system of linear equations with a tridiagonal coefficient matrix (lapack gttrs)."""
73+
74+
__props__ = ("b_ndim", "overwrite_b", "transposed")
75+
76+
def __init__(self, b_ndim: int, transposed: bool, overwrite_b=False):
77+
if b_ndim not in (1, 2):
78+
raise ValueError("b_ndim must be 1 or 2")
79+
if b_ndim == 1:
80+
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)"
81+
else:
82+
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)"
83+
if overwrite_b:
84+
self.destroy_map = {0: [5]}
85+
self.b_ndim = b_ndim
86+
self.transposed = transposed
87+
self.overwrite_b = overwrite_b
88+
super().__init__()
89+
90+
def make_node(self, dl, d, du, du2, ipiv, b):
91+
dl, d, du, du2, ipiv, b = map(as_tensor, (dl, d, du, du2, ipiv, b))
92+
93+
if b.type.ndim != self.b_ndim:
94+
raise ValueError("Wrang number of dimensions for input b.")
95+
96+
if not all(inp.type.ndim == 1 for inp in (dl, d, du, du2, ipiv)):
97+
raise ValueError("Inputs must be vectors")
98+
99+
ndl, nd, ndu, ndu2, nipiv = (
100+
inp.type.shape[-1] for inp in (dl, d, du, du2, ipiv)
101+
)
102+
nb = b.type.shape[0]
103+
n = (
104+
ndl + 1
105+
if ndl is not None
106+
else (
107+
nd
108+
if nd is not None
109+
else (
110+
ndu + 1
111+
if ndu is not None
112+
else (
113+
ndu2 + 2
114+
if ndu2 is not None
115+
else (nipiv if nipiv is not None else nb)
116+
)
117+
)
118+
)
119+
)
120+
dummy_arrays = [
121+
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)
122+
]
123+
# Seems to always be float64?
124+
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
125+
if self.b_ndim == 1:
126+
output_shape = (n,)
127+
else:
128+
output_shape = (n, b.type.shape[-1])
129+
130+
outputs = [tensor(shape=output_shape, dtype=out_dtype)]
131+
return Apply(self, [dl, d, du, du2, ipiv, b], outputs)
132+
133+
def perform(self, node, inputs, output_storage):
134+
gttrs = get_lapack_funcs("gttrs", dtype=node.outputs[0].type.dtype)
135+
x, _ = gttrs(
136+
*inputs,
137+
overwrite_b=self.overwrite_b,
138+
trans="N" if not self.transposed else "T",
139+
)
140+
output_storage[0][0] = x
141+
142+
143+
def tridiagonal_lu_factor(a):
144+
# Return the decomposition of A implied by a solve tridiagonal
145+
dl, d, du = (diagonal(a, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1))
146+
dl, d, du, du2, ipiv = Blockwise(LUFactorTridiagonal())(dl, d, du)
147+
return dl, d, du, du2, ipiv
148+
149+
150+
def tridiagonal_lu_solve(a_diagonals, b, *, b_ndim: int, transposed: bool = False):
151+
dl, d, du, du2, ipiv = a_diagonals
152+
return Blockwise(SolveLUFactorTridiagonal(b_ndim=b_ndim, transposed=transposed))(
153+
dl, d, du, du2, ipiv, b
154+
)

0 commit comments

Comments
 (0)