Skip to content

Commit d15d98d

Browse files
committed
use a local cache for activity reg
1 parent cffa09d commit d15d98d

File tree

3 files changed

+42
-20
lines changed

3 files changed

+42
-20
lines changed

src/analyses/activity.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
@enum ActivityState begin
2-
AnyState = 0
3-
ActiveState = 1
4-
DupState = 2
5-
MixedState = 3
6-
end
7-
8-
@inline function Base.:|(a1::ActivityState, a2::ActivityState)
9-
ActivityState(Int(a1) | Int(a2))
10-
end
11-
121
@inline element(::Val{T}) where {T} = T
132

143
@inline ptreltype(::Type{Ptr{T}}) where {T} = T
@@ -393,6 +382,14 @@ Base.@nospecializeinfer @inline function active_reg_inner(
393382
return ty
394383
end
395384

385+
function active_reg_cached(ctx::EnzymeContext, @nospecialize(ST::Type); justActive=false, UnionSret = false, AbstractIsMixed = false)
386+
key = (ST, justActive, UnionSret, AbstractIsMixed)
387+
get!(ctx.activity_cache, key) do
388+
set = Base.IdSet{Type}()
389+
active_reg_inner(ST, set, ctx.world, justActive, UnionSret, AbstractIsMixed)
390+
end
391+
end
392+
396393
Base.@nospecializeinfer @inline function active_reg(@nospecialize(ST::Type), world::UInt; justActive=false, UnionSret = false, AbstractIsMixed = false)
397394
set = Base.IdSet{Type}()
398395
return active_reg_inner(ST, set, world, justActive, UnionSret, AbstractIsMixed)

src/compiler.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,25 @@ import LLVM: Target, TargetMachine
4141
import SparseArrays
4242
using Printf
4343

44+
@enum ActivityState begin
45+
AnyState = 0
46+
ActiveState = 1
47+
DupState = 2
48+
MixedState = 3
49+
end
50+
51+
@inline function Base.:|(a1::ActivityState, a2::ActivityState)
52+
ActivityState(Int(a1) | Int(a2))
53+
end
54+
55+
mutable struct EnzymeContext
56+
world::UInt
57+
activity_cache::Dict{Tuple{Type,Bool,Bool,Bool},ActivityState}
58+
function EnzymeContext(world)
59+
new(world, Dict{Tuple{Type,Bool,Bool,Bool},ActivityState}())
60+
end
61+
end
62+
4463
using Preferences
4564

4665
bitcode_replacement() = parse(Bool, @load_preference("bitcode_replacement", "true"))
@@ -3109,7 +3128,7 @@ function create_abi_wrapper(
31093128
# 3 is index of shadow
31103129
if existed[3] != 0 &&
31113130
sret_union &&
3112-
active_reg(pactualRetType, world; justActive=true, UnionSret=true) == ActiveState
3131+
active_reg_cached(interp.context, pactualRetType; justActive=true, UnionSret=true) == ActiveState
31133132
rewrite_union_returns_as_ref(enzymefn, data[3], world, width)
31143133
end
31153134
returnNum = 0
@@ -4773,7 +4792,7 @@ end
47734792
if params.err_if_func_written
47744793
FT = TT.parameters[1]
47754794
Ty = eltype(FT)
4776-
reg = active_reg(Ty, job.world)
4795+
reg = active_reg_cached(interp.context, Ty)
47774796
if reg == DupState || reg == MixedState
47784797
swiftself = has_swiftself(primalf)
47794798
todo = LLVM.Value[parameters(primalf)[1+swiftself]]
@@ -4797,7 +4816,7 @@ end
47974816
if !mayWriteToMemory(user)
47984817
slegal, foundv, byref = abs_typeof(user)
47994818
if slegal
4800-
reg2 = active_reg(foundv, job.world)
4819+
reg2 = active_reg_cached(interp.context, foundv)
48014820
if reg2 == ActiveState || reg2 == AnyState
48024821
continue
48034822
end
@@ -4825,7 +4844,7 @@ end
48254844
if operands(user)[2] == cur
48264845
slegal, foundv, byref = abs_typeof(operands(user)[1])
48274846
if slegal
4828-
reg2 = active_reg(foundv, job.world)
4847+
reg2 = active_reg_cached(interp.context, foundv)
48294848
if reg2 == AnyState
48304849
continue
48314850
end
@@ -4859,7 +4878,7 @@ end
48594878
if is_readonly(called)
48604879
slegal, foundv, byref = abs_typeof(user)
48614880
if slegal
4862-
reg2 = active_reg(foundv, job.world)
4881+
reg2 = active_reg_cached(interp.context, foundv)
48634882
if reg2 == ActiveState || reg2 == AnyState
48644883
continue
48654884
end
@@ -4877,7 +4896,7 @@ end
48774896
end
48784897
slegal, foundv, byref = abs_typeof(user)
48794898
if slegal
4880-
reg2 = active_reg(foundv, job.world)
4899+
reg2 = active_reg_cached(interp.context, foundv)
48814900
if reg2 == ActiveState || reg2 == AnyState
48824901
continue
48834902
end

src/compiler/interpreter.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
136136
within_autodiff_rewrite::Bool
137137

138138
handler::T
139+
140+
context::Enzyme.Compiler.EnzymeContext
139141
end
140142

141143
const SigCache = Dict{Tuple, Dict{UInt, Base.IdSet{Type}}}()
@@ -247,7 +249,8 @@ function EnzymeInterpreter(
247249
inactive_rules::Bool,
248250
broadcast_rewrite::Bool,
249251
within_autodiff_rewrite::Bool,
250-
handler
252+
handler,
253+
Enzyme.Compiler.EnzymeContext(world)
251254
)
252255
end
253256

@@ -278,7 +281,9 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
278281
inactive_rules = interp.inactive_rules,
279282
broadcast_rewrite = interp.broadcast_rewrite,
280283
within_autodiff_rewrite = interp.within_autodiff_rewrite,
281-
handler = interp.handler)
284+
handler = interp.handler,
285+
context = interp.context,)
286+
@assert context.world == world
282287
return EnzymeInterpreter(
283288
cache_or_token,
284289
mt,
@@ -291,7 +296,8 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
291296
inactive_rules,
292297
broadcast_rewrite,
293298
within_autodiff_rewrite,
294-
handler
299+
handler,
300+
context
295301
)
296302
end
297303

0 commit comments

Comments
 (0)