Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type inference failure when compiled with custom AbstractInterpreter (e.g. GPUCompiler) #643

Closed
wsmoses opened this issue May 28, 2024 · 9 comments

Comments

@wsmoses
Copy link
Contributor

wsmoses commented May 28, 2024

julia> Core.Compiler.typeinf_ext_toplevel(interp, mi)
CodeInfo(
     @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:93 within `logdensity`
    ┌ @ Base.jl:37 within `getproperty`
1 ──│ %1  = Base.getfield(f, :varinfo)::TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}
│   └
│   ┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/abstract_varinfo.jl:747 within `unflatten` @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:134 @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:137
│   │┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:116 within `VarInfo`
│   ││┌ @ Base.jl:37 within `getproperty`
│   │││ %2  = Base.getfield(%1, :metadata)::@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}
│   ││└
│   ││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:158 within `newmetadata`
│   │││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl within `macro expansion`
│   ││││┌ @ Base.jl:37 within `getproperty`
│   │││││ %3  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %4  = Base.getfield(%3, :idcs)::Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}
│   │││││ %5  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %6  = Base.getfield(%5, :vns)::Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}
│   │││││ %7  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %8  = Base.getfield(%7, :ranges)::Vector{UnitRange{Int64}}
│   │││││ %9  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %10 = Base.getfield(%9, :ranges)::Vector{UnitRange{Int64}}
│   ││││└
│   ││││ %11 = DynamicPPL.length::typeof(length)
│   ││││┌ @ reducedim.jl:1011 within `sum`
│   │││││┌ @ reducedim.jl:1011 within `#sum#829`
│   ││││││┌ @ reducedim.jl:1015 within `_sum`
│   │││││││┌ @ reducedim.jl:1015 within `#_sum#831`
│   ││││││││ %12 = Base.add_sum::typeof(Base.add_sum)
│   ││││││││┌ @ reducedim.jl:357 within `mapreduce`
│   │││││││││┌ @ reducedim.jl:357 within `#mapreduce#821`
│   ││││││││││┌ @ reducedim.jl:365 within `_mapreduce_dim`
│   │││││││││││ %13 = invoke Base._mapreduce(%11::typeof(length), %12::typeof(Base.add_sum), $(QuoteNode(IndexLinear()))::IndexLinear, %10::Vector{UnitRange{Int64}})::Int64
│   ││││└└└└└└└
│   ││││┌ @ int.jl:87 within `+`
│   │││││ %14 = Base.add_int(0, %13)::Int64
│   ││││└
│   ││││┌ @ range.jl:5 within `Colon`
│   │││││┌ @ range.jl:403 within `UnitRange`
│   ││││││┌ @ range.jl:414 within `unitrange_last`
│   │││││││┌ @ operators.jl:425 within `>=`
│   ││││││││┌ @ int.jl:514 within `<=`
│   │││││││││ %15 = Base.sle_int(1, %14)::Bool
│   │││││││└└
└───│││││││       goto TuringLang/Turing.jl#3 if not %15
2 ──│││││││       goto TuringLang/Turing.jl#4
3 ──│││││││       goto TuringLang/Turing.jl#4
    ││││││└
4 ┄─││││││ %19 = φ (#2 => %14, TuringLang/Turing.jl#3 => 0)::Int64
│   ││││││ %20 = %new(UnitRange{Int64}, 1, %19)::UnitRange{Int64}
└───││││││       goto TuringLang/Turing.jl#5
5 ──││││││       goto TuringLang/Turing.jl#6
    ││││└└
    ││││┌ @ array.jl:973 within `getindex`
6 ──│││││       goto TuringLang/Turing.jl#11 if not true
    │││││┌ @ abstractarray.jl:700 within `checkbounds`
7 ──││││││ %24 = Core.tuple(%20)::Tuple{UnitRange{Int64}}
│   ││││││ @ abstractarray.jl:702 within `checkbounds` @ abstractarray.jl:687
│   ││││││┌ @ abstractarray.jl:389 within `eachindex`
│   │││││││┌ @ abstractarray.jl:137 within `axes1`
│   ││││││││┌ @ abstractarray.jl:98 within `axes`
│   │││││││││┌ @ array.jl:191 within `size`
│   ││││││││││ %25 = Base.arraysize(θ, 1)::Int64
│   │││││││││└
│   │││││││││┌ @ tuple.jl:291 within `map`
│   ││││││││││┌ @ range.jl:469 within `oneto`
│   │││││││││││┌ @ range.jl:467 within `OneTo` @ range.jl:454
│   ││││││││││││┌ @ promotion.jl:532 within `max`
│   │││││││││││││┌ @ int.jl:83 within `<`
│   ││││││││││││││ %26 = Base.slt_int(%25, 0)::Bool
│   │││││││││││││└
│   │││││││││││││┌ @ essentials.jl:647 within `ifelse`
│   ││││││││││││││ %27 = Core.ifelse(%26, 0, %25)::Int64
│   ││││││└└└└└└└└
│   ││││││┌ @ abstractarray.jl:768 within `checkindex`
│   │││││││┌ @ range.jl:672 within `isempty`
│   ││││││││┌ @ operators.jl:378 within `>`
│   │││││││││┌ @ int.jl:83 within `<`
│   ││││││││││ %28 = Base.slt_int(%19, 1)::Bool
│   │││││││└└└
│   │││││││ @ abstractarray.jl:768 within `checkindex` @ abstractarray.jl:763
│   │││││││┌ @ int.jl:86 within `-`
│   ││││││││ %29 = Base.sub_int(1, 1)::Int64
│   │││││││└
│   │││││││┌ @ essentials.jl:524 within `unsigned`
│   ││││││││┌ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %30 = Base.bitcast(UInt64, %29)::UInt64
│   │││││││││ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %31 = Base.bitcast(UInt64, %27)::UInt64
│   │││││││└└
│   │││││││┌ @ int.jl:513 within `<`
│   ││││││││ %32 = Base.ult_int(%30, %31)::Bool
│   │││││││└
│   │││││││┌ @ int.jl:86 within `-`
│   ││││││││ %33 = Base.sub_int(%19, 1)::Int64
│   │││││││└
│   │││││││┌ @ essentials.jl:524 within `unsigned`
│   ││││││││┌ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %34 = Base.bitcast(UInt64, %33)::UInt64
│   │││││││││ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %35 = Base.bitcast(UInt64, %27)::UInt64
│   │││││││└└
│   │││││││┌ @ int.jl:513 within `<`
│   ││││││││ %36 = Base.ult_int(%34, %35)::Bool
│   │││││││└
│   │││││││ @ abstractarray.jl:768 within `checkindex`
│   │││││││┌ @ bool.jl:38 within `&`
│   ││││││││ %37 = Base.and_int(%32, %36)::Bool
│   │││││││└
│   │││││││┌ @ bool.jl:39 within `|`
│   ││││││││ %38 = Base.or_int(%28, %37)::Bool
│   ││││││└└
│   ││││││ @ abstractarray.jl:702 within `checkbounds`
└───││││││       goto TuringLang/Turing.jl#9 if not %38
8 ──││││││       goto TuringLang/Turing.jl#10
9 ──││││││       invoke Base.throw_boundserror::Vector{Float64}, %24::Tuple{UnitRange{Int64}})::Union{}
└───││││││       unreachable
10 ─││││││       nothing::Nothing
    │││││└
    │││││ @ array.jl:974 within `getindex`
    │││││┌ @ range.jl:761 within `length`
    ││││││┌ @ int.jl:86 within `-`
11 ┄│││││││ %44 = Base.sub_int(%19, 1)::Int64
│   ││││││└
│   ││││││┌ @ int.jl:87 within `+`
│   │││││││ %45 = Base.add_int(1, %44)::Int64
│   │││││└└
│   │││││ @ array.jl:975 within `getindex`
│   │││││┌ @ range.jl:706 within `axes`
│   ││││││┌ @ range.jl:761 within `length`
│   │││││││┌ @ int.jl:86 within `-`
│   ││││││││ %46 = Base.sub_int(%19, 1)::Int64
│   │││││││└
│   │││││││┌ @ int.jl:87 within `+`
│   ││││││││ %47 = Base.add_int(1, %46)::Int64
│   ││││││└└
│   ││││││┌ @ range.jl:469 within `oneto`
│   │││││││┌ @ range.jl:467 within `OneTo` @ range.jl:454
│   ││││││││┌ @ promotion.jl:532 within `max`
│   │││││││││┌ @ int.jl:83 within `<`
│   ││││││││││ %48 = Base.slt_int(%47, 0)::Bool
│   │││││││││└
│   │││││││││┌ @ essentials.jl:647 within `ifelse`
│   ││││││││││ %49 = Core.ifelse(%48, 0, %47)::Int64
│   │││││└└└└└
│   │││││┌ @ abstractarray.jl:831 within `similar` @ array.jl:420
│   ││││││┌ @ boot.jl:486 within `Array` @ boot.jl:477
│   │││││││ %50 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Float64}, svec(Any, Int64), 0, :(:ccall), Vector{Float64}, :(%49), :(%49)))::Vector{Float64}
│   │││││└└
│   │││││ @ array.jl:976 within `getindex`
│   │││││┌ @ operators.jl:378 within `>`
│   ││││││┌ @ int.jl:83 within `<`
│   │││││││ %51 = Base.slt_int(0, %45)::Bool
│   │││││└└
└───│││││       goto TuringLang/Turing.jl#13 if not %51
    │││││ @ array.jl:977 within `getindex`
    │││││┌ @ array.jl:368 within `copyto!`
12 ─││││││       invoke Base._copyto_impl!(%50::Vector{Float64}, 1::Int64, θ::Vector{Float64}, 1::Int64, %45::Int64)::Vector{Float64}
    │││││└
    │││││ @ array.jl:979 within `getindex`
13 ┄│││││       goto TuringLang/Turing.jl#14
    ││││└
    ││││┌ @ Base.jl:37 within `getproperty`
14 ─│││││ %55 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %56 = Base.getfield(%55, :dists)::Vector{IsoNormal}
│   │││││ %57 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %58 = Base.getfield(%57, :gids)::Vector{Set{DynamicPPL.Selector}}
│   │││││ %59 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %60 = Base.getfield(%59, :orders)::Vector{Int64}
│   │││││ %61 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %62 = Base.getfield(%61, :flags)::Dict{String, BitVector}
│   ││││└
│   ││││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:47 within `Metadata`
│   │││││ %63 = %new(DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, %4, %6, %8, %50, %56, %58, %60, %62)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   ││││└
│   ││││┌ @ boot.jl:622 within `NamedTuple`
│   │││││ %64 = %new(@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, %63)::@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}
│   ││││└
└───││││       goto TuringLang/Turing.jl#15
    ││└└
    ││ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:117 within `VarInfo`
    ││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:906 within `getlogp`
    │││┌ @ Base.jl:37 within `getproperty`
15 ─││││ %66 = Base.getfield(%1, :logp)::Base.RefValue{Float64}
│   │││└
│   │││┌ @ refvalue.jl:59 within `getindex`
│   ││││┌ @ Base.jl:37 within `getproperty`
│   │││││ %67 = Base.getfield(%66, :x)::Float64
│   ││└└└
│   ││┌ @ refvalue.jl:8 within `RefValue`
│   │││ %68 = %new(Base.RefValue{Float64}, %67)::Base.RefValue{Float64}
│   ││└
│   ││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:923 within `get_num_produce`
│   │││┌ @ Base.jl:37 within `getproperty`
│   ││││ %69 = Base.getfield(%1, :num_produce)::Base.RefValue{Int64}
│   │││└
│   │││┌ @ refvalue.jl:59 within `getindex`
│   ││││┌ @ Base.jl:37 within `getproperty`
│   │││││ %70 = Base.getfield(%69, :x)::Int64
│   ││└└└
│   ││┌ @ refpointer.jl:137 within `Ref`
│   │││┌ @ refvalue.jl:10 within `RefValue` @ refvalue.jl:8
│   ││││ %71 = %new(Base.RefValue{Int64}, %70)::Base.RefValue{Int64}
│   ││└└
│   ││ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:117 within `VarInfo` @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:100
│   ││ %72 = %new(TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, %64, %68, %71)::TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}
│   ││ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:117 within `VarInfo`
└───││       goto TuringLang/Turing.jl#16
16 ─││       goto TuringLang/Turing.jl#17
17 ─││       goto TuringLang/Turing.jl#18
18 ─││       goto TuringLang/Turing.jl#19
    └└
     @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:94 within `logdensity`
19 ─       invoke DynamicPPL.evaluate!!($(QuoteNode(Model{typeof(demo2), (Symbol("##arg#225"),), (), (), Tuple{DynamicPPL.TypeWrap{Matrix{Float64}}}, Tuple{}, DefaultContext}(demo2, (var"##arg#225" = DynamicPPL.TypeWrap{Matrix{Float64}}(),), NamedTuple(), DefaultContext())))::Model{typeof(demo2), (Symbol("##arg#225"),), (), (), Tuple{DynamicPPL.TypeWrap{Matrix{Float64}}}, Tuple{}, DefaultContext}, %72::TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, $(QuoteNode(DefaultContext()))::DefaultContext)::Union{}
└───       unreachable
)

julia> Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype
Union{}
@wsmoses
Copy link
Contributor Author

wsmoses commented May 28, 2024

Some relevant code:

using Turing, Enzyme, LinearAlgebra, LogDensityProblems

using AbstractPPL
using DynamicPPL
using Accessors

using GPUCompiler
Enzyme.API.runtimeActivity!(true);

@model function demo2(::Type{TV}=Matrix{Float64}) where {TV}
    d = 2
    n = 2
    x = TV(undef, d, n)
    x[:, 1] ~ MvNormal(zeros(d), I)
    for i = 2:n
        x[:, i] ~ MvNormal(x[:, i - 1], I)
    end
end

model = demo2()
ℓ = Turing.LogDensityFunction(model)
θ = ℓ.varinfo[:]

x = θ

@show LogDensityProblems.logdensity(ℓ, x)

Enzyme.autodiff(ReverseWithPrimal, LogDensityProblems.logdensity, Active, Const(ℓ), Enzyme.Duplicated(x, zero(x)))
World = Base.get_world_counter()
FA = Const{typeof(LogDensityProblems.logdensity)}
A = Active
width = 1
Mode = Enzyme.API.DEM_ReverseModeCombined
ModifiedBetween = (false, false)
ReturnPrimal = true
ShadowInit = false
ABI = Enzyme.FFIABI
TT = Tuple{Const{LogDensityFunction{DynamicPPL.TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, DynamicPPL.Model{typeof(demo2), (Symbol("##arg#225"),), (), (), Tuple{DynamicPPL.TypeWrap{Matrix{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.DefaultContext}}, Duplicated{Vector{Float64}}}

mi = Enzyme.Compiler.fspec(eltype(FA), TT, World)

target = Enzyme.Compiler.EnzymeTarget()
params = Enzyme.Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, Enzyme.Compiler.remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, Enzyme.Compiler.UnknownTapeType, ABI)
tmp_job    = Enzyme.Compiler.CompilerJob(mi, Enzyme.Compiler.CompilerConfig(target, params; kernel=false), World)

interp = GPUCompiler.get_interpreter(tmp_job)

spec = specialize_method(mi.def, mi.specTypes, mi.sparam_vals)
Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals)

@torfjelde
Copy link
Member

So is this a Turing.jl issue or a GPUCompiler.jl issue? Given that the type inference works nicely without GPUCompiler

@wsmoses
Copy link
Contributor Author

wsmoses commented May 29, 2024 via email

@torfjelde
Copy link
Member

Just to clarify a bit here: there's no "bug" per-se in Turing.jl. The "bug" is just that there's no constructor for MvNormal with eltype Any, but arguably that's not a desirable thing to support. That is, this model is only expected to work when everything is type-stable.

But this constructor is only hit because GPUCompiler somehow causes an inference issue, leading to Any when every other approach correctly infers it as Float64.

@wsmoses
Copy link
Contributor Author

wsmoses commented May 29, 2024

Sure, I'm not sure which subpackage used by turing the error is caused by.

My guess probalbly is that somewhere there is a use of a typeof(x) [aka inferred type] instead of Core.Typeof(x) [aka runtime type] which would correct the construction.

@wsmoses
Copy link
Contributor Author

wsmoses commented May 29, 2024

cc @maleadt @vchuravy for visibility

@mhauru
Copy link
Member

mhauru commented Jul 15, 2024

@yebai
Copy link
Member

yebai commented Aug 21, 2024

@willtebbutt wrote in a slack discussion

I’ve found a fairly simple situation in which the results of inference differ depending on whether you use a Core.Compiler.NativeInterpreter() , or one of the various custom AbstractInterpreter s in use in the wild (e.g. Cthulhu.jl’s , Enzyme.jl’s, and Tapir.jl’s).
In particular, the native interpreter successfully infers the return type of Base._mapreduce_dim(Base.Fix1(view, [5.0, 4.0]), vcat, Float64[], [1:1, 2:2], :) to be Vector{Float64, while the other abstract interpreters infer Any.

using Cthulhu, Enzyme, Tapir

# Specify function + args.
fargs = (Base._mapreduce_dim, Base.Fix1(view, [5.0, 4.0]), vcat, Float64[], [1:1, 2:2], :)
tt = typeof(fargs)

# Construct the relevant interpreters.
native_interp = Core.Compiler.NativeInterpreter();
cthulhu_interp = Cthulhu.CthulhuInterpreter();
enzyme_interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(
    Enzyme.Compiler.GLOBAL_REV_CACHE,
    nothing,
    Base.get_world_counter(),
    Enzyme.API.DEM_ReverseModeCombined,
);
tapir_interp = Tapir.TapirInterpreter();

# Both of these correctly infer the return type, Vector{Float64}.
Base.code_typed_by_type(tt; optimize=true, interp=native_interp)
Base.code_ircode_by_type(tt; optimize_until=nothing, interp=native_interp)

# Inference fails.
Base.code_typed_by_type(tt; optimize=true, interp=cthulhu_interp)
Base.code_ircode_by_type(tt; optimize_until=nothing, interp=cthulhu_interp)

# Inference fails.
Base.code_typed_by_type(tt; optimize=true, interp=enzyme_interp)
Base.code_ircode_by_type(tt; optimize_until=nothing, interp=enzyme_interp)

# Inference fails.
Base.code_typed_by_type(tt; optimize=true, interp=tapir_interp)
Base.code_ircode_by_type(tt; optimize_until=nothing, interp=tapir_interp)

@wsmoses pointed out the above compiler bug might be related to this issue.

@yebai yebai transferred this issue from TuringLang/Turing.jl Aug 21, 2024
@yebai yebai changed the title Turing guaranteed to error (when compiled with GPUCompiler) Type inference failure when compiled with GPUCompiler Aug 21, 2024
@yebai yebai changed the title Type inference failure when compiled with GPUCompiler Type inference failure when compiled with custom AbstractInterpreter (e.g. GPUCompiler) Aug 21, 2024
@yebai
Copy link
Member

yebai commented Sep 3, 2024

Not a DynamicPPL/Turing issue; close in favour of JuliaLang/julia#55638

@yebai yebai closed this as not planned Won't fix, can't repro, duplicate, stale Sep 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants