Skip to content

Commit

Permalink
feat(HomotopyContinuation): enable more performant jacobians with Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 31, 2025
1 parent 7bb635d commit a82989d
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 145 deletions.
4 changes: 3 additions & 1 deletion lib/NonlinearSolveHomotopyContinuation/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DifferentiationInterface = "0.6.27"
DocStringExtensions = "0.9.3"
Enzyme = "0.13"
HomotopyContinuation = "2.12.0"
LinearAlgebra = "1.10"
NonlinearSolve = "4"
Expand All @@ -35,8 +36,9 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "NonlinearSolve"]
test = ["Aqua", "Test", "NonlinearSolve", "Enzyme"]
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ end

HomotopyContinuationJL(; kwargs...) = HomotopyContinuationJL{false}(; kwargs...)

function HomotopyContinuationJL(alg::HomotopyContinuationJL{R}; kwargs...) where {R}
HomotopyContinuationJL{R}(; autodiff = alg.autodiff, alg.kwargs..., kwargs...)
end

include("interface_types.jl")
include("jacobian_handling.jl")
include("solve.jl")

end
112 changes: 9 additions & 103 deletions lib/NonlinearSolveHomotopyContinuation/src/interface_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,6 @@ struct Inplace <: HomotopySystemVariant end
struct OutOfPlace <: HomotopySystemVariant end
struct Scalar <: HomotopySystemVariant end

"""
$(TYPEDEF)
A simple struct that wraps a polynomial function which takes complex input and returns
complex output in a form that supports automatic differentiation. If the wrapped
function if ``f: \\mathbb{C}^n \\rightarrow \\mathbb{C}^n`` then it is assumed
the input arrays are real-valued and have length ``2n``. They are `reinterpret`ed
into complex arrays and passed into the function. This struct has an in-place signature
regardless of the signature of ``f``.
"""
@concrete struct ComplexJacobianWrapper{variant <: HomotopySystemVariant}
f
end

function (cjw::ComplexJacobianWrapper{Inplace})(
u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
x = reinterpret(Complex{T}, x)
u = reinterpret(Complex{T}, u)
cjw.f(u, x, p)
u = parent(u)
return u
end

function (cjw::ComplexJacobianWrapper{OutOfPlace})(
u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
x = reinterpret(Complex{T}, x)
u_tmp = cjw.f(x, p)
u_tmp = reinterpret(T, u_tmp)
copyto!(u, u_tmp)
return u
end

function (cjw::ComplexJacobianWrapper{Scalar})(
u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
x = reinterpret(Complex{T}, x)
u_tmp = cjw.f(x[1], p)
u[1] = real(u_tmp)
u[2] = imag(u_tmp)
return u
end

"""
$(TYPEDEF)
Expand All @@ -62,34 +21,24 @@ $(FIELDS)
"""
f
"""
The jacobian function, if provided to the `NonlinearProblem` being solved. Otherwise,
a `ComplexJacobianWrapper` wrapping `f` used for automatic differentiation.
A function which calculates both the polynomial and the jacobian. Must be a function
of the form `f(u, U, x, p)` where `x` is the current unknowns and `p` is the parameter
object, writing the value of the polynomial to `u` and the jacobian to `U`. Must be able
to handle complex `x`.
"""
jac
"""
The parameter object.
"""
p
"""
The ADType for automatic differentiation.
"""
autodiff
"""
The result from `DifferentiationInterface.prepare_jacobian`.
"""
prep
"""
HomotopyContinuation.jl's symbolic variables for the system.
"""
vars
"""
The `TaylorDiff.TaylorScalar` objects used to compute the taylor series of `f`.
"""
taylorvars
"""
Preallocated intermediate buffers used for calculating the jacobian.
"""
jacobian_buffers
end

Base.size(sys::HomotopySystemWrapper) = (length(sys.vars), length(sys.vars))
Expand All @@ -112,54 +61,11 @@ function HC.ModelKit.evaluate!(u, sys::HomotopySystemWrapper{Scalar}, x, p = not
end

function HC.ModelKit.evaluate_and_jacobian!(
u, U, sys::HomotopySystemWrapper{Inplace}, x, p = nothing)
p = sys.p
sys.f(u, x, p)
sys.jac(U, x, p)
return u, U
end

function HC.ModelKit.evaluate_and_jacobian!(
u, U, sys::HomotopySystemWrapper{OutOfPlace}, x, p = nothing)
p = sys.p
u_tmp = sys.f(x, p)
copyto!(u, u_tmp)
j_tmp = sys.jac(x, p)
copyto!(U, j_tmp)
u, U, sys::HomotopySystemWrapper, x, p = nothing)
sys.jac(u, U, x, sys.p)
return u, U
end

function HC.ModelKit.evaluate_and_jacobian!(
u, U, sys::HomotopySystemWrapper{Scalar}, x, p = nothing)
p = sys.p
u[1] = sys.f(x[1], p)
U[1] = sys.jac(x[1], p)
return u, U
end

for V in (Inplace, OutOfPlace, Scalar)
@eval function HC.ModelKit.evaluate_and_jacobian!(
u, U, sys::HomotopySystemWrapper{$V, F, J}, x,
p = nothing) where {F, J <: ComplexJacobianWrapper}
p = sys.p
U_tmp = sys.jacobian_buffers
x = reinterpret(Float64, x)
u = reinterpret(Float64, u)
DI.value_and_jacobian!(sys.jac, u, U_tmp, sys.prep, sys.autodiff, x, DI.Constant(p))
U = reinterpret(Float64, U)
@inbounds for j in axes(U, 2)
jj = 2j - 1
for i in axes(U, 1)
U[i, j] = U_tmp[i, jj]
end
end
u = parent(u)
U = parent(U)

return u, U
end
end

function update_taylorvars_from_taylorvector!(
vars, x::HC.ModelKit.TaylorVector)
for i in eachindex(x)
Expand All @@ -185,14 +91,14 @@ end

function check_taylor_equality(vars, x::HC.ModelKit.TaylorVector)
for i in eachindex(x)
TaylorDiff.flatten(vars[2i-1]) == map(real, x[i]) || return false
TaylorDiff.flatten(vars[2i - 1]) == map(real, x[i]) || return false
TaylorDiff.flatten(vars[2i]) == map(imag, x[i]) || return false
end
return true
end
function check_taylor_equality(vars, x::AbstractVector)
for i in eachindex(x)
TaylorDiff.value(vars[2i-1]) != real(x[i]) && return false
TaylorDiff.value(vars[2i - 1]) != real(x[i]) && return false
TaylorDiff.value(vars[2i]) != imag(x[i]) && return false
end
return true
Expand All @@ -212,7 +118,7 @@ function update_maybe_taylorvector_from_taylorvars!(
for i in eachindex(vars)
rval = TaylorDiff.flatten(real(buffer[i]))
ival = TaylorDiff.flatten(imag(buffer[i]))
u[i] = ntuple(i -> rval[i] + im * ival[i], Val(length(rval)))
u[i] = ntuple(i -> rval[i] + im * ival[i], Val(length(rval)))
end
end

Expand Down
Loading

0 comments on commit a82989d

Please sign in to comment.