Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(HomotopyContinuation): enable more performant jacobians with Enzyme #528

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/NonlinearSolveHomotopyContinuation/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DifferentiationInterface = "0.6.27"
DocStringExtensions = "0.9.3"
Enzyme = "0.13"
HomotopyContinuation = "2.12.0"
LinearAlgebra = "1.10"
NonlinearSolve = "4"
Expand All @@ -35,8 +36,9 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "NonlinearSolve"]
test = ["Aqua", "Test", "NonlinearSolve", "Enzyme"]
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ end

HomotopyContinuationJL(; kwargs...) = HomotopyContinuationJL{false}(; kwargs...)

function HomotopyContinuationJL(alg::HomotopyContinuationJL{R}; kwargs...) where {R}
HomotopyContinuationJL{R}(; autodiff = alg.autodiff, alg.kwargs..., kwargs...)
end

include("interface_types.jl")
include("jacobian_handling.jl")
include("solve.jl")

end
112 changes: 9 additions & 103 deletions lib/NonlinearSolveHomotopyContinuation/src/interface_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,6 @@
struct OutOfPlace <: HomotopySystemVariant end
struct Scalar <: HomotopySystemVariant end

"""
$(TYPEDEF)

A simple struct that wraps a polynomial function which takes complex input and returns
complex output in a form that supports automatic differentiation. If the wrapped
function if ``f: \\mathbb{C}^n \\rightarrow \\mathbb{C}^n`` then it is assumed
the input arrays are real-valued and have length ``2n``. They are `reinterpret`ed
into complex arrays and passed into the function. This struct has an in-place signature
regardless of the signature of ``f``.
"""
@concrete struct ComplexJacobianWrapper{variant <: HomotopySystemVariant}
f
end

function (cjw::ComplexJacobianWrapper{Inplace})(
u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
x = reinterpret(Complex{T}, x)
u = reinterpret(Complex{T}, u)
cjw.f(u, x, p)
u = parent(u)
return u
end

function (cjw::ComplexJacobianWrapper{OutOfPlace})(
u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
x = reinterpret(Complex{T}, x)
u_tmp = cjw.f(x, p)
u_tmp = reinterpret(T, u_tmp)
copyto!(u, u_tmp)
return u
end

function (cjw::ComplexJacobianWrapper{Scalar})(
u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
x = reinterpret(Complex{T}, x)
u_tmp = cjw.f(x[1], p)
u[1] = real(u_tmp)
u[2] = imag(u_tmp)
return u
end

"""
$(TYPEDEF)

Expand All @@ -62,34 +21,24 @@
"""
f
"""
The jacobian function, if provided to the `NonlinearProblem` being solved. Otherwise,
a `ComplexJacobianWrapper` wrapping `f` used for automatic differentiation.
A function which calculates both the polynomial and the jacobian. Must be a function
of the form `f(u, U, x, p)` where `x` is the current unknowns and `p` is the parameter
object, writing the value of the polynomial to `u` and the jacobian to `U`. Must be able
to handle complex `x`.
"""
jac
"""
The parameter object.
"""
p
"""
The ADType for automatic differentiation.
"""
autodiff
"""
The result from `DifferentiationInterface.prepare_jacobian`.
"""
prep
"""
HomotopyContinuation.jl's symbolic variables for the system.
"""
vars
"""
The `TaylorDiff.TaylorScalar` objects used to compute the taylor series of `f`.
"""
taylorvars
"""
Preallocated intermediate buffers used for calculating the jacobian.
"""
jacobian_buffers
end

Base.size(sys::HomotopySystemWrapper) = (length(sys.vars), length(sys.vars))
Expand All @@ -112,54 +61,11 @@
end

function HC.ModelKit.evaluate_and_jacobian!(
u, U, sys::HomotopySystemWrapper{Inplace}, x, p = nothing)
p = sys.p
sys.f(u, x, p)
sys.jac(U, x, p)
return u, U
end

function HC.ModelKit.evaluate_and_jacobian!(
u, U, sys::HomotopySystemWrapper{OutOfPlace}, x, p = nothing)
p = sys.p
u_tmp = sys.f(x, p)
copyto!(u, u_tmp)
j_tmp = sys.jac(x, p)
copyto!(U, j_tmp)
u, U, sys::HomotopySystemWrapper, x, p = nothing)
sys.jac(u, U, x, sys.p)
return u, U
end

function HC.ModelKit.evaluate_and_jacobian!(
u, U, sys::HomotopySystemWrapper{Scalar}, x, p = nothing)
p = sys.p
u[1] = sys.f(x[1], p)
U[1] = sys.jac(x[1], p)
return u, U
end

for V in (Inplace, OutOfPlace, Scalar)
@eval function HC.ModelKit.evaluate_and_jacobian!(
u, U, sys::HomotopySystemWrapper{$V, F, J}, x,
p = nothing) where {F, J <: ComplexJacobianWrapper}
p = sys.p
U_tmp = sys.jacobian_buffers
x = reinterpret(Float64, x)
u = reinterpret(Float64, u)
DI.value_and_jacobian!(sys.jac, u, U_tmp, sys.prep, sys.autodiff, x, DI.Constant(p))
U = reinterpret(Float64, U)
@inbounds for j in axes(U, 2)
jj = 2j - 1
for i in axes(U, 1)
U[i, j] = U_tmp[i, jj]
end
end
u = parent(u)
U = parent(U)

return u, U
end
end

function update_taylorvars_from_taylorvector!(
vars, x::HC.ModelKit.TaylorVector)
for i in eachindex(x)
Expand All @@ -185,14 +91,14 @@

function check_taylor_equality(vars, x::HC.ModelKit.TaylorVector)
for i in eachindex(x)
TaylorDiff.flatten(vars[2i-1]) == map(real, x[i]) || return false
TaylorDiff.flatten(vars[2i - 1]) == map(real, x[i]) || return false
TaylorDiff.flatten(vars[2i]) == map(imag, x[i]) || return false
end
return true
end
function check_taylor_equality(vars, x::AbstractVector)
for i in eachindex(x)
TaylorDiff.value(vars[2i-1]) != real(x[i]) && return false
TaylorDiff.value(vars[2i - 1]) != real(x[i]) && return false
TaylorDiff.value(vars[2i]) != imag(x[i]) && return false
end
return true
Expand All @@ -212,7 +118,7 @@
for i in eachindex(vars)
rval = TaylorDiff.flatten(real(buffer[i]))
ival = TaylorDiff.flatten(imag(buffer[i]))
u[i] = ntuple(i -> rval[i] + im * ival[i], Val(length(rval)))
u[i] = ntuple(i -> rval[i] + im * ival[i], Val(length(rval)))
end
end

Expand Down Expand Up @@ -270,7 +176,7 @@
"""
$(TYPEDEF)

A `HomotopyContinuation.AbstractHomotopy` which uses an inital guess ``x_0`` to construct

Check warning on line 179 in lib/NonlinearSolveHomotopyContinuation/src/interface_types.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"inital" should be "initial".
the start system for the homotopy. The homotopy is

```math
Expand Down
Loading
Loading