Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Basic Support for Enzyme AD #27

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ jobs:
strategy:
fail-fast: false
matrix:
julia_version: ["1.7", "1"]
julia_version: ["1"]
name: julia ${{ matrix.julia_version }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia_version }}
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.9"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b"
Expand Down
132 changes: 132 additions & 0 deletions src/AutoDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using ..ParametricMCPs: ParametricMCPs, get_problem_size, get_result_buffer, get_parameter_dimension
using ChainRulesCore: ChainRulesCore
using EnzymeCore: EnzymeCore, EnzymeRules
using ForwardDiff: ForwardDiff
using SparseArrays: SparseArrays
using LinearAlgebra: LinearAlgebra
Expand Down Expand Up @@ -45,6 +46,137 @@
∂z∂θ
end

const EnzymeBatchedAnnotation = Union{EnzymeCore.BatchDuplicated,EnzymeCore.BatchDuplicatedNoNeed}
const EnzymeNoneedAnnotation = Union{EnzymeCore.DuplicatedNoNeed,EnzymeCore.BatchDuplicatedNoNeed}

function EnzymeRules.forward(

Check warning on line 52 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L52

Added line #L52 was not covered by tests
func::EnzymeCore.Const{typeof(ParametricMCPs.solve)},
::Type{ReturnType},
problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP},
θ::EnzymeCore.Annotation;
kwargs...,
) where {ReturnType<:EnzymeCore.Annotation}
# TODO: Enzyme sometimes passes us the problem as non-const (why?). For now, skip this check.
#if !(problem isa EnzymeCore.Const)
# throw(ArgumentError("""
# `problem` must be annotated `Enzyme.Const`.
# If you did not pass the non-const problem annotation yourself,
# consider filing an issue with ParametricMCPs.jl.
# """))
#end

if θ isa EnzymeCore.Const
throw(

Check warning on line 69 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L68-L69

Added lines #L68 - L69 were not covered by tests
ArgumentError(
"""
`θ` was annotated `Enzyme.Const` which defeats the purpose of running AD.
If you did not pass the const θ annotation yourself,
consider filing an issue with ParametricMCPs.jl.
""",
),
)
end

# forward pass
solution_val = func.val(problem.val, θ.val; kwargs...)

Check warning on line 81 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L81

Added line #L81 was not covered by tests

if ReturnType <: EnzymeCore.Const
return solution_val

Check warning on line 84 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L83-L84

Added lines #L83 - L84 were not covered by tests
end

# backward pass
∂z∂θ = _solve_jacobian_θ(problem.val, solution_val, θ.val)

Check warning on line 88 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L88

Added line #L88 was not covered by tests

if ReturnType <: EnzymeBatchedAnnotation
solution_dval = map(θ.dval) do θdval
_dval = deepcopy(solution_val)
_dval.z .= ∂z∂θ * θdval
_dval

Check warning on line 94 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L90-L94

Added lines #L90 - L94 were not covered by tests
end
else
# downstream gradient
dz = ∂z∂θ * θ.dval
solution_dval = deepcopy(solution_val)
solution_dval.z .= dz

Check warning on line 100 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L98-L100

Added lines #L98 - L100 were not covered by tests
end

if ReturnType <: EnzymeNoneedAnnotation
return solution_dval

Check warning on line 104 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
end

if ReturnType <: EnzymeCore.Duplicated
return EnzymeCore.Duplicated(solution_val, solution_dval)

Check warning on line 108 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end

if ReturnType <: EnzymeCore.BatchDuplicated
return EnzymeCore.BatchDuplicated(solution_val, solution_dval)

Check warning on line 112 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L111-L112

Added lines #L111 - L112 were not covered by tests
end

throw(ArgumentError("""

Check warning on line 115 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L115

Added line #L115 was not covered by tests
Forward rule for ReturnType with annotation $(ReturnType) not implemented.
Please file an issue with ParametricMCPs.jl.
"""))
end

function EnzymeRules.augmented_primal(

Check warning on line 121 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L121

Added line #L121 was not covered by tests
config::EnzymeRules.ConfigWidth{1},
func::EnzymeCore.Const{typeof(ParametricMCPs.solve)},
::Type{<:EnzymeRules.Annotation},
problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP},
θ::EnzymeCore.Annotation;
kwargs...,
)
function copy_or_reuse(val, idx)
if EnzymeRules.overwritten(config)[idx] && ismutable(val)
return deepcopy(val)

Check warning on line 131 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L129-L131

Added lines #L129 - L131 were not covered by tests
end
val

Check warning on line 133 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L133

Added line #L133 was not covered by tests
end

θval = copy_or_reuse(θ.val, 3)
res = func.val(problem.val, θval; kwargs...)

Check warning on line 137 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L136-L137

Added lines #L136 - L137 were not covered by tests
# backward pass
∂z∂θ_thunk = () -> _solve_jacobian_θ(problem.val, res, θval)

Check warning on line 139 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L139

Added line #L139 was not covered by tests

dres = deepcopy(res)
dres.z .= 0.0

Check warning on line 142 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L141-L142

Added lines #L141 - L142 were not covered by tests

tape = (; ∂z∂θ_thunk, dres)

Check warning on line 144 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L144

Added line #L144 was not covered by tests

EnzymeRules.AugmentedReturn(res, dres, tape)

Check warning on line 146 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L146

Added line #L146 was not covered by tests
end

function EnzymeRules.reverse(

Check warning on line 149 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L149

Added line #L149 was not covered by tests
config,
func::EnzymeCore.Const{typeof(ParametricMCPs.solve)},
rt::Type{ReturnType}, # TODO: tighter type constraint
tape,
problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP},
θ::EnzymeCore.Annotation;
kwargs...,
) where {ReturnType}
if θ isa EnzymeCore.Duplicated
∂z∂θ = tape.∂z∂θ_thunk()
∂l∂z = tape.dres.z
θ.dval .+= ∂z∂θ' * ∂l∂z
elseif !(θ isa EnzymeCore.Const)
throw(ArgumentError("""

Check warning on line 163 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L158-L163

Added lines #L158 - L163 were not covered by tests
Reverse rule for θ with annotation $(typeof(θ)) not implemented.
Please file an issue with ParametricMCPs.jl.
"""))
end

if !(problem isa EnzymeCore.Const)
throw(ArgumentError("""

Check warning on line 170 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L169-L170

Added lines #L169 - L170 were not covered by tests
`problem` must be annotated `Enzyme.Const`.
If you did not pass the non-const problem annotation yourself,
consider filing an issue with ParametricMCPs.jl.
"""))
end

(nothing, nothing)

Check warning on line 177 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L177

Added line #L177 was not covered by tests
end

function ChainRulesCore.rrule(::typeof(ParametricMCPs.solve), problem, θ; kwargs...)
solution = ParametricMCPs.solve(problem, θ; kwargs...)
project_to_θ = ChainRulesCore.ProjectTo(θ)
Expand Down
2 changes: 1 addition & 1 deletion src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function solve(

jacobian_linear_elements =
enable_presolve ? jacobian_z!.constant_entries : empty(jacobian_z!.constant_entries)
status, z, info = PATHSolver.solve_mcp(
status, z::Vector, info::PATHSolver.Information = PATHSolver.solve_mcp(
F,
J,
lower_bounds,
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
74 changes: 57 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Random: Random
using LinearAlgebra: norm
using Zygote: Zygote
using FiniteDiff: FiniteDiff
using Enzyme: Enzyme

@testset "ParametricMCPs.jl" begin
rng = Random.MersenneTwister(1)
Expand All @@ -15,9 +16,15 @@ using FiniteDiff: FiniteDiff
lower_bounds = [-Inf, -Inf, 0, 0]
upper_bounds = [Inf, Inf, Inf, Inf]
problem = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension)
problem_no_jacobian = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension; compute_sensitivities=false)
problem_no_jacobian = ParametricMCPs.ParametricMCP(
f,
lower_bounds,
upper_bounds,
parameter_dimension;
compute_sensitivities = false,
)

feasible_parameters = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [rand(rng, 2) for _ in 1:10]...]
feasible_parameters = [[0.0, 0.0], [rand(rng, 2) for _ in 1:4]...]
infeasible_parameters = -feasible_parameters

@testset "forward pass" begin
Expand All @@ -32,27 +39,60 @@ using FiniteDiff: FiniteDiff
end
end

@testset "backward pass" begin
function dummy_pipeline(θ)
solution = ParametricMCPs.solve(problem, θ)
sum(solution.z .^ 2)
end
function dummy_pipeline(problem, θ)
solution = ParametricMCPs.solve(problem, θ)
sum(solution.z .^ 2)
end

@testset "backward pass" begin
for θ in [feasible_parameters; infeasible_parameters]
∇_autodiff_reverse = only(Zygote.gradient(dummy_pipeline, θ))
∇_autodiff_forward = only(Zygote.gradient(θ -> Zygote.forwarddiff(dummy_pipeline, θ), θ))
∇_finitediff = FiniteDiff.finite_difference_gradient(dummy_pipeline, θ)
@test isapprox(∇_autodiff_reverse, ∇_finitediff; atol=1e-4)
@test isapprox(∇_autodiff_reverse, ∇_autodiff_forward; atol=1e-4)
∇_finitediff = FiniteDiff.finite_difference_gradient(θ -> dummy_pipeline(problem, θ), θ)

@testset "Zygote Reverse" begin
∇_zygote_reverse = Zygote.gradient(θ) do θ
dummy_pipeline(problem, θ)
end |> only
@test isapprox(∇_zygote_reverse, ∇_finitediff; atol = 1e-4)
end

@testset "Zygote Forward" begin
∇_zygote_forward = Zygote.gradient(θ) do θ
Zygote.forwarddiff(θ) do θ
dummy_pipeline(problem, θ)
end
end |> only
@test isapprox(∇_zygote_forward, ∇_finitediff; atol = 1e-4)
end

@testset "Enzyme Forward" begin
∇_enzyme_forward =
Enzyme.autodiff(
Enzyme.Forward,
dummy_pipeline,
problem,
Enzyme.BatchDuplicated(θ, Enzyme.onehot(θ)),
) |>
only |>
collect
@test isapprox(∇_enzyme_forward, ∇_finitediff; atol = 1e-4)
end

@testset "Enzyme Reverse" begin
∇_enzyme_reverse = zero(θ)
Enzyme.autodiff(
Enzyme.Reverse,
dummy_pipeline,
problem,
Enzyme.Duplicated(θ, ∇_enzyme_reverse),
)
@test isapprox(∇_enzyme_reverse, ∇_finitediff; atol = 1e-4)
end
end
end

@testset "missing jacobian" begin
function dummy_pipeline(θ, problem)
solution = ParametricMCPs.solve(problem, θ)
sum(solution.z .^ 2)
@test_throws ArgumentError Zygote.gradient(feasible_parameters[1]) do θ
dummy_pipeline(problem_no_jacobian, θ)
end

@test_throws ArgumentError Zygote.gradient(θ -> dummy_pipeline(θ, problem_no_jacobian), feasible_parameters[1])
end
end
Loading