Skip to content

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Feb 8, 2025

@gdalle this will enable the DI ext to more properly touch internals if need be

include("reverse_onearg.jl")
include("reverse_twoarg.jl")

end # module No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
end # module
end # module

Comment on lines +4 to +16
f::F,
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
return DI.NoPushforwardPrep()
end

function DI.value_and_pushforward(
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F,
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
return DI.NoPushforwardPrep()
end
function DI.value_and_pushforward(
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
f::F,
::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context, C},
) where {F, C}
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context, C},
) where {F, C}

Comment on lines +31 to +37
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context, C},
) where {F, B, C}

Comment on lines +48 to +54
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context, C},
) where {F, C}

Comment on lines +65 to +71
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context, C},
) where {F, B, C}

Comment on lines +12 to +29
f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return f
end

@inline function get_f_and_df(
f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return Const(f)
end

@inline function get_f_and_df(
f::F,
::AutoEnzyme{
M,
<:Union{
Duplicated,
MixedDuplicated,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return f
end
@inline function get_f_and_df(
f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return Const(f)
end
@inline function get_f_and_df(
f::F,
::AutoEnzyme{
M,
<:Union{
Duplicated,
MixedDuplicated,
f::F, ::AutoEnzyme{M, Nothing}, mode::Mode, ::Val{B} = Val(1)
) where {F, M, B}
f::F, ::AutoEnzyme{M, <:Const}, mode::Mode, ::Val{B} = Val(1)
) where {F, M, B}
f::F,
::AutoEnzyme{
M,
<:Union{
Duplicated,
MixedDuplicated,
BatchDuplicated,
BatchMixedDuplicated,
DuplicatedNoNeed,
BatchDuplicatedNoNeed,
},
mode::Mode,
::Val{B} = Val(1),
) where {F, M, B}

Comment on lines +47 to +51
force_annotation(f::F) where {F<:Annotation} = f
force_annotation(f::F) where {F} = Const(f)

@inline function _translate(
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
force_annotation(f::F) where {F<:Annotation} = f
force_annotation(f::F) where {F} = Const(f)
@inline function _translate(
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext}
force_annotation(f::F) where {F <: Annotation} = f
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant, DI.BackendContext}
) where {B}
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
) where {B}

Comment on lines +67 to +70
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
) where {B}
return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B)))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
) where {B}
return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B)))
end
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
) where {B}
backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context, C}
) where {B, C}

Comment on lines +103 to +104
set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode)
set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode)
set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode
set_err(mode::Mode, ::AutoEnzyme{<:Any, Nothing}) = EnzymeCore.set_err_if_func_written(mode)
set_err(mode::Mode, ::AutoEnzyme{<:Any, <:Annotation}) = mode

Comment on lines +118 to +120
function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B}
return BatchDuplicated(x, tx)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B}
return BatchDuplicated(x, tx)
end
function annotate(::Type{BatchDuplicated{T, B}}, x, tx::NTuple{B}) where {T, B}
batchify_activity(::Type{Active{T}}, ::Val{B}) where {T, B} = Active{T}
batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T, B} = BatchDuplicated{T, B}

@github-actions
Copy link
Contributor

github-actions bot commented Feb 8, 2025

Benchmark Results

main 00803ee... main/00803ee98e24c7...
basics/overhead 5.26 ± 0.01 ns 4.64 ± 0.01 ns 1.13
time_to_load 1.1 ± 0.012 s 1.13 ± 0.034 s 0.967

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@wsmoses wsmoses marked this pull request as draft February 8, 2025 18:34
return set_err(ReverseSplitWithPrimal, backend)
end

set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle as discussed on slack this probably should be an extension to the set_err_if_func_written function to take an ADmode, so likely we have this in an EnzymeCoreADTypes ext?

forward_withprimal(backend::AutoEnzyme{<:ForwardMode}) = WithPrimal(backend.mode)
forward_withprimal(::AutoEnzyme{Nothing}) = ForwardWithPrimal

reverse_noprimal(backend::AutoEnzyme{<:ReverseMode}) = NoPrimal(backend.mode)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle similarly here we can make an ADTypes ext func for get_mode_or_default(AutoEnzyme, defaultMode)

dy_sametype = convert(typeof(y), only(prep.ty_copy))
x_and_dx = Duplicated(x, dx_sametype)
y_and_dy = Duplicated(y, dy_sametype)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle I presume this can be moved into Enzyme.gradient! And have DI call that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants