Skip to content

Small refactoring of make_mlir_fn + more #1226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Conversation

jumerckx
Copy link
Collaborator

@jumerckx jumerckx commented Apr 30, 2025

This pr splits the implementation of make_mlir_fn into multiple parts:

  1. prepare_mlir_fn_args
  2. process_linear_args
  3. trace the function by calling it
  4. finalize_mlir_fn

This is a prerequisite for doing automatic function call insertion as 1., 2., and 4. can be repurposed there. I'd like to merge this already in order to prevent constant merge conflicts as I make progress on the call insertions.

Unrelated, I moved code in Ops.call around, and did slight refactoring in utils.jl by introducing push_inst!, both also useful for the call insertion. If any of these changes turn out to be more contentious than the others, I can split this pr as well.

@jumerckx jumerckx marked this pull request as draft April 30, 2025 15:17
@jumerckx jumerckx force-pushed the jm/make_mlir_fn branch from 4b379d3 to c1670af Compare May 1, 2025 15:12
@jumerckx jumerckx requested a review from wsmoses May 1, 2025 18:37
@jumerckx jumerckx marked this pull request as ready for review May 1, 2025 18:37
@jumerckx
Copy link
Collaborator Author

jumerckx commented May 1, 2025

I'm not sure if CI failures are related to this pr?

@wsmoses
Copy link
Member

wsmoses commented May 1, 2025

DUS: Error During Test at /home/runner/work/Reactant.jl/Reactant.jl/test/optimize_comm.jl:64
  Got exception outside of a @test
  UndefVarError: `ctx` not defined
  Stacktrace:
    [1] finalize_mlir_fn(result::Nothing, traced_args::Vector{Any}, linear_args::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, seen_args::Reactant.OrderedIdDict{Any, Any}, seen_results::Reactant.OrderedIdDict{Any, Any}, fnbody::Reactant.MLIR.IR.Block, func::Reactant.MLIR.IR.Operation, mod::Reactant.MLIR.IR.Module, name::String, in_tys::Vector{Reactant.MLIR.IR.Type}, do_transpose::Bool, optimize_then_pad::Bool, inv_map::IdDict{Any, Any}, args_in_result::Symbol, resprefix::Symbol, argprefix::Symbol, resargprefix::Symbol, verify_arg_names::Nothing, return_dialect::Symbol, traced_args_to_shardings::Reactant.OrderedIdDict{Any, Any}, output_shardings::Nothing, sym_visibility::Nothing, num_replicas::Int64, runtime::Val{:PJRT}, construct_function_without_args::Bool, args::Tuple{ConcretePJRTArray{Int64, 2, 8, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, ConcretePJRTArray{Int64, 2, 8, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}, N::Int64, concretein::Bool, toscalar::Bool)
      @ Reactant.TracedUtils ~/work/Reactant.jl/Reactant.jl/src/TracedUtils.jl:801
    [2] make_mlir_fn(f::typeof(Main.var"##Comm Optimization#243".dus), args::Tuple{ConcretePJRTArray{Int64, 2, 8, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, ConcretePJRTArray{Int64, 2, 8, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
      @ Reactant.TracedUtils ~/work/Reactant.jl/Reactant.jl/src/TracedUtils.jl:332
      ```
      
      seems related?

jumerckx and others added 2 commits May 1, 2025 21:22
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants