Skip to content

Commit 1aa2002

Browse files
authored
Add defer_within_autodiff to EnzymeInterpreter (#2254)
* add `defer_within_autodiff` to EnzymeInterpreter in order for `within_autodiff` to no return true during Reactant compilation. When this flag is true, `interp.handler` is responsible for handling within_autodiff, or to toggle defer_within_autodiff to false somewhere down the call chain. * `!defer_within_autodiff` -> `within_autodiff_rewrite`
1 parent 4ba9b71 commit 1aa2002

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

src/compiler/interpreter.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
131131
reverse_rules::Bool
132132
inactive_rules::Bool
133133
broadcast_rewrite::Bool
134+
135+
# When false, leave the check for within_autodiff to the handler.
136+
within_autodiff_rewrite::Bool
137+
134138
handler::T
135139
end
136140

@@ -169,6 +173,7 @@ function EnzymeInterpreter(
169173
reverse_rules::Bool,
170174
inactive_rules::Bool,
171175
broadcast_rewrite::Bool = true,
176+
within_autodiff_rewrite::Bool = true,
172177
handler = nothing
173178
)
174179
@assert world <= Base.get_world_counter()
@@ -229,6 +234,7 @@ function EnzymeInterpreter(
229234
reverse_rules::Bool,
230235
inactive_rules::Bool,
231236
broadcast_rewrite::Bool,
237+
within_autodiff_rewrite::Bool,
232238
handler
233239
)
234240
end
@@ -240,8 +246,42 @@ EnzymeInterpreter(
240246
mode::API.CDerivativeMode,
241247
inactive_rules::Bool,
242248
broadcast_rewrite::Bool = true,
249+
within_autodiff_rewrite::Bool = true,
243250
handler = nothing
244-
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, handler)
251+
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler)
252+
253+
function EnzymeInterpreter(interp::EnzymeInterpreter;
254+
cache_or_token = (@static if HAS_INTEGRATED_CACHE
255+
interp.token
256+
else
257+
interp.code_cache
258+
end),
259+
mt = interp.method_table,
260+
local_cache = interp.local_cache,
261+
world = interp.world,
262+
inf_params = interp.inf_params,
263+
opt_params = interp.opt_params,
264+
forward_rules = interp.forward_rules,
265+
reverse_rules = interp.reverse_rules,
266+
inactive_rules = interp.inactive_rules,
267+
broadcast_rewrite = interp.broadcast_rewrite,
268+
within_autodiff_rewrite = interp.within_autodiff_rewrite,
269+
handler = interp.handler)
270+
return EnzymeInterpreter(
271+
cache_or_token,
272+
mt,
273+
local_cache,
274+
world,
275+
inf_params,
276+
opt_params,
277+
forward_rules,
278+
reverse_rules,
279+
inactive_rules,
280+
broadcast_rewrite,
281+
within_autodiff_rewrite,
282+
handler
283+
)
284+
end
245285

246286
Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
247287
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
@@ -933,7 +973,7 @@ function abstract_call_known(
933973

934974
(; fargs, argtypes) = arginfo
935975

936-
if f === Enzyme.within_autodiff
976+
if interp.within_autodiff_rewrite && f === Enzyme.within_autodiff
937977
if length(argtypes) != 1
938978
@static if VERSION < v"1.11.0-"
939979
return CallMeta(Union{}, Effects(), NoCallInfo())

0 commit comments

Comments
 (0)