Skip to content

Make SparseArrays an extension #570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2c0d8ab
Make RecursiveFactorization.jl optional
ChrisRackauckas Feb 1, 2025
75d2813
get it working
ChrisRackauckas Feb 1, 2025
b62f277
Update ext/LinearSolveRecursiveFactorization.jl
ChrisRackauckas Feb 1, 2025
d81e8a2
Update src/default.jl
ChrisRackauckas Feb 1, 2025
aa865ba
Update src/factorization.jl
ChrisRackauckas Feb 1, 2025
b8cd21b
Update src/extension_algs.jl
ChrisRackauckas Feb 1, 2025
1c2eae4
Update src/extension_algs.jl
ChrisRackauckas Feb 1, 2025
6606dd4
Update src/extension_algs.jl
ChrisRackauckas Feb 1, 2025
45d94fd
Update src/default.jl
ChrisRackauckas Feb 1, 2025
cd3a29e
add RecursiveFactorization in tests
ChrisRackauckas Feb 1, 2025
80e6614
Update and rename LinearSolveRecursiveFactorization.jl to LinearSolve…
ChrisRackauckas Feb 5, 2025
c280c46
Update LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
50e8cbe
Update LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
4e358f6
Update LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
0278ecd
Update LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
e18d864
namespace PreallocatedLU
ChrisRackauckas Feb 5, 2025
a689fd6
one more
ChrisRackauckas Feb 5, 2025
8faa4e6
one more
ChrisRackauckas Feb 5, 2025
7d1f54a
namespace
ChrisRackauckas Feb 5, 2025
c93a6b2
namespace
ChrisRackauckas Feb 5, 2025
41a786b
fix default
ChrisRackauckas Feb 5, 2025
7569440
fix inference on recfact load
ChrisRackauckas Feb 5, 2025
8704062
don't double
ChrisRackauckas Feb 5, 2025
3ad68c7
Update src/extension_algs.jl
ChrisRackauckas Feb 5, 2025
4d3a346
Update src/extension_algs.jl
ChrisRackauckas Feb 5, 2025
486d924
Update src/default.jl
ChrisRackauckas Feb 5, 2025
1fad945
Update ext/LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
e77fb72
Make RecursiveFactorization.jl optional
ChrisRackauckas Feb 1, 2025
4f6225c
WIP: Make SparseArrays an extension
ChrisRackauckas Feb 1, 2025
536611a
finishing touches
ChrisRackauckas Feb 5, 2025
3f62e1b
format
ChrisRackauckas Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand All @@ -21,13 +20,11 @@ MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

Expand All @@ -45,6 +42,8 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"

[extensions]
LinearSolveBandedMatricesExt = "BandedMatrices"
Expand All @@ -60,6 +59,9 @@ LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
LinearSolveRecursiveFactorizationExt = "RecursiveFactorization"
LinearSolveSparseArraysExt = "SparseArrays"
LinearSolveSparspakExt = "Sparspak"

[compat]
AllocCheck = "0.2"
Expand All @@ -84,7 +86,6 @@ HYPRE = "1.4.0"
InteractiveUtils = "1.10"
IterativeSolvers = "0.9.3"
JET = "0.8.28, 0.9"
KLU = "0.6"
KernelAbstractions = "0.9.27"
Krylov = "0.9"
KrylovKit = "0.8, 0.9"
Expand Down Expand Up @@ -140,11 +141,13 @@ MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak"]
4 changes: 2 additions & 2 deletions benchmarks/applelu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ for i in 1:length(ns)
for j in 1:length(algs)
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
copy(b);
u0 = copy(u0),
u0 = copy(u0),
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)
))
))
push!(res[j], luflop(n) / bt / 1e9)
end
end
Expand Down
6 changes: 4 additions & 2 deletions docs/src/advanced/developing.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ basic machinery. A simplified version is:
```julia
struct MyLUFactorization{P} <: LinearSolve.SciMLLinearSolveAlgorithm end

function LinearSolve.init_cacheval(alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
function LinearSolve.init_cacheval(
alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
verbose::Bool, assump::LinearSolve.OperatorAssumptions)
lu!(convert(AbstractMatrix, A))
end
Expand All @@ -41,7 +42,8 @@ need to cache their own things, and so there's one value `cacheval` that is
for the algorithms to modify. The function:

```julia
init_cacheval(alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose, assump)
init_cacheval(
alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose, assump)
```

is what is called at `init` time to create the first `cacheval`. Note that this
Expand Down
3 changes: 2 additions & 1 deletion docs/src/basics/common_solver_opts.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ in order to give composability. These are also the options taken at `init` time.
The following are the options these algorithms take, along with their defaults.

## General Controls

- `alias::LinearAliasSpecifier`: Holds the fields `alias_A` and `alias_b` which specify
whether to alias the matrices `A` and `b` respectively. When these fields are `true`,
`A` and `b` can be written to and changed by the solver algorithm. When fields are `nothing`
the default behavior is used, which is to default to `true` when the algorithm is known
the default behavior is used, which is to default to `true` when the algorithm is known
not to modify the matrices, and false otherwise.
- `verbose`: Whether to print extra information. Defaults to `false`.
- `assumptions`: Sets the assumptions of the operator in order to effect the default
Expand Down
31 changes: 31 additions & 0 deletions ext/LinearSolveRecursiveFactorizationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module LinearSolveRecursiveFactorizationExt

using LinearSolve
using LinearSolve.LinearAlgebra, LinearSolve.ArrayInterface, RecursiveFactorization

LinearSolve.userecursivefactorization(A::Union{Nothing, AbstractMatrix}) = true

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization{P, T};
kwargs...) where {P, T}
A = cache.A
A = convert(AbstractMatrix, A)
fact, ipiv = LinearSolve.@get_cacheval(cache, :RFLUFactorization)
if cache.isfresh
if length(ipiv) != min(size(A)...)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
end
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv)

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
end

cache.isfresh = false
end
y = ldiv!(cache.u, LinearSolve.@get_cacheval(cache, :RFLUFactorization)[1], cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

end
254 changes: 254 additions & 0 deletions ext/LinearSolveSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
module LinearSolveSparseArraysExt

using LinearSolve, LinearAlgebra
using SparseArrays
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr

# Can't `using KLU` because cannot have a dependency in there without
# requiring the user does `using KLU`
# But there's no reason to require it because SparseArrays will already
# load SuiteSparse and thus all of the underlying KLU code
include("../src/KLU/klu.jl")

LinearSolve.issparsematrixcsc(A::AbstractSparseMatrixCSC) = true
LinearSolve.issparsematrix(A::AbstractSparseArray) = true
function LinearSolve.make_SparseMatrixCSC(A::AbstractSparseArray)
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A))
end
function LinearSolve.makeempty_SparaseMatrixCSC(A::AbstractSparseArray)
SparseMatrixCSC(0, 0, [1], Int[], eltype(A)[])
end

function LinearSolve.init_cacheval(alg::RFLUFactorization,
A::Union{AbstractSparseArray, LinearSolve.SciMLOperators.AbstractSciMLOperator}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing, nothing
end

function LinearSolve.init_cacheval(
alg::QRFactorization, A::Symmetric{<:Number, <:SparseMatrixCSC}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
return nothing
end

function LinearSolve.handle_sparsematrixcsc_lu(A::AbstractSparseMatrixCSC)
lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
end

function LinearSolve.defaultalg(
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.CHOLMODFactorization)
end

function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
if assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
else
error("Generic number sparse factorization for non-square is not currently handled")
end
end

function LinearSolve.init_cacheval(alg::GenericFactorization,
A::Union{Hermitian{T, <:SparseMatrixCSC},
Symmetric{T, <:SparseMatrixCSC}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {T}
newA = copy(convert(AbstractMatrix, A))
LinearSolve.do_factorization(alg, newA, b, u)
end

const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0, [1],
Int[], Float64[]))

function LinearSolve.init_cacheval(
alg::UMFPACKFactorization, A::SparseMatrixCSC{Float64, Int}, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_UMFPACK
end

function LinearSolve.init_cacheval(
alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
A = convert(AbstractMatrix, A)
zerobased = SparseArrays.getcolptr(A)[1] == 0
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
rowvals(A), nonzeros(A)))
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::UMFPACKFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = LinearSolve.@get_cacheval(cache, :UMFPACKFactorization)
if alg.reuse_symbolic
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
if alg.check_pattern && pattern_changed(cacheval, A)
fact = lu(
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
check = false)
else
fact = lu!(cacheval,
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)), check = false)
end
else
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
end
cache.cacheval = fact
cache.isfresh = false
end

F = LinearSolve.@get_cacheval(cache, :UMFPACKFactorization)
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
y = ldiv!(cache.u, F, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
else
SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
end
end

const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
Float64[]))

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl,
Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_KLU
end

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
A = convert(AbstractMatrix, A)
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
end

# TODO: guard this against errors
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = LinearSolve.@get_cacheval(cache, :KLUFactorization)
if alg.reuse_symbolic
if alg.check_pattern && pattern_changed(cacheval, A)
fact = KLU.klu(
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
check = false)
else
fact = KLU.klu!(cacheval, nonzeros(A), check = false)
end
else
# New fact each time since the sparsity pattern can change
# and thus it needs to reallocate
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
end
cache.cacheval = fact
cache.isfresh = false
end
F = LinearSolve.@get_cacheval(cache, :KLUFactorization)
if F.common.status == KLU.KLU_OK
y = ldiv!(cache.u, F, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
else
SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
end
end

const PREALLOCATED_CHOLMOD = cholesky(SparseMatrixCSC(0, 0, [1], Int[], Float64[]))

function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
A::Union{SparseMatrixCSC{T, Int}, Symmetric{T, SparseMatrixCSC{T, Int}}}, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions) where {T <:
Union{Float32, Float64}}
PREALLOCATED_CHOLMOD
end

function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
A::Union{AbstractSparseArray, LinearSolve.GPUArraysCore.AnyGPUArray,
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
LinearSolve.ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
end

# Specialize QR for the non-square case
# Missing ldiv! definitions: https://github.com/JuliaSparse/SparseArrays.jl/issues/242
function LinearSolve._ldiv!(x::Vector,
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
SparseArrays.SPQR.QRSparse,
SparseArrays.CHOLMOD.Factor}, b::Vector)
x .= A \ b
end

function LinearSolve._ldiv!(x::AbstractVector,
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
SparseArrays.SPQR.QRSparse,
SparseArrays.CHOLMOD.Factor}, b::AbstractVector)
x .= A \ b
end

# Ambiguity removal
function LinearSolve._ldiv!(::LinearSolve.SVector,
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
b::AbstractVector)
(A \ b)
end
function LinearSolve._ldiv!(::LinearSolve.SVector,
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
b::LinearSolve.SVector)
(A \ b)
end

function pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
fact.rowval)
end

function LinearSolve.defaultalg(
A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Ti}
if assump.issq
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KLUFactorization)
else
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.UMFPACKFactorization)
end
else
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization)
end
end

LinearSolve.PrecompileTools.@compile_workload begin
A = sprand(4, 4, 0.3) + I
b = rand(4)
prob = LinearProblem(A, b)
sol = solve(prob, KLUFactorization())
sol = solve(prob, UMFPACKFactorization())
end

end
Loading
Loading