Skip to content

Small refactoring of make_mlir_fn + more #1226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 3, 2025
Merged

Small refactoring of make_mlir_fn + more #1226

merged 12 commits into from
May 3, 2025

Conversation

jumerckx
Copy link
Collaborator

@jumerckx jumerckx commented Apr 30, 2025

This pr splits the implementation of make_mlir_fn into multiple parts:

  1. prepare_mlir_fn_args
  2. process_linear_args
  3. trace the function by calling it
  4. finalize_mlir_fn

This is a prerequisite for doing automatic function call insertion as 1., 2., and 4. can be repurposed there. I'd like to merge this already in order to prevent constant merge conflicts as I make progress on the call insertions.

Unrelated, I moved code in Ops.call around, and did slight refactoring in utils.jl by introducing push_inst!, both also useful for the call insertion. If any of these changes turn out to be more contentious than the others, I can split this pr as well.

@jumerckx jumerckx marked this pull request as draft April 30, 2025 15:17
@jumerckx jumerckx force-pushed the jm/make_mlir_fn branch from 4b379d3 to c1670af Compare May 1, 2025 15:12
@jumerckx jumerckx requested a review from wsmoses May 1, 2025 18:37
@jumerckx jumerckx marked this pull request as ready for review May 1, 2025 18:37
@jumerckx
Copy link
Collaborator Author

jumerckx commented May 1, 2025

I'm not sure if CI failures are related to this pr?

@wsmoses
Copy link
Member

wsmoses commented May 1, 2025

DUS: Error During Test at /home/runner/work/Reactant.jl/Reactant.jl/test/optimize_comm.jl:64
  Got exception outside of a @test
  UndefVarError: `ctx` not defined
  Stacktrace:
    [1] finalize_mlir_fn(result::Nothing, traced_args::Vector{Any}, linear_args::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, seen_args::Reactant.OrderedIdDict{Any, Any}, seen_results::Reactant.OrderedIdDict{Any, Any}, fnbody::Reactant.MLIR.IR.Block, func::Reactant.MLIR.IR.Operation, mod::Reactant.MLIR.IR.Module, name::String, in_tys::Vector{Reactant.MLIR.IR.Type}, do_transpose::Bool, optimize_then_pad::Bool, inv_map::IdDict{Any, Any}, args_in_result::Symbol, resprefix::Symbol, argprefix::Symbol, resargprefix::Symbol, verify_arg_names::Nothing, return_dialect::Symbol, traced_args_to_shardings::Reactant.OrderedIdDict{Any, Any}, output_shardings::Nothing, sym_visibility::Nothing, num_replicas::Int64, runtime::Val{:PJRT}, construct_function_without_args::Bool, args::Tuple{ConcretePJRTArray{Int64, 2, 8, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, ConcretePJRTArray{Int64, 2, 8, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}, N::Int64, concretein::Bool, toscalar::Bool)
      @ Reactant.TracedUtils ~/work/Reactant.jl/Reactant.jl/src/TracedUtils.jl:801
    [2] make_mlir_fn(f::typeof(Main.var"##Comm Optimization#243".dus), args::Tuple{ConcretePJRTArray{Int64, 2, 8, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, ConcretePJRTArray{Int64, 2, 8, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
      @ Reactant.TracedUtils ~/work/Reactant.jl/Reactant.jl/src/TracedUtils.jl:332
      ```
      
      seems related?

jumerckx and others added 2 commits May 1, 2025 21:22
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@jumerckx
Copy link
Collaborator Author

jumerckx commented May 2, 2025

seems related?

Right, sorry. This is now fixed.

There's also the buildkite failure, not sure if that's related. I can't seem to restart that job myself.

Error During Test at /var/lib/buildkite-agent/builds/gpuci-14/julialang/reactant-dot-jl/test/integration/cuda.jl:129
  Got exception outside of a @test
  Out of GPU memory trying to allocate 8 bytes
  Effective GPU memory usage: 77.14% (3.664 GiB/4.750 GiB)
  Memory pool usage: 0 bytes (0 bytes reserved)
  Stacktrace:
...

@wsmoses wsmoses merged commit c0aa16d into main May 3, 2025
56 checks passed
@wsmoses wsmoses deleted the jm/make_mlir_fn branch May 3, 2025 03:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants