Skip to content

Commit b79228d

Browse files
committed
Add tests
1 parent 8f70a4e commit b79228d

File tree

8 files changed

+126
-46
lines changed

8 files changed

+126
-46
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ FiniteDiff = "2"
5050
ForwardDiff = "0.10.3"
5151
LeastSquaresOptim = "0.8"
5252
LineSearches = "7"
53-
LinearAlgebra = "1.9"
53+
LinearAlgebra = "<0.0.1, 1"
5454
LinearSolve = "2.12"
5555
NonlinearProblemLibrary = "0.1"
5656
PrecompileTools = "1"
@@ -59,7 +59,7 @@ Reexport = "0.2, 1"
5959
SciMLBase = "2.8.2"
6060
SciMLOperators = "0.3"
6161
SimpleNonlinearSolve = "0.1.23"
62-
SparseArrays = "1.9"
62+
SparseArrays = "<0.0.1, 1"
6363
SparseDiffTools = "2.12"
6464
StaticArraysCore = "1.4"
6565
UnPack = "1.0"

src/gaussnewton.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,30 @@ for large-scale and numerically-difficult nonlinear least squares problems.
2929
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
3030
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
3131
used here directly, and they will be converted to the correct `LineSearch`.
32+
- `vjp_autodiff`: Automatic Differentiation Backend used for vector-jacobian products.
33+
This is applicable if the linear solver doesn't require a concrete jacobian, for eg.,
34+
Krylov Methods. Defaults to `nothing`, which means if the problem is out of place and
35+
`Zygote` is loaded then, we use `AutoZygote`. In all other, cases `FiniteDiff` is used.
3236
"""
3337
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
3438
ad::AD
3539
linsolve
3640
precs
3741
linesearch
42+
vjp_autodiff
3843
end
3944

4045
function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
41-
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
46+
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch, alg.vjp_autodiff)
4247
end
4348

4449
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
45-
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
50+
linesearch = LineSearch(), precs = DEFAULT_PRECS, vjp_autodiff = nothing,
51+
adkwargs...)
4652
ad = default_adargs_to_adtype(; adkwargs...)
4753
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
48-
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
54+
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch,
55+
vjp_autodiff)
4956
end
5057

5158
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}

src/jacobian.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
JᵀJ
33
Jᵀ
44
end
5+
6+
SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)
7+
58
sparsity_detection_alg(_, _) = NoSparsityDetection()
69
function sparsity_detection_alg(f, ad::AbstractSparseADType)
710
if f.sparsity === nothing
@@ -67,12 +70,10 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
6770
jac_cache = nothing
6871
end
6972

70-
# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
71-
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
72-
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
73+
J = if !(linsolve_needs_jac || alg_wants_jac)
7374
if f.jvp === nothing
7475
# We don't need to construct the Jacobian
75-
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
76+
JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
7677
else
7778
if iip
7879
jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du)
@@ -96,9 +97,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
9697
du = _mutable_zero(u)
9798

9899
if needsJᵀJ
99-
# TODO: Pass in `jac_transpose_autodiff`
100-
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u;
101-
jac_autodiff = __get_nonsparse_ad(alg.ad))
100+
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f,
101+
vjp_autodiff = __get_nonsparse_ad(_getproperty(alg, Val(:vjp_autodiff))),
102+
jvp_autodiff = __get_nonsparse_ad(alg.ad))
102103
end
103104

104105
if linsolve_init
@@ -141,26 +142,29 @@ function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...)
141142
JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
142143
return JᵀJ, J' * fu
143144
end
144-
function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...;
145-
jac_transpose_autodiff = nothing, jac_autodiff = nothing, kwargs...)
146-
autodiff = __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf)
147-
Jᵀ = VecJac(uf, u; autodiff)
145+
function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; f = nothing,
146+
vjp_autodiff = nothing, jvp_autodiff = nothing, kwargs...)
147+
# FIXME: Proper fix to this requires the FunctionOperator patch
148+
if f !== nothing && f.vjp !== nothing
149+
@warn "Currently we don't make use of user provided `jvp`. This is planned to be \
150+
fixed in the near future."
151+
end
152+
autodiff = __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
153+
Jᵀ = VecJac(uf, u; fu, autodiff)
148154
JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u)
149155
JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ)
150156
Jᵀfu = Jᵀ * fu
151157
return JᵀJ, Jᵀfu
152158
end
153159

154-
SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)
155-
156-
function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf)
157-
if jac_transpose_autodiff === nothing
160+
function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
161+
if vjp_autodiff === nothing
158162
if isinplace(uf)
159163
# VecJac can be only FiniteDiff
160164
return AutoFiniteDiff()
161165
else
162166
# Short circuit if we see that FiniteDiff was used for J computation
163-
jac_autodiff isa AutoFiniteDiff && return jac_autodiff
167+
jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
164168
# Check if Zygote is loaded then use Zygote else use FiniteDiff
165169
if haskey(Base.loaded_modules,
166170
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
@@ -170,7 +174,13 @@ function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff,
170174
end
171175
end
172176
else
173-
return __get_nonsparse_ad(jac_transpose_autodiff)
177+
ad = __get_nonsparse_ad(vjp_autodiff)
178+
if isinplace(uf) && ad isa AutoZygote
179+
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \
180+
Falling back to finite differencing."
181+
return AutoFiniteDiff()
182+
end
183+
return ad
174184
end
175185
end
176186

src/linesearch.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
131131
end
132132

133133
function g!(u, fu)
134+
if f.jvp !== nothing
135+
@warn "Currently we don't make use of user provided `jvp` in linesearch. This \
136+
is planned to be fixed in the near future." maxlog=1
137+
end
134138
op = VecJac(SciMLBase.JacobianWrapper(f, p), u; fu = fu1, autodiff)
135139
if iip
136140
mul!(g₀, op, fu)

src/trustRegion.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,12 @@ for large-scale and numerically-difficult nonlinear systems.
141141
`expand_threshold < r` (with `r` defined in `shrink_threshold`). Defaults to `2.0`.
142142
- `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
143143
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
144+
- `vjp_autodiff`: Automatic Differentiation Backend used for vector-jacobian products.
145+
This is applicable if the linear solver doesn't require a concrete jacobian, for eg.,
146+
Krylov Methods. Defaults to `nothing`, which means if the problem is out of place and
147+
`Zygote` is loaded then, we use `AutoZygote`. In all other, cases `FiniteDiff` is used.
144148
"""
145-
@concrete struct TrustRegion{CJ, AD, MTR} <:
146-
AbstractNewtonAlgorithm{CJ, AD}
149+
@concrete struct TrustRegion{CJ, AD, MTR} <: AbstractNewtonAlgorithm{CJ, AD}
147150
ad::AD
148151
linsolve
149152
precs
@@ -156,25 +159,27 @@ for large-scale and numerically-difficult nonlinear systems.
156159
shrink_factor::MTR
157160
expand_factor::MTR
158161
max_shrink_times::Int
162+
vjp_autodiff
159163
end
160164

161165
function set_ad(alg::TrustRegion{CJ}, ad) where {CJ}
162166
return TrustRegion{CJ}(ad, alg.linsolve, alg.precs, alg.radius_update_scheme,
163167
alg.max_trust_radius, alg.initial_trust_radius, alg.step_threshold,
164168
alg.shrink_threshold, alg.expand_threshold, alg.shrink_factor, alg.expand_factor,
165-
alg.max_shrink_times)
169+
alg.max_shrink_times, alg.vjp_autodiff)
166170
end
167171

168172
function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
169173
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple, #defaults to conventional radius update
170174
max_trust_radius::Real = 0 // 1, initial_trust_radius::Real = 0 // 1,
171175
step_threshold::Real = 1 // 10000, shrink_threshold::Real = 1 // 4,
172176
expand_threshold::Real = 3 // 4, shrink_factor::Real = 1 // 4,
173-
expand_factor::Real = 2 // 1, max_shrink_times::Int = 32, adkwargs...)
177+
expand_factor::Real = 2 // 1, max_shrink_times::Int = 32, vjp_autodiff = nothing,
178+
adkwargs...)
174179
ad = default_adargs_to_adtype(; adkwargs...)
175180
return TrustRegion{_unwrap_val(concrete_jac)}(ad, linsolve, precs, radius_update_scheme,
176181
max_trust_radius, initial_trust_radius, step_threshold, shrink_threshold,
177-
expand_threshold, shrink_factor, expand_factor, max_shrink_times)
182+
expand_threshold, shrink_factor, expand_factor, max_shrink_times, vjp_autodiff)
178183
end
179184

180185
@concrete mutable struct TrustRegionCache{iip, trustType, floatType} <:
@@ -422,7 +427,7 @@ function retrospective_step!(cache::TrustRegionCache)
422427
@unpack H, g, du = cache
423428

424429
return -(get_loss(fu_prev) - get_loss(fu)) /
425-
(dot(du, g) + dot(du, H, du) / 2)
430+
(dot(_vec(du), _vec(g)) + __lr_mul(Val(isinplace(cache)), H, _vec(du)) / 2)
426431
end
427432

428433
function trust_region_step!(cache::TrustRegionCache)

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,12 @@ end
305305
__issingular(x::AbstractMatrix{T}) where {T} = cond(x) > inv(sqrt(eps(real(T))))
306306
__issingular(x) = false ## If SciMLOperator and such
307307

308+
# Safe getproperty
309+
@generated function _getproperty(s::S, ::Val{X}) where {S, X}
310+
hasfield(S, X) && return :(s.$X)
311+
return :(nothing)
312+
end
313+
308314
# If factorization is LU then perform that and update the linsolve cache
309315
# else check if the matrix is singular
310316
function _try_factorize_and_check_singular!(linsolve, X)

test/basictests.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,41 +142,52 @@ end
142142
# --- TrustRegion tests ---
143143

144144
@testset "TrustRegion" begin
145-
function benchmark_nlsolve_oop(f, u0, p = 2.0; radius_update_scheme, kwargs...)
145+
function benchmark_nlsolve_oop(f, u0, p = 2.0; radius_update_scheme, linsolve = nothing,
146+
vjp_autodiff = nothing, kwargs...)
146147
prob = NonlinearProblem{false}(f, u0, p)
147-
return solve(prob, TrustRegion(; radius_update_scheme); abstol = 1e-9, kwargs...)
148+
return solve(prob, TrustRegion(; radius_update_scheme, linsolve, vjp_autodiff);
149+
abstol = 1e-9, kwargs...)
148150
end
149151

150-
function benchmark_nlsolve_iip(f, u0, p = 2.0; radius_update_scheme, kwargs...)
152+
function benchmark_nlsolve_iip(f, u0, p = 2.0; radius_update_scheme, linsolve = nothing,
153+
vjp_autodiff = nothing, kwargs...)
151154
prob = NonlinearProblem{true}(f, u0, p)
152-
return solve(prob, TrustRegion(; radius_update_scheme); abstol = 1e-9, kwargs...)
155+
return solve(prob, TrustRegion(; radius_update_scheme, linsolve, vjp_autodiff);
156+
abstol = 1e-9, kwargs...)
153157
end
154158

155159
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.NocedalWright,
156160
RadiusUpdateSchemes.NLsolve, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan,
157161
RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]
158162
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
163+
linear_solvers = [nothing, LUFactorization(), KrylovJL_GMRES()]
159164

160-
@testset "[OOP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme)" for u0 in u0s,
161-
radius_update_scheme in radius_update_schemes
165+
@testset "[OOP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme) linear_solver: $(linsolve)" for u0 in u0s,
166+
radius_update_scheme in radius_update_schemes, linsolve in linear_solvers
167+
168+
!(u0 isa Array) && linsolve !== nothing && continue
169+
170+
abstol = ifelse(linsolve isa KrylovJL, 1e-6, 1e-9)
162171

163-
sol = benchmark_nlsolve_oop(quadratic_f, u0; radius_update_scheme)
172+
sol = benchmark_nlsolve_oop(quadratic_f, u0; radius_update_scheme, linsolve, abstol)
164173
@test SciMLBase.successful_retcode(sol)
165-
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
174+
@test all(abs.(sol.u .* sol.u .- 2) .< abstol)
166175

167176
cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0),
168-
TrustRegion(; radius_update_scheme); abstol = 1e-9)
177+
TrustRegion(; radius_update_scheme, linsolve); abstol)
169178
@test (@ballocated solve!($cache)) < 200
170179
end
171180

172-
@testset "[IIP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme)" for u0 in ([
173-
1.0, 1.0],), radius_update_scheme in radius_update_schemes
174-
sol = benchmark_nlsolve_iip(quadratic_f!, u0; radius_update_scheme)
181+
@testset "[IIP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme) linear_solver: $(linsolve)" for u0 in ([
182+
1.0, 1.0],), radius_update_scheme in radius_update_schemes, linsolve in linear_solvers
183+
abstol = ifelse(linsolve isa KrylovJL, 1e-6, 1e-9)
184+
sol = benchmark_nlsolve_iip(quadratic_f!, u0; radius_update_scheme, linsolve,
185+
abstol)
175186
@test SciMLBase.successful_retcode(sol)
176-
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
187+
@test all(abs.(sol.u .* sol.u .- 2) .< abstol)
177188

178189
cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
179-
TrustRegion(; radius_update_scheme); abstol = 1e-9)
190+
TrustRegion(; radius_update_scheme); abstol)
180191
@test (@ballocated solve!($cache)) 64
181192
end
182193

test/nonlinear_least_squares.jl

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random, ForwardDiff
1+
using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random, ForwardDiff, Zygote
22
import FastLevenbergMarquardt, LeastSquaresOptim
33

44
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
@@ -27,9 +27,16 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2727
resid_prototype = zero(y_target)), θ_init, x)
2828

2929
nlls_problems = [prob_oop, prob_iip]
30-
solvers = vec(Any[GaussNewton(; linsolve, linesearch)
31-
for linsolve in [nothing, LUFactorization()],
32-
linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()]])
30+
solvers = []
31+
for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES()]
32+
vjp_autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoFiniteDiff()] :
33+
[nothing]
34+
for linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()],
35+
vjp_autodiff in vjp_autodiffs
36+
37+
push!(solvers, GaussNewton(; linsolve, linesearch, vjp_autodiff))
38+
end
39+
end
3340
append!(solvers,
3441
[
3542
LevenbergMarquardt(),
@@ -45,6 +52,36 @@ for prob in nlls_problems, solver in solvers
4552
@test norm(sol.resid) < 1e-6
4653
end
4754

55+
# This is just for testing that we can use vjp provided by the user
56+
function vjp(v, θ, p)
57+
resid = zeros(length(p))
58+
J = ForwardDiff.jacobian((resid, θ) -> loss_function(resid, θ, p), resid, θ)
59+
return vec(v' * J)
60+
end
61+
62+
function vjp!(Jv, v, θ, p)
63+
resid = zeros(length(p))
64+
J = ForwardDiff.jacobian((resid, θ) -> loss_function(resid, θ, p), resid, θ)
65+
mul!(vec(Jv), v', J)
66+
return nothing
67+
end
68+
69+
probs = [
70+
NonlinearLeastSquaresProblem(NonlinearFunction{true}(loss_function;
71+
resid_prototype = zero(y_target), vjp = vjp!), θ_init, x),
72+
NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function;
73+
resid_prototype = zero(y_target), vjp = vjp), θ_init, x),
74+
]
75+
76+
for prob in probs, solver in solvers
77+
!(solver isa GaussNewton) && continue
78+
!(solver.linsolve isa KrylovJL) && continue
79+
@test_warn "Currently we don't make use of user provided `jvp`. This is planned to be \
80+
fixed in the near future." sol=solve(prob, solver; maxiters = 10000, abstol = 1e-8)
81+
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
82+
@test norm(sol.resid) < 1e-6
83+
end
84+
4885
function jac!(J, θ, p)
4986
resid = zeros(length(p))
5087
ForwardDiff.jacobian!(J, (resid, θ) -> loss_function(resid, θ, p), resid, θ)

0 commit comments

Comments
 (0)