Skip to content

Commit a87f1e8

Browse files
committed
Make sure that we pick up new definition of typetree
1 parent c889e43 commit a87f1e8

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

src/compiler.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ function set_module_types!(mod::LLVM.Module, primalf::Union{Nothing, LLVM.Functi
10781078

10791079
byref = arg.cc
10801080

1081-
rest = copy(typetree(arg.typ, ctx, dl))
1081+
rest = copy(typetree_total(job, typetree, arg.typ, ctx, dl))
10821082

10831083
if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF
10841084
# adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader
@@ -1103,7 +1103,7 @@ function set_module_types!(mod::LLVM.Module, primalf::Union{Nothing, LLVM.Functi
11031103
if sret !== nothing
11041104
idx = 0
11051105
if !in(0, parmsRemoved)
1106-
rest = typetree(sret, ctx, dl)
1106+
rest = typetree_total(job, sret, ctx, dl)
11071107
push!(
11081108
parameter_attributes(f, idx + 1),
11091109
StringAttribute("enzyme_type", string(rest)),
@@ -1125,12 +1125,12 @@ function set_module_types!(mod::LLVM.Module, primalf::Union{Nothing, LLVM.Functi
11251125
LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType()
11261126
@assert !retRemoved
11271127
rest = if llRT == Ptr{RT}
1128-
typeTree = copy(typetree(RT, ctx, dl))
1128+
typeTree = copy(typetree_total(job, RT, ctx, dl))
11291129
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
11301130
only!(typeTree, -1)
11311131
typeTree
11321132
else
1133-
typetree(RT, ctx, dl)
1133+
typetree_total(job, RT, ctx, dl)
11341134
end
11351135
push!(return_attributes(f), StringAttribute("enzyme_type", string(rest)))
11361136
end

src/typetree.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,16 @@ end
195195

196196
const TypeTreeTable = IdDict{Any,Union{Nothing,TypeTree}}
197197

198+
"""
199+
typetree_total(job, T, ctx, dl, seen=TypeTreeTable())
200+
201+
A wrapper around `typetree` that ensures the call happens in the correct world for GPUCompiler.
202+
Useful when using typetree from a generated function since typetree is user-extendable.
203+
"""
204+
function typetree_total(@nospecialize(job::GPUCompiler.CompilerJob), @nospecialize(T), ctx, dl, seen=TypeTreeTable())
205+
return Core._call_in_world_total(job.world, typetree, T, ctx, dl)
206+
end
207+
198208
"""
199209
function typetree(T, ctx, dl, seen=TypeTreeTable())
200210

0 commit comments

Comments
 (0)