@@ -131,6 +131,10 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
131131 reverse_rules:: Bool
132132 inactive_rules:: Bool
133133 broadcast_rewrite:: Bool
134+
135+ # When false, leave the check for within_autodiff to the handler.
136+ within_autodiff_rewrite:: Bool
137+
134138 handler:: T
135139end
136140
@@ -169,6 +173,7 @@ function EnzymeInterpreter(
169173 reverse_rules:: Bool ,
170174 inactive_rules:: Bool ,
171175 broadcast_rewrite:: Bool = true ,
176+ within_autodiff_rewrite:: Bool = true ,
172177 handler = nothing
173178)
174179 @assert world <= Base. get_world_counter ()
@@ -229,6 +234,7 @@ function EnzymeInterpreter(
229234 reverse_rules:: Bool ,
230235 inactive_rules:: Bool ,
231236 broadcast_rewrite:: Bool ,
237+ within_autodiff_rewrite:: Bool ,
232238 handler
233239 )
234240end
@@ -240,8 +246,42 @@ EnzymeInterpreter(
240246 mode:: API.CDerivativeMode ,
241247 inactive_rules:: Bool ,
242248 broadcast_rewrite:: Bool = true ,
249+ within_autodiff_rewrite:: Bool = true ,
243250 handler = nothing
244- ) = EnzymeInterpreter (cache_or_token, mt, world, mode == API. DEM_ForwardMode, mode == API. DEM_ReverseModeCombined || mode == API. DEM_ReverseModePrimal || mode == API. DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, handler)
251+ ) = EnzymeInterpreter (cache_or_token, mt, world, mode == API. DEM_ForwardMode, mode == API. DEM_ReverseModeCombined || mode == API. DEM_ReverseModePrimal || mode == API. DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler)
252+
253+ function EnzymeInterpreter (interp:: EnzymeInterpreter ;
254+ cache_or_token = (@static if HAS_INTEGRATED_CACHE
255+ interp. token
256+ else
257+ interp. code_cache
258+ end ),
259+ mt = interp. method_table,
260+ local_cache = interp. local_cache,
261+ world = interp. world,
262+ inf_params = interp. inf_params,
263+ opt_params = interp. opt_params,
264+ forward_rules = interp. forward_rules,
265+ reverse_rules = interp. reverse_rules,
266+ inactive_rules = interp. inactive_rules,
267+ broadcast_rewrite = interp. broadcast_rewrite,
268+ within_autodiff_rewrite = interp. within_autodiff_rewrite,
269+ handler = interp. handler)
270+ return EnzymeInterpreter (
271+ cache_or_token,
272+ mt,
273+ local_cache,
274+ world,
275+ inf_params,
276+ opt_params,
277+ forward_rules,
278+ reverse_rules,
279+ inactive_rules,
280+ broadcast_rewrite,
281+ within_autodiff_rewrite,
282+ handler
283+ )
284+ end
245285
246286Core. Compiler. InferenceParams (@nospecialize (interp:: EnzymeInterpreter )) = interp. inf_params
247287Core. Compiler. OptimizationParams (@nospecialize (interp:: EnzymeInterpreter )) = interp. opt_params
@@ -933,7 +973,7 @@ function abstract_call_known(
933973
934974 (; fargs, argtypes) = arginfo
935975
936- if f === Enzyme. within_autodiff
976+ if interp . within_autodiff_rewrite && f === Enzyme. within_autodiff
937977 if length (argtypes) != 1
938978 @static if VERSION < v " 1.11.0-"
939979 return CallMeta (Union{}, Effects (), NoCallInfo ())
0 commit comments