2
2
JᵀJ
3
3
Jᵀ
4
4
end
5
+
6
+ SciMLBase. isinplace (JᵀJ:: KrylovJ ᵀJ) = isinplace (JᵀJ. Jᵀ)
7
+
5
8
sparsity_detection_alg (_, _) = NoSparsityDetection ()
6
9
function sparsity_detection_alg (f, ad:: AbstractSparseADType )
7
10
if f. sparsity === nothing
@@ -67,12 +70,10 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
67
70
jac_cache = nothing
68
71
end
69
72
70
- # FIXME : To properly support needsJᵀJ without Jacobian, we need to implement
71
- # a reverse diff operation with the seed being `Jx`, this is not yet implemented
72
- J = if ! (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
73
+ J = if ! (linsolve_needs_jac || alg_wants_jac)
73
74
if f. jvp === nothing
74
75
# We don't need to construct the Jacobian
75
- JacVec (uf, u; autodiff = __get_nonsparse_ad (alg. ad))
76
+ JacVec (uf, u; fu, autodiff = __get_nonsparse_ad (alg. ad))
76
77
else
77
78
if iip
78
79
jvp = (_, u, v) -> (du = similar (fu); f. jvp (du, v, u, p); du)
@@ -96,9 +97,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
96
97
du = _mutable_zero (u)
97
98
98
99
if needsJᵀJ
99
- # TODO : Pass in `jac_transpose_autodiff`
100
- JᵀJ, Jᵀfu = __init_JᵀJ (J, _vec (fu), uf, u;
101
- jac_autodiff = __get_nonsparse_ad (alg. ad))
100
+ JᵀJ, Jᵀfu = __init_JᵀJ (J, _vec (fu), uf, u; f,
101
+ vjp_autodiff = __get_nonsparse_ad ( _getproperty (alg, Val ( :vjp_autodiff ))),
102
+ jvp_autodiff = __get_nonsparse_ad (alg. ad))
102
103
end
103
104
104
105
if linsolve_init
@@ -141,26 +142,29 @@ function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...)
141
142
JᵀJ = MArray {Tuple{size(J, 2), size(J, 2)}, eltype(J)} (undef)
142
143
return JᵀJ, J' * fu
143
144
end
144
- function __init_JᵀJ (J:: FunctionOperator , fu, uf, u, args... ;
145
- jac_transpose_autodiff = nothing , jac_autodiff = nothing , kwargs... )
146
- autodiff = __concrete_jac_transpose_autodiff (jac_transpose_autodiff, jac_autodiff, uf)
147
- Jᵀ = VecJac (uf, u; autodiff)
145
+ function __init_JᵀJ (J:: FunctionOperator , fu, uf, u, args... ; f = nothing ,
146
+ vjp_autodiff = nothing , jvp_autodiff = nothing , kwargs... )
147
+ # FIXME : Proper fix to this requires the FunctionOperator patch
148
+ if f != = nothing && f. vjp != = nothing
149
+ @warn " Currently we don't make use of user provided `jvp`. This is planned to be \
150
+ fixed in the near future."
151
+ end
152
+ autodiff = __concrete_vjp_autodiff (vjp_autodiff, jvp_autodiff, uf)
153
+ Jᵀ = VecJac (uf, u; fu, autodiff)
148
154
JᵀJ_op = SciMLOperators. cache_operator (Jᵀ * J, u)
149
155
JᵀJ = KrylovJᵀJ (JᵀJ_op, Jᵀ)
150
156
Jᵀfu = Jᵀ * fu
151
157
return JᵀJ, Jᵀfu
152
158
end
153
159
154
- SciMLBase. isinplace (JᵀJ:: KrylovJ ᵀJ) = isinplace (JᵀJ. Jᵀ)
155
-
156
- function __concrete_jac_transpose_autodiff (jac_transpose_autodiff, jac_autodiff, uf)
157
- if jac_transpose_autodiff === nothing
160
+ function __concrete_vjp_autodiff (vjp_autodiff, jvp_autodiff, uf)
161
+ if vjp_autodiff === nothing
158
162
if isinplace (uf)
159
163
# VecJac can be only FiniteDiff
160
164
return AutoFiniteDiff ()
161
165
else
162
166
# Short circuit if we see that FiniteDiff was used for J computation
163
- jac_autodiff isa AutoFiniteDiff && return jac_autodiff
167
+ jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
164
168
# Check if Zygote is loaded then use Zygote else use FiniteDiff
165
169
if haskey (Base. loaded_modules,
166
170
Base. PkgId (Base. UUID (" e88e6eb3-aa80-5325-afca-941959d7151f" ), " Zygote" ))
@@ -170,7 +174,13 @@ function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff,
170
174
end
171
175
end
172
176
else
173
- return __get_nonsparse_ad (jac_transpose_autodiff)
177
+ ad = __get_nonsparse_ad (vjp_autodiff)
178
+ if isinplace (uf) && ad isa AutoZygote
179
+ @warn " Attempting to use Zygote.jl for linesearch on an in-place problem. \
180
+ Falling back to finite differencing."
181
+ return AutoFiniteDiff ()
182
+ end
183
+ return ad
174
184
end
175
185
end
176
186
0 commit comments