Skip to content

Commit

Permalink
Merge pull request #564 from jClugstor/alias_API
Browse files Browse the repository at this point in the history
Use Aliasing API for alias_A and alias_b
  • Loading branch information
ChrisRackauckas authored Feb 5, 2025
2 parents 378f67f + ebda656 commit d757452
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ RecursiveArrayTools = "3.8"
RecursiveFactorization = "0.2.14"
Reexport = "1"
SafeTestsets = "0.1"
SciMLBase = "2.26.3"
SciMLBase = "2.70"
SciMLOperators = "0.3.7"
Setfield = "1"
SparseArrays = "1.10"
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/applelu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ for i in 1:length(ns)
for j in 1:length(algs)
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
copy(b);
u0 = copy(u0),
alias_A = true,
alias_b = true))
u0 = copy(u0),
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)
))
push!(res[j], luflop(n) / bt / 1e9)
end
end
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/cudalu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ for i in 1:length(ns)
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
copy(b);
u0 = copy(u0),
alias_A = true,
alias_b = true))
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))
push!(res[j], luflop(n) / bt / 1e9)
end
end
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ for i in 1:length(ns)
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
copy(b);
u0 = copy(u0),
alias_A = true,
alias_b = true))
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))
push!(res[j], luflop(n) / bt / 1e9)
end
end
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/metallu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ for i in 1:length(ns)
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A),
copy(b);
u0 = copy(u0),
alias_A = true,
alias_b = true))
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))
GC.gc()
push!(res[j], luflop(n) / bt / 1e9)
end
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/sparselu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ function run_and_plot(; dims = [1, 2, 3], kmax = 12)
copy($A),
copy($b);
u0 = copy($u0),
alias_A = true,
alias_b = true))
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))
push!(res[dim][j], bt)
end
end
Expand Down
14 changes: 5 additions & 9 deletions docs/src/basics/common_solver_opts.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ in order to give composability. These are also the options taken at `init` time.
The following are the options these algorithms take, along with their defaults.

## General Controls

- `alias_A::Bool`: Whether to alias the matrix `A` or use a copy by default. When `true`,
algorithms like LU-factorization can be faster by reusing the memory via `lu!`,
but care must be taken as the original input will be modified. Default is `true` if the
algorithm is known not to modify `A`, otherwise is `false`.
- `alias_b::Bool`: Whether to alias the matrix `b` or use a copy by default. When `true`,
algorithms can write and change `b` upon usage. Care must be taken as the
original input will be modified. Default is `true` if the algorithm is known not to
modify `b`, otherwise `false`.
- `alias::LinearAliasSpecifier`: Holds the fields `alias_A` and `alias_b` which specify
whether to alias the matrices `A` and `b` respectively. When these fields are `true`,
`A` and `b` can be written to and changed by the solver algorithm. When fields are `nothing`
the default behavior is used, which is to default to `true` when the algorithm is known
not to modify the matrices, and false otherwise.
- `verbose`: Whether to print extra information. Defaults to `false`.
- `assumptions`: Sets the assumptions of the operator in order to effect the default
choice algorithm. See the [Operator Assumptions page for more details](@ref assumptions).
Expand Down
43 changes: 41 additions & 2 deletions ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning, LinearSolveAdjoint
using SciMLBase: LinearProblem, SciMLBase
using SciMLBase: LinearProblem, LinearAliasSpecifier, SciMLBase
using UnPack: @unpack
using Setfield: @set!

Expand Down Expand Up @@ -55,7 +55,7 @@ end

function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
args...;
alias_A = false, alias_b = false,
alias = LinearAliasSpecifier(),
# TODO: Implement eltype for HYPREMatrix in HYPRE.jl? Looks useful
# even if it is not AbstractArray.
abstol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
Expand All @@ -72,6 +72,45 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
kwargs...)
@unpack A, b, u0, p = prob

if haskey(kwargs, :alias_A) || haskey(kwargs, :alias_b)
aliases = LinearAliasSpecifier()

if haskey(kwargs, :alias_A)
message = "`alias_A` keyword argument is deprecated, to set `alias_A`,
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))"
Base.depwarn(message, :init)
Base.depwarn(message, :solve)
aliases = LinearAliasSpecifier(alias_A = values(kwargs).alias_A)
end

if haskey(kwargs, :alias_b)
message = "`alias_b` keyword argument is deprecated, to set `alias_b`,
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))"
Base.depwarn(message, :init)
Base.depwarn(message, :solve)
aliases = LinearAliasSpecifier(
alias_A = aliases.alias_A, alias_b = values(kwargs).alias_b)
end
else
if alias isa Bool
aliases = LinearAliasSpecifier(alias = alias)
else
aliases = alias
end
end

if isnothing(aliases.alias_A)
alias_A = false
else
alias_A = aliases.alias_A
end

if isnothing(aliases.alias_b)
alias_b = false
else
alias_b = aliases.alias_b
end

A = A isa HYPREMatrix ? A : HYPREMatrix(A)
b = b isa HYPREVector ? b : HYPREVector(b)
u0 = u0 isa HYPREVector ? u0 : (u0 === nothing ? nothing : HYPREVector(u0))
Expand Down
2 changes: 1 addition & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using LinearAlgebra
using SparseArrays
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
using LazyArrays: @~, BroadcastArray
using SciMLBase: AbstractLinearAlgorithm
using SciMLBase: AbstractLinearAlgorithm, LinearAliasSpecifier
using SciMLOperators
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
using Setfield
Expand Down
42 changes: 40 additions & 2 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ __init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltyp

function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...;
alias_A = default_alias_A(alg, prob.A, prob.b),
alias_b = default_alias_b(alg, prob.A, prob.b),
alias = LinearAliasSpecifier(),
abstol = default_tol(real(eltype(prob.b))),
reltol = default_tol(real(eltype(prob.b))),
maxiters::Int = length(prob.b),
Expand All @@ -152,6 +151,45 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
kwargs...)
(; A, b, u0, p) = prob

if haskey(kwargs,:alias_A) || haskey(kwargs,:alias_b)
aliases = LinearAliasSpecifier()

if haskey(kwargs, :alias_A)
message = "`alias_A` keyword argument is deprecated, to set `alias_A`,
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))"
Base.depwarn(message, :init)
Base.depwarn(message, :solve)
aliases = LinearAliasSpecifier(alias_A = values(kwargs).alias_A)
end

if haskey(kwargs, :alias_b)
message = "`alias_b` keyword argument is deprecated, to set `alias_b`,
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))"
Base.depwarn(message, :init)
Base.depwarn(message, :solve)
aliases = LinearAliasSpecifier(alias_A = aliases.alias_A, alias_b = values(kwargs).alias_b)
end
else
if alias isa Bool
aliases = LinearAliasSpecifier(alias = alias)
else
aliases = alias
end
end

if isnothing(aliases.alias_A)
alias_A = default_alias_A(alg, prob.A, prob.b)
else
alias_A = aliases.alias_A
end

if isnothing(aliases.alias_b)
alias_b = default_alias_b(alg, prob.A, prob.b)
else
alias_b = aliases.alias_b
end


A = if alias_A || A isa SMatrix
A
elseif A isa Array
Expand Down
4 changes: 2 additions & 2 deletions test/gpu/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ prob2 = LinearProblem(transpose(A), b)

@testset "Adjoint/Transpose Type: $(alg)" for alg in (NormalCholeskyFactorization(),
CholeskyFactorization(), LUFactorization(), QRFactorization(), nothing)
sol = solve(prob1, alg; alias_A = false)
sol = solve(prob1, alg; alias = LinearAliasSpecifier(alias = LinearAliasSpecifier(alias_A = false)))
@test norm(A' * sol.u .- b) < 1e-5

sol = solve(prob2, alg; alias_A = false)
sol = solve(prob2, alg; alias = LinearAliasSpecifier(alias_A = false))
@test norm(transpose(A) * sol.u .- b) < 1e-5
end
6 changes: 3 additions & 3 deletions test/resolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
alg in [LDLtFactorization] && (A = SymTridiagonal(A))
b = [1.0, 2.0]
prob = LinearProblem(A, b)
linsolve = init(prob, alg(), alias_A = false, alias_b = false)
linsolve = init(prob, alg(), alias = LinearAliasSpecifier(alias_A = false, alias_b = false))
@test solve!(linsolve).u [-2.0, 1.5]
@test !linsolve.isfresh
@test solve!(linsolve).u [-2.0, 1.5]
Expand All @@ -48,7 +48,7 @@ end
A = Diagonal([1.0, 4.0])
b = [1.0, 2.0]
prob = LinearProblem(A, b)
linsolve = init(prob, DiagonalFactorization(), alias_A = false, alias_b = false)
linsolve = init(prob, DiagonalFactorization(), alias = LinearAliasSpecifier(alias_A = false, alias_b = false))
@test solve!(linsolve).u [1.0, 0.5]
@test solve!(linsolve).u [1.0, 0.5]
A = Diagonal([1.0, 4.0])
Expand All @@ -59,7 +59,7 @@ A = Symmetric([1.0 2.0
2.0 1.0])
b = [1.0, 2.0]
prob = LinearProblem(A, b)
linsolve = init(prob, BunchKaufmanFactorization(), alias_A = false, alias_b = false)
linsolve = init(prob, BunchKaufmanFactorization(), alias = LinearAliasSpecifier(alias_A = false, alias_b = false))
@test solve!(linsolve).u [1.0, 0.0]
@test solve!(linsolve).u [1.0, 0.0]
A = Symmetric([1.0 2.0
Expand Down

0 comments on commit d757452

Please sign in to comment.