Skip to content

Commit 88dc1f6

Browse files
Merge pull request #570 from SciML/sparsearraysext
Make SparseArrays an extension
2 parents d757452 + 3f62e1b commit 88dc1f6

23 files changed

+2555
-394
lines changed

Project.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1212
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
1313
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
15-
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
1615
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1716
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1817
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -21,13 +20,11 @@ MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
2120
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2221
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2322
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
24-
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
2523
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2624
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2725
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2826
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2927
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
30-
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
3128
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3229
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3330

@@ -45,6 +42,8 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
4542
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4643
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4744
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
45+
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
46+
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
4847

4948
[extensions]
5049
LinearSolveBandedMatricesExt = "BandedMatrices"
@@ -60,6 +59,9 @@ LinearSolveKrylovKitExt = "KrylovKit"
6059
LinearSolveMetalExt = "Metal"
6160
LinearSolvePardisoExt = "Pardiso"
6261
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
62+
LinearSolveRecursiveFactorizationExt = "RecursiveFactorization"
63+
LinearSolveSparseArraysExt = "SparseArrays"
64+
LinearSolveSparspakExt = "Sparspak"
6365

6466
[compat]
6567
AllocCheck = "0.2"
@@ -84,7 +86,6 @@ HYPRE = "1.4.0"
8486
InteractiveUtils = "1.10"
8587
IterativeSolvers = "0.9.3"
8688
JET = "0.8.28, 0.9"
87-
KLU = "0.6"
8889
KernelAbstractions = "0.9.27"
8990
Krylov = "0.9"
9091
KrylovKit = "0.8, 0.9"
@@ -140,11 +141,13 @@ MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
140141
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
141142
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
142143
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
144+
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
143145
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
146+
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
144147
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
145148
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
146149
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
147150
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
148151

149152
[targets]
150-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
153+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak"]

benchmarks/applelu.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ for i in 1:length(ns)
3939
for j in 1:length(algs)
4040
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
4141
copy(b);
42-
u0 = copy(u0),
42+
u0 = copy(u0),
4343
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)
44-
))
44+
))
4545
push!(res[j], luflop(n) / bt / 1e9)
4646
end
4747
end

docs/src/advanced/developing.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ basic machinery. A simplified version is:
1717
```julia
1818
struct MyLUFactorization{P} <: LinearSolve.SciMLLinearSolveAlgorithm end
1919

20-
function LinearSolve.init_cacheval(alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
20+
function LinearSolve.init_cacheval(
21+
alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
2122
verbose::Bool, assump::LinearSolve.OperatorAssumptions)
2223
lu!(convert(AbstractMatrix, A))
2324
end
@@ -41,7 +42,8 @@ need to cache their own things, and so there's one value `cacheval` that is
4142
for the algorithms to modify. The function:
4243

4344
```julia
44-
init_cacheval(alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose, assump)
45+
init_cacheval(
46+
alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose, assump)
4547
```
4648

4749
is what is called at `init` time to create the first `cacheval`. Note that this

docs/src/basics/common_solver_opts.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ in order to give composability. These are also the options taken at `init` time.
66
The following are the options these algorithms take, along with their defaults.
77

88
## General Controls
9+
910
- `alias::LinearAliasSpecifier`: Holds the fields `alias_A` and `alias_b` which specify
1011
whether to alias the matrices `A` and `b` respectively. When these fields are `true`,
1112
`A` and `b` can be written to and changed by the solver algorithm. When fields are `nothing`
12-
the default behavior is used, which is to default to `true` when the algorithm is known
13+
the default behavior is used, which is to default to `true` when the algorithm is known
1314
not to modify the matrices, and false otherwise.
1415
- `verbose`: Whether to print extra information. Defaults to `false`.
1516
- `assumptions`: Sets the assumptions of the operator in order to effect the default
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
module LinearSolveRecursiveFactorizationExt
2+
3+
using LinearSolve
4+
using LinearSolve.LinearAlgebra, LinearSolve.ArrayInterface, RecursiveFactorization
5+
6+
LinearSolve.userecursivefactorization(A::Union{Nothing, AbstractMatrix}) = true
7+
8+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization{P, T};
9+
kwargs...) where {P, T}
10+
A = cache.A
11+
A = convert(AbstractMatrix, A)
12+
fact, ipiv = LinearSolve.@get_cacheval(cache, :RFLUFactorization)
13+
if cache.isfresh
14+
if length(ipiv) != min(size(A)...)
15+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
16+
end
17+
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false)
18+
cache.cacheval = (fact, ipiv)
19+
20+
if !LinearAlgebra.issuccess(fact)
21+
return SciMLBase.build_linear_solution(
22+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
23+
end
24+
25+
cache.isfresh = false
26+
end
27+
y = ldiv!(cache.u, LinearSolve.@get_cacheval(cache, :RFLUFactorization)[1], cache.b)
28+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
29+
end
30+
31+
end

ext/LinearSolveSparseArraysExt.jl

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
module LinearSolveSparseArraysExt
2+
3+
using LinearSolve, LinearAlgebra
4+
using SparseArrays
5+
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
6+
7+
# Can't `using KLU` because cannot have a dependency in there without
8+
# requiring the user does `using KLU`
9+
# But there's no reason to require it because SparseArrays will already
10+
# load SuiteSparse and thus all of the underlying KLU code
11+
include("../src/KLU/klu.jl")
12+
13+
LinearSolve.issparsematrixcsc(A::AbstractSparseMatrixCSC) = true
14+
LinearSolve.issparsematrix(A::AbstractSparseArray) = true
15+
function LinearSolve.make_SparseMatrixCSC(A::AbstractSparseArray)
16+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A))
17+
end
18+
function LinearSolve.makeempty_SparaseMatrixCSC(A::AbstractSparseArray)
19+
SparseMatrixCSC(0, 0, [1], Int[], eltype(A)[])
20+
end
21+
22+
function LinearSolve.init_cacheval(alg::RFLUFactorization,
23+
A::Union{AbstractSparseArray, LinearSolve.SciMLOperators.AbstractSciMLOperator}, b, u, Pl, Pr,
24+
maxiters::Int,
25+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
26+
nothing, nothing
27+
end
28+
29+
function LinearSolve.init_cacheval(
30+
alg::QRFactorization, A::Symmetric{<:Number, <:SparseMatrixCSC}, b, u, Pl, Pr,
31+
maxiters::Int, abstol, reltol, verbose::Bool,
32+
assumptions::OperatorAssumptions)
33+
return nothing
34+
end
35+
36+
function LinearSolve.handle_sparsematrixcsc_lu(A::AbstractSparseMatrixCSC)
37+
lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
38+
check = false)
39+
end
40+
41+
function LinearSolve.defaultalg(
42+
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
43+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.CHOLMODFactorization)
44+
end
45+
46+
function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
47+
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
48+
if assump.issq
49+
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
50+
else
51+
error("Generic number sparse factorization for non-square is not currently handled")
52+
end
53+
end
54+
55+
function LinearSolve.init_cacheval(alg::GenericFactorization,
56+
A::Union{Hermitian{T, <:SparseMatrixCSC},
57+
Symmetric{T, <:SparseMatrixCSC}}, b, u, Pl, Pr,
58+
maxiters::Int, abstol, reltol, verbose::Bool,
59+
assumptions::OperatorAssumptions) where {T}
60+
newA = copy(convert(AbstractMatrix, A))
61+
LinearSolve.do_factorization(alg, newA, b, u)
62+
end
63+
64+
const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0, [1],
65+
Int[], Float64[]))
66+
67+
function LinearSolve.init_cacheval(
68+
alg::UMFPACKFactorization, A::SparseMatrixCSC{Float64, Int}, b, u,
69+
Pl, Pr,
70+
maxiters::Int, abstol, reltol,
71+
verbose::Bool, assumptions::OperatorAssumptions)
72+
PREALLOCATED_UMFPACK
73+
end
74+
75+
function LinearSolve.init_cacheval(
76+
alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
77+
maxiters::Int, abstol,
78+
reltol,
79+
verbose::Bool, assumptions::OperatorAssumptions)
80+
A = convert(AbstractMatrix, A)
81+
zerobased = SparseArrays.getcolptr(A)[1] == 0
82+
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
83+
rowvals(A), nonzeros(A)))
84+
end
85+
86+
function SciMLBase.solve!(
87+
cache::LinearSolve.LinearCache, alg::UMFPACKFactorization; kwargs...)
88+
A = cache.A
89+
A = convert(AbstractMatrix, A)
90+
if cache.isfresh
91+
cacheval = LinearSolve.@get_cacheval(cache, :UMFPACKFactorization)
92+
if alg.reuse_symbolic
93+
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
94+
if alg.check_pattern && pattern_changed(cacheval, A)
95+
fact = lu(
96+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
97+
nonzeros(A)),
98+
check = false)
99+
else
100+
fact = lu!(cacheval,
101+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
102+
nonzeros(A)), check = false)
103+
end
104+
else
105+
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
106+
check = false)
107+
end
108+
cache.cacheval = fact
109+
cache.isfresh = false
110+
end
111+
112+
F = LinearSolve.@get_cacheval(cache, :UMFPACKFactorization)
113+
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
114+
y = ldiv!(cache.u, F, cache.b)
115+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
116+
else
117+
SciMLBase.build_linear_solution(
118+
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
119+
end
120+
end
121+
122+
const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
123+
Float64[]))
124+
125+
function LinearSolve.init_cacheval(
126+
alg::KLUFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl,
127+
Pr,
128+
maxiters::Int, abstol, reltol,
129+
verbose::Bool, assumptions::OperatorAssumptions)
130+
PREALLOCATED_KLU
131+
end
132+
133+
function LinearSolve.init_cacheval(
134+
alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
135+
maxiters::Int, abstol,
136+
reltol,
137+
verbose::Bool, assumptions::OperatorAssumptions)
138+
A = convert(AbstractMatrix, A)
139+
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
140+
nonzeros(A)))
141+
end
142+
143+
# TODO: guard this against errors
144+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization; kwargs...)
145+
A = cache.A
146+
A = convert(AbstractMatrix, A)
147+
if cache.isfresh
148+
cacheval = LinearSolve.@get_cacheval(cache, :KLUFactorization)
149+
if alg.reuse_symbolic
150+
if alg.check_pattern && pattern_changed(cacheval, A)
151+
fact = KLU.klu(
152+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
153+
nonzeros(A)),
154+
check = false)
155+
else
156+
fact = KLU.klu!(cacheval, nonzeros(A), check = false)
157+
end
158+
else
159+
# New fact each time since the sparsity pattern can change
160+
# and thus it needs to reallocate
161+
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
162+
nonzeros(A)))
163+
end
164+
cache.cacheval = fact
165+
cache.isfresh = false
166+
end
167+
F = LinearSolve.@get_cacheval(cache, :KLUFactorization)
168+
if F.common.status == KLU.KLU_OK
169+
y = ldiv!(cache.u, F, cache.b)
170+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
171+
else
172+
SciMLBase.build_linear_solution(
173+
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
174+
end
175+
end
176+
177+
const PREALLOCATED_CHOLMOD = cholesky(SparseMatrixCSC(0, 0, [1], Int[], Float64[]))
178+
179+
function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
180+
A::Union{SparseMatrixCSC{T, Int}, Symmetric{T, SparseMatrixCSC{T, Int}}}, b, u,
181+
Pl, Pr,
182+
maxiters::Int, abstol, reltol,
183+
verbose::Bool, assumptions::OperatorAssumptions) where {T <:
184+
Union{Float32, Float64}}
185+
PREALLOCATED_CHOLMOD
186+
end
187+
188+
function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
189+
A::Union{AbstractSparseArray, LinearSolve.GPUArraysCore.AnyGPUArray,
190+
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
191+
maxiters::Int, abstol, reltol, verbose::Bool,
192+
assumptions::OperatorAssumptions)
193+
LinearSolve.ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
194+
end
195+
196+
# Specialize QR for the non-square case
197+
# Missing ldiv! definitions: https://github.com/JuliaSparse/SparseArrays.jl/issues/242
198+
function LinearSolve._ldiv!(x::Vector,
199+
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
200+
SparseArrays.SPQR.QRSparse,
201+
SparseArrays.CHOLMOD.Factor}, b::Vector)
202+
x .= A \ b
203+
end
204+
205+
function LinearSolve._ldiv!(x::AbstractVector,
206+
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
207+
SparseArrays.SPQR.QRSparse,
208+
SparseArrays.CHOLMOD.Factor}, b::AbstractVector)
209+
x .= A \ b
210+
end
211+
212+
# Ambiguity removal
213+
function LinearSolve._ldiv!(::LinearSolve.SVector,
214+
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
215+
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
216+
b::AbstractVector)
217+
(A \ b)
218+
end
219+
function LinearSolve._ldiv!(::LinearSolve.SVector,
220+
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
221+
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
222+
b::LinearSolve.SVector)
223+
(A \ b)
224+
end
225+
226+
function pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
227+
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
228+
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
229+
fact.rowval)
230+
end
231+
232+
function LinearSolve.defaultalg(
233+
A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
234+
assump::OperatorAssumptions{Bool}) where {Ti}
235+
if assump.issq
236+
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
237+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KLUFactorization)
238+
else
239+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.UMFPACKFactorization)
240+
end
241+
else
242+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization)
243+
end
244+
end
245+
246+
LinearSolve.PrecompileTools.@compile_workload begin
247+
A = sprand(4, 4, 0.3) + I
248+
b = rand(4)
249+
prob = LinearProblem(A, b)
250+
sol = solve(prob, KLUFactorization())
251+
sol = solve(prob, UMFPACKFactorization())
252+
end
253+
254+
end

0 commit comments

Comments
 (0)