@@ -41,6 +41,25 @@ import LLVM: Target, TargetMachine
4141import SparseArrays
4242using Printf
4343
44+ @enum ActivityState begin
45+ AnyState = 0
46+ ActiveState = 1
47+ DupState = 2
48+ MixedState = 3
49+ end
50+
51+ @inline function Base.:| (a1:: ActivityState , a2:: ActivityState )
52+ ActivityState (Int (a1) | Int (a2))
53+ end
54+
55+ mutable struct EnzymeContext
56+ world:: UInt
57+ activity_cache:: Dict{Tuple{Type,Bool,Bool,Bool},ActivityState}
58+ function EnzymeContext (world)
59+ new (world, Dict {Tuple{Type,Bool,Bool,Bool},ActivityState} ())
60+ end
61+ end
62+
4463using Preferences
4564
4665bitcode_replacement () = parse (Bool, @load_preference (" bitcode_replacement" , " true" ))
@@ -3109,7 +3128,7 @@ function create_abi_wrapper(
31093128 # 3 is index of shadow
31103129 if existed[3 ] != 0 &&
31113130 sret_union &&
3112- active_reg (pactualRetType, world ; justActive= true , UnionSret= true ) == ActiveState
3131+ active_reg_cached (interp . context, pactualRetType ; justActive= true , UnionSret= true ) == ActiveState
31133132 rewrite_union_returns_as_ref (enzymefn, data[3 ], world, width)
31143133 end
31153134 returnNum = 0
@@ -4773,7 +4792,7 @@ end
47734792 if params. err_if_func_written
47744793 FT = TT. parameters[1 ]
47754794 Ty = eltype (FT)
4776- reg = active_reg (Ty, job . world )
4795+ reg = active_reg_cached (interp . context, Ty )
47774796 if reg == DupState || reg == MixedState
47784797 swiftself = has_swiftself (primalf)
47794798 todo = LLVM. Value[parameters (primalf)[1 + swiftself]]
@@ -4797,7 +4816,7 @@ end
47974816 if ! mayWriteToMemory (user)
47984817 slegal, foundv, byref = abs_typeof (user)
47994818 if slegal
4800- reg2 = active_reg (foundv, job . world )
4819+ reg2 = active_reg_cached (interp . context, foundv )
48014820 if reg2 == ActiveState || reg2 == AnyState
48024821 continue
48034822 end
@@ -4825,7 +4844,7 @@ end
48254844 if operands (user)[2 ] == cur
48264845 slegal, foundv, byref = abs_typeof (operands (user)[1 ])
48274846 if slegal
4828- reg2 = active_reg (foundv, job . world )
4847+ reg2 = active_reg_cached (interp . context, foundv )
48294848 if reg2 == AnyState
48304849 continue
48314850 end
@@ -4859,7 +4878,7 @@ end
48594878 if is_readonly (called)
48604879 slegal, foundv, byref = abs_typeof (user)
48614880 if slegal
4862- reg2 = active_reg (foundv, job . world )
4881+ reg2 = active_reg_cached (interp . context, foundv )
48634882 if reg2 == ActiveState || reg2 == AnyState
48644883 continue
48654884 end
@@ -4877,7 +4896,7 @@ end
48774896 end
48784897 slegal, foundv, byref = abs_typeof (user)
48794898 if slegal
4880- reg2 = active_reg (foundv, job . world )
4899+ reg2 = active_reg_cached (interp . context, foundv )
48814900 if reg2 == ActiveState || reg2 == AnyState
48824901 continue
48834902 end
0 commit comments