Skip to content

Commit aceab09

Browse files
committed
Decompose Tridiagonal Solve into core steps
1 parent 3da78fa commit aceab09

File tree

5 files changed

+282
-8
lines changed

5 files changed

+282
-8
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
55
from pytensor.scan.op import Scan
66
from pytensor.scan.rewriting import scan_seqopt1
7+
from pytensor.tensor._linalg.solve.tridiagonal import (
8+
tridiagonal_lu_factor,
9+
tridiagonal_lu_solve,
10+
)
711
from pytensor.tensor.basic import atleast_Nd
812
from pytensor.tensor.blockwise import Blockwise
913
from pytensor.tensor.elemwise import DimShuffle
@@ -16,13 +20,17 @@
1620
def decompose_A(A, assume_a):
1721
if assume_a == "gen":
1822
return lu_factor(A, check_finite=False)
23+
elif assume_a == "tridiagonal":
24+
return tridiagonal_lu_factor(A)
1925
else:
2026
raise NotImplementedError
2127

2228

2329
def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False):
2430
if assume_a == "gen":
2531
return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed)
32+
elif assume_a == "tridiagonal":
33+
return tridiagonal_lu_solve(A_decomp, b, b_ndim=b_ndim, transposed=transposed)
2634
else:
2735
raise NotImplementedError
2836

@@ -58,7 +66,7 @@ def find_solve_clients(var, assume_a):
5866

5967
assume_a = node.op.core_op.assume_a
6068

61-
if assume_a != "gen":
69+
if assume_a not in {"gen", "tridiagonal"}:
6270
return None
6371

6472
A, _ = get_root_A(node.inputs[0])
@@ -139,7 +147,7 @@ def scan_pushout_solve_lu_decomposition(fgraph, node):
139147
if (
140148
isinstance(inner_node.op, Blockwise)
141149
and isinstance(inner_node.op.core_op, Solve)
142-
and inner_node.op.core_op.assume_a == "gen"
150+
and inner_node.op.core_op.assume_a not in {"gen", "tridiagonal"}
143151
):
144152
# TODO: Move transpose from graph to Low level solve Op
145153
A, _ = inner_node.inputs
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import numpy as np
2+
from scipy.linalg import get_lapack_funcs
3+
4+
from pytensor.graph import Apply, Op
5+
from pytensor.tensor.basic import as_tensor, diagonal
6+
from pytensor.tensor.blockwise import Blockwise
7+
from pytensor.tensor.type import tensor, vector
8+
9+
10+
class LUFactorTridiagonal(Op):
11+
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
12+
13+
__props__ = (
14+
"overwrite_dl",
15+
"overwrite_d",
16+
"overwrite_du",
17+
)
18+
gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
19+
20+
def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
21+
self.destroy_map = dm = {}
22+
if overwrite_dl:
23+
dm[0] = [0]
24+
if overwrite_d:
25+
dm[1] = [1]
26+
if overwrite_du:
27+
dm[2] = [2]
28+
self.overwrite_dl = overwrite_dl
29+
self.overwrite_d = overwrite_d
30+
self.overwrite_du = overwrite_du
31+
super().__init__()
32+
33+
def make_node(self, dl, d, du):
34+
dl, d, du = map(as_tensor, (dl, d, du))
35+
36+
if not all(inp.type.ndim == 1 for inp in (dl, d, du)):
37+
raise ValueError("Diagonals must be vectors")
38+
39+
ndl, nd, ndu = (inp.type.shape[-1] for inp in (dl, d, du))
40+
n = (
41+
ndl + 1
42+
if ndl is not None
43+
else (nd if nd is not None else (ndu + 1 if ndu is not None else None))
44+
)
45+
dummy_arrays = [np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du)]
46+
out_dtype = get_lapack_funcs("gttrf", dummy_arrays).dtype
47+
outputs = [
48+
vector(shape=(None if n is None else (n - 1),), dtype=out_dtype),
49+
vector(shape=(n,), dtype=out_dtype),
50+
vector(shape=(None if n is None else n - 1,), dtype=out_dtype),
51+
vector(shape=(None if n is None else n - 2,), dtype=out_dtype),
52+
vector(shape=(n,), dtype=np.int32),
53+
]
54+
return Apply(self, [dl, d, du], outputs)
55+
56+
def perform(self, node, inputs, output_storage):
57+
gttrf = get_lapack_funcs("gttrf", dtype=node.outputs[0].type.dtype)
58+
dl, d, du, du2, ipiv, _ = gttrf(
59+
*inputs,
60+
overwrite_dl=self.overwrite_dl,
61+
overwrite_d=self.overwrite_d,
62+
overwrite_du=self.overwrite_du,
63+
)
64+
output_storage[0][0] = dl
65+
output_storage[1][0] = d
66+
output_storage[2][0] = du
67+
output_storage[3][0] = du2
68+
output_storage[4][0] = ipiv
69+
70+
71+
class SolveLUFactorTridiagonal(Op):
72+
"""Solve a system of linear equations with a tridiagonal coefficient matrix (lapack gttrs)."""
73+
74+
__props__ = ("b_ndim", "overwrite_b", "transposed")
75+
76+
def __init__(self, b_ndim: int, transposed: bool, overwrite_b=False):
77+
if b_ndim not in (1, 2):
78+
raise ValueError("b_ndim must be 1 or 2")
79+
if b_ndim == 1:
80+
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)"
81+
else:
82+
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)"
83+
if overwrite_b:
84+
self.destroy_map = {0: [5]}
85+
self.b_ndim = b_ndim
86+
self.transposed = transposed
87+
self.overwrite_b = overwrite_b
88+
super().__init__()
89+
90+
def make_node(self, dl, d, du, du2, ipiv, b):
91+
dl, d, du, du2, ipiv, b = map(as_tensor, (dl, d, du, du2, ipiv, b))
92+
93+
if b.type.ndim != self.b_ndim:
94+
raise ValueError("Wrang number of dimensions for input b.")
95+
96+
if not all(inp.type.ndim == 1 for inp in (dl, d, du, du2, ipiv)):
97+
raise ValueError("Inputs must be vectors")
98+
99+
ndl, nd, ndu, ndu2, nipiv = (
100+
inp.type.shape[-1] for inp in (dl, d, du, du2, ipiv)
101+
)
102+
nb = b.type.shape[0]
103+
n = (
104+
ndl + 1
105+
if ndl is not None
106+
else (
107+
nd
108+
if nd is not None
109+
else (
110+
ndu + 1
111+
if ndu is not None
112+
else (
113+
ndu2 + 2
114+
if ndu2 is not None
115+
else (nipiv if nipiv is not None else nb)
116+
)
117+
)
118+
)
119+
)
120+
dummy_arrays = [
121+
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)
122+
]
123+
# Seems to always be float64?
124+
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
125+
if self.b_ndim == 1:
126+
output_shape = (n,)
127+
else:
128+
output_shape = (n, b.type.shape[-1])
129+
130+
outputs = [tensor(shape=output_shape, dtype=out_dtype)]
131+
return Apply(self, [dl, d, du, du2, ipiv, b], outputs)
132+
133+
def perform(self, node, inputs, output_storage):
134+
gttrs = get_lapack_funcs("gttrs", dtype=node.outputs[0].type.dtype)
135+
x, _ = gttrs(
136+
*inputs,
137+
overwrite_b=self.overwrite_b,
138+
trans="N" if not self.transposed else "T",
139+
)
140+
output_storage[0][0] = x
141+
142+
143+
def tridiagonal_lu_factor(a):
144+
# Return the decomposition of A implied by a solve tridiagonal
145+
dl, d, du = (diagonal(a, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1))
146+
dl, d, du, du2, ipiv = Blockwise(LUFactorTridiagonal())(dl, d, du)
147+
return dl, d, du, du2, ipiv
148+
149+
150+
def tridiagonal_lu_solve(a_diagonals, b, *, b_ndim: int, transposed: bool = False):
151+
dl, d, du, du2, ipiv = a_diagonals
152+
return Blockwise(SolveLUFactorTridiagonal(b_ndim=b_ndim, transposed=transposed))(
153+
dl, d, du, du2, ipiv, b
154+
)

pytensor/tensor/rewriting/subtensor.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pytensor.scalar import constant as scalar_constant
2121
from pytensor.tensor.basic import (
2222
Alloc,
23+
ExtractDiag,
2324
Join,
2425
MakeVector,
2526
ScalarFromTensor,
@@ -28,6 +29,7 @@
2829
as_tensor,
2930
cast,
3031
concatenate,
32+
full,
3133
get_scalar_constant_value,
3234
get_underlying_scalar_constant_value,
3335
register_infer_shape,
@@ -2163,3 +2165,82 @@ def ravel_multidimensional_int_idx(fgraph, node):
21632165
"numba",
21642166
use_db_name_as_tag=False, # Not included if only "specialize" is requested
21652167
)
2168+
2169+
2170+
@register_canonicalize
2171+
@register_stabilize
2172+
@register_specialize
2173+
@node_rewriter([ExtractDiag])
2174+
def extract_diag_of_diagonal_set_subtensor(fgraph, node):
2175+
def is_contant_arange(var) -> bool:
2176+
if not (isinstance(var, TensorConstant) and var.type.ndim == 1):
2177+
return False
2178+
2179+
data = var.data
2180+
start, stop = data[0], data[-1] + 1
2181+
return data.size == (stop - start) and (data == np.arange(start, stop)).all()
2182+
2183+
[diag_x] = node.inputs
2184+
if not (
2185+
diag_x.owner is not None
2186+
and isinstance(diag_x.owner.op, AdvancedIncSubtensor)
2187+
and diag_x.owner.op.set_instead_of_inc
2188+
):
2189+
return None
2190+
2191+
x, y, *idxs = diag_x.owner.inputs
2192+
2193+
if not (
2194+
x.type.ndim >= 2
2195+
and None not in x.type.shape[-2:]
2196+
and x.type.shape[-2] == x.type.shape[-1]
2197+
):
2198+
# For now we only support rewrite with static square shape for x
2199+
return None
2200+
2201+
op = node.op
2202+
if op.axis2 > len(idxs):
2203+
return None
2204+
2205+
# Check all non-axis indices are full slices
2206+
axis = {op.axis1, op.axis2}
2207+
if not all(is_full_slice(idx) for i, idx in enumerate(idxs) if i not in axis):
2208+
return None
2209+
2210+
# Check axis indices are arange we would expect from setting on the diagonal
2211+
axis1_idx = idxs[op.axis1]
2212+
axis2_idx = idxs[op.axis2]
2213+
if not (is_contant_arange(axis1_idx) and is_contant_arange(axis2_idx)):
2214+
return None
2215+
2216+
dim_length = x.type.shape[-1]
2217+
offset = op.offset
2218+
start_stop1 = (axis1_idx.data[0], axis1_idx.data[-1] + 1)
2219+
start_stop2 = (axis2_idx.data[0], axis2_idx.data[-1] + 1)
2220+
orig_start1, orig_start2 = start_stop1[0], start_stop2[0]
2221+
2222+
if offset < 0:
2223+
# The logic for checking if we are selecting or not a diagonal for negative offset is the same
2224+
# as the one with positive offset but swapped axis
2225+
start_stop1, start_stop2 = start_stop2, start_stop1
2226+
offset = -offset
2227+
2228+
start1, stop1 = start_stop1
2229+
start2, stop2 = start_stop2
2230+
if (
2231+
start1 == 0
2232+
and start2 == offset
2233+
and stop1 == dim_length - offset
2234+
and stop2 == dim_length
2235+
):
2236+
# We are extracting the just written diagonal
2237+
if y.type.ndim == 0 or y.type.shape[-1] == 1:
2238+
# We may need to broadcast y
2239+
y = full((*x.shape[:-2], dim_length - offset), y, dtype=x.type.dtype)
2240+
return [y]
2241+
elif (orig_start2 - orig_start1) != op.offset:
2242+
# Some other diagonal was written, ignore it
2243+
return [op(x)]
2244+
else:
2245+
# A portion, but no the whole diagonal was written, don't do anything
2246+
return None

pytensor/tensor/subtensor.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3021,12 +3021,7 @@ def make_node(self, x, y, *inputs):
30213021
return Apply(
30223022
self,
30233023
(x, y, *new_inputs),
3024-
[
3025-
tensor(
3026-
dtype=x.type.dtype,
3027-
shape=tuple(1 if s == 1 else None for s in x.type.shape),
3028-
)
3029-
],
3024+
[x.type()],
30303025
)
30313026

30323027
def perform(self, node, inputs, out_):

tests/tensor/rewriting/test_subtensor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import random
2+
13
import numpy as np
24
import pytest
35

@@ -2415,3 +2417,37 @@ def test_unknown_step(self):
24152417
f(test_x, -2),
24162418
test_x[0:3:-2, -1:-6:2, ::],
24172419
)
2420+
2421+
2422+
def test_extract_diag_of_diagonal_set_subtensor():
2423+
A = pt.full((2, 6, 6), np.nan)
2424+
rows = pt.arange(A.shape[-2])
2425+
cols = pt.arange(A.shape[-1])
2426+
write_offsets = [-2, -1, 0, 1, 2]
2427+
# Randomize order of write operations, to make sure rewrite is not sensitive to it
2428+
random.shuffle(write_offsets)
2429+
for offset in write_offsets:
2430+
value = offset + 0.1 * offset
2431+
if offset == 0:
2432+
A = A[..., rows, cols].set(value)
2433+
elif offset > 0:
2434+
A = A[..., rows[:-offset], cols[offset:]].set(value)
2435+
else:
2436+
offset = -offset
2437+
A = A[..., rows[offset:], cols[:-offset]].set(value)
2438+
# Add a partial diagonal along offset 3
2439+
A = A[..., rows[1:-3], cols[4:]].set(np.pi)
2440+
2441+
read_offsets = [-2, -1, 0, 1, 2, 3]
2442+
outs = [A.diagonal(offset=offset, axis1=-2, axis2=-1) for offset in read_offsets]
2443+
rewritten_outs = rewrite_graph(outs, include=("ShapeOpt", "canonicalize"))
2444+
2445+
# Every output should just be an Alloc with value
2446+
expected_outs = []
2447+
for offset in read_offsets[:-1]:
2448+
value = np.asarray(offset + 0.1 * offset, dtype=A.type.dtype)
2449+
expected_outs.append(pt.full((np.int64(2), np.int8(6 - abs(offset))), value))
2450+
# The partial diagonal shouldn't be rewritten
2451+
expected_outs.append(outs[-1])
2452+
2453+
assert equal_computations(rewritten_outs, expected_outs)

0 commit comments

Comments
 (0)