Skip to content

Commit

Permalink
fix: hessian through nonlinear solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 1, 2024
1 parent 748fb09 commit c2438c7
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function NonlinearSolveBase.additional_incompatible_backend_check(
end

Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
Utils.value(x::Dual) = ForwardDiff.value(x)
Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
Expand Down
54 changes: 54 additions & 0 deletions test/forward_ad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,57 @@ end
end
end
end

@testitem "NLLS Hessian SciML/NonlinearSolve.jl#445" tags=[:core] begin
using ForwardDiff, FiniteDiff

function objfn(F, init, params)
th1, th2 = init
px, py, l1, l2 = params
F[1] = l1 * cos(th1) + l2 * cos(th1 + th2) - px
F[2] = l1 * sin(th1) + l2 * sin(th1 + th2) - py
return F
end

function solve_nlprob(pxpy)
px, py = pxpy
theta1 = pi / 4
theta2 = pi / 4
initial_guess = [theta1; theta2]
l1 = 60
l2 = 60
p = [px; py; l1; l2]
prob = NonlinearLeastSquaresProblem(
NonlinearFunction(objfn, resid_prototype = zeros(2)),
initial_guess, p
)
resu = solve(
prob,
reltol = 1e-12, abstol = 1e-12
)
th1, th2 = resu.u
cable1_base = [-90; 0; 0]
cable2_base = [-150; 0; 0]
cable3_base = [150; 0; 0]
cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0]
cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2;
l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0]
c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 +
(cable1_top[2] - cable1_base[2])^2)
c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 +
(cable23_top[2] - cable2_base[2])^2)
c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 +
(cable23_top[2] - cable3_base[2])^2)
return c1_length + c2_length + c3_length
end

grad1 = ForwardDiff.gradient(solve_nlprob, [34.0, 87.0])
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob, [34.0, 87.0])

@test grad1 grad2 atol = 1e-3

hess1 = ForwardDiff.hessian(solve_nlprob, [34.0, 87.0])
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob, [34.0, 87.0])

@test hess1 hess2 atol = 1e-3
end

0 comments on commit c2438c7

Please sign in to comment.