Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some type stabilization to get a new gradient from Enzyme #3360

Merged
merged 52 commits into from
Jan 30, 2024

Conversation

glwagner
Copy link
Member

@glwagner glwagner commented Oct 24, 2023

with @wsmoses and @jlk9

Also defines __groupsize(cm) for our special CompilerMetadata / KernelAbstractions extension cc @simone-silvestri

@glwagner glwagner marked this pull request as draft October 24, 2023 00:08
@glwagner
Copy link
Member Author

Latest error:

ERROR: LoadError: task switch not allowed from inside staged nor pure functions
Stacktrace:
  [1] try_yieldto(undo::typeof(Base.ensure_rescheduled))
    @ Base ./task.jl:921
  [2] wait()
    @ Base ./task.jl:995
  [3] uv_write(s::Base.TTY, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:1048
  [4] unsafe_write(s::Base.TTY, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:1120
  [5] write
    @ Base ./strings/io.jl:248 [inlined]
  [6] print
    @ Base ./strings/io.jl:250 [inlined]
  [7] print(::Base.TTY, ::String, ::String, ::Vararg{String})
    @ Base ./strings/io.jl:46
  [8] println(::Base.TTY, ::String, ::Vararg{String})
    @ Base ./strings/io.jl:75
  [9] println(::String, ::String)
    @ Base ./coreio.jl:4
 [10] calling_conv_fixup(builder::LLVM.IRBuilder, val::LLVM.AddrSpaceCastInst, tape::LLVM.PointerType, prev::LLVM.UndefValue, lidxs::Vector{…}, ridxs::Vector{…})
    @ Enzyme.Compiler ~/Projects/Enzyme.jl/src/compiler/utils.jl:270
 [11] calling_conv_fixup (repeats 2 times)
    @ Enzyme.Compiler ~/Projects/Enzyme.jl/src/compiler/utils.jl:183 [inlined]
 [12] calling_conv_fixup(builder::LLVM.IRBuilder, val::LLVM.AddrSpaceCastInst, tape::LLVM.PointerType)
    @ Enzyme.Compiler ~/Projects/Enzyme.jl/src/compiler/utils.jl:183
 [13] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{…}, shadowR::Ptr{…}, tape::LLVM.ExtractValueInst)
    @ Enzyme.Compiler ~/Projects/Enzyme.jl/src/compiler.jl:4610
 [14] enzyme_custom_rev(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, tape::LLVM.ExtractValueInst)
    @ Enzyme.Compiler ~/Projects/Enzyme.jl/src/compiler.jl:4770

@wsmoses
Copy link
Collaborator

wsmoses commented Oct 30, 2023

That is ironically an error in the error printer. Can you convert that to an assertion so we can see the actual error message?

@glwagner
Copy link
Member Author

glwagner commented Nov 2, 2023

you mean like this?

ERROR: LoadError: AssertionError: false
Stacktrace:
  [1] calling_conv_fixup(builder::LLVM.IRBuilder, val::LLVM.AddrSpaceCastInst, tape::LLVM.PointerType, prev::LLVM.UndefValue, lidxs::Vector{UInt32}, ridxs::Vector{UInt32})
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler/utils.jl:271
  [2] calling_conv_fixup (repeats 2 times)
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler/utils.jl:183 [inlined]
  [3] calling_conv_fixup(builder::LLVM.IRBuilder, val::LLVM.AddrSpaceCastInst, tape::LLVM.PointerType)
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler/utils.jl:183
  [4] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{Nothing}, shadowR::Ptr{Nothing}, tape::LLVM.ExtractValueInst)
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:4610
  [5] enzyme_custom_rev(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, tape::LLVM.ExtractValueInst)
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:4770
  [6] (::Enzyme.Compiler.var"#201#202")(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, tape::Ptr{LLVM.API.LLVMOpaqueValue})
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:6657
  [7] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/Projects/Enzymantics/Enzyme.jl/src/api.jl:141
  [8] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, jlrules::Vector{…}, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:7715
  [9] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:9278
 [10] codegen
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:8886 [inlined]
 [11] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:9830
 [12] cached_compilation
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:9864 [inlined]
 [13] (::Enzyme.Compiler.var"#474#475"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:9921
 [14] JuliaContext(f::Enzyme.Compiler.var"#474#475"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:47
 [15] #s324#473
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:9882 [inlined]
 [16]
    @ Enzyme.Compiler ./none:0
 [17] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:600
 [18] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::Type{…}, df::Nothing, primal_1::var"#cᵢ#1"{…}, shadow_1_1::var"#cᵢ#1"{…}, primal_2::RectilinearGrid{…}, shadow_2_1::RectilinearGrid{…})
    @ Enzyme.Compiler ~/Projects/Enzymantics/Enzyme.jl/src/compiler.jl:1386
 [19] FunctionField
    @ ~/Projects/Enzymantics/Oceananigans.jl/src/Fields/function_field.jl:54 [inlined]

I guess we don't hit any of the if statments in calling_conv_fixup so then we get to the end where @assert false. The problem is that we don't match any of those conditions?????

@glwagner
Copy link
Member Author

glwagner commented Nov 2, 2023

@jlk9 might take up this work 😏 we'll see if I convinced him

@jlk9
Copy link
Collaborator

jlk9 commented Nov 2, 2023

above now fixed, now we get back to our favorite tuples

 [10] flattened_unique_values
    @ ~/.julia/packages/Oceananigans/i9N3H/src/Fields/field_tuples.jl:19 [inlined]
    ```
    
    ```
    @inline function flattened_unique_values(a::Union{NamedTuple, Tuple})
    tupled = Tuple(tuplify(ai) for ai in a)
    flattened = flatten_tuple(tupled)

    # Alternative implementation of `unique` for tuples that uses === comparison, rather than ==
    seen = []
    return Tuple(last(push!(seen, f)) for f in flattened if !any(f === s for s in seen))
end

@glwagner
Copy link
Member Author

glwagner commented Nov 2, 2023

Do you know how to oimplement the "ntuple trick" @jlk9 ? I can help

@wsmoses
Copy link
Collaborator

wsmoses commented Jan 29, 2024

@glwagner the Enzyme CI appears to pass here whereas failures appear unrelated.

@glwagner
Copy link
Member Author

@glwagner glwagner merged commit dc34b80 into main Jan 30, 2024
48 checks passed
@glwagner glwagner deleted the glw/type-stable-with-tracers branch January 30, 2024 21:09
[[deps.Enzyme]]
deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Preferences", "Printf", "Random"]
git-tree-sha1 = "ae881b2f107e3c01444edbaa0bf3b73f461539f6"
repo-rev = "main"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to depend on a tagged version of a package

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah, let's change that to the latest

@@ -11,6 +11,7 @@ CubedSphere = "7445602f-e544-4518-8976-18f8e8ae6cdb"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so now Enzyme is both a normal and a weak dependendcy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand how that works, did we do it wrong?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I'm not sure if it was accidental or on purpose. But adding it there implies that's a core dependency of Oceananigans. I'll have a look.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x-ref #3452

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants