From 5ace26fd838a8cf0a2dbe7d9798c263083c1d81e Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 29 Apr 2025 09:52:18 +0200 Subject: [PATCH 1/9] factor out first half of `make_mlir_fn` --- src/TracedUtils.jl | 247 +++++++++++++++++++++++++++------------------ 1 file changed, 151 insertions(+), 96 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index b14cba95f6..620dd2f91e 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -280,106 +280,35 @@ function make_mlir_fn( return mlir_fn_res end - N = length(args) seen_args = OrderedIdDict() - traced_args = Vector{Any}(undef, N) - inmode = if concretein - @assert !toscalar - Reactant.ConcreteToTraced - else - Reactant.TracedSetPath - end - for i in 1:N - @inbounds traced_args[i] = Reactant.make_tracer( - seen_args, args[i], (argprefix, i), inmode; toscalar, runtime - ) - end - - linear_args = Reactant.TracedType[] - inv_map = IdDict() - for (k, v) in seen_args - v isa Reactant.TracedType || continue - push!(linear_args, v) - inv_map[v] = k - end - - in_tys = Vector{MLIR.IR.Type}(undef, length(linear_args)) - for (i, arg) in enumerate(linear_args) - elT = MLIR.IR.Type(Reactant.unwrapped_eltype(arg)) - if toscalar - in_tys[i] = MLIR.IR.TensorType(Int[], elT) - else - sz = collect(Int, size(arg)) - if !optimize_then_pad - carg = inv_map[arg] - Reactant.has_padding(carg) && (sz .+= Reactant.get_padding(carg)) - end - - typ = MLIR.IR.TensorType(sz, elT) - do_transpose && (typ = transpose_ty(typ)) - in_tys[i] = typ - end - end - - sym_visibility = nothing - if !concretein - sym_visibility = MLIR.IR.Attribute("private") - end - - ctx = MLIR.IR.context() - mod = MLIR.IR.mmodule() - - # Insert meshes for the sharded arguments - traced_args_to_shardings = OrderedIdDict() - for (k, v) in seen_args - if k isa Reactant.AbstractConcreteNumber || k isa Reactant.AbstractConcreteArray - if Reactant.Sharding.is_sharded(k) - Reactant.Ops.mesh(k.sharding.mesh) - traced_args_to_shardings[v] = k.sharding - elseif input_shardings !== nothing && haskey(input_shardings, k) - Reactant.Ops.mesh(input_shardings[k].mesh) - traced_args_to_shardings[v] = input_shardings[k] - end - end - end - func = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=name * "_tmp", - function_type=MLIR.IR.FunctionType(in_tys, Vector{MLIR.IR.Type}(undef, 0)), - body=MLIR.IR.Region(), - ) - end + ( + N, + traced_args, + linear_args, + inv_map, + in_tys, + sym_visibility, + ctx, + mod, + traced_args_to_shardings, + func, + fnbody + ) = prepare_mlir_fn_args( + args, + name, + seen_args, + concretein, + toscalar, + argprefix, + runtime, + optimize_then_pad, + do_transpose, + input_shardings, + verify_arg_names + ) - arglocs = MLIR.IR.Location[] - for arg in linear_args - path = get_idx(arg, argprefix) - stridx = if verify_arg_names isa Nothing - "arg" * string(path[2]) - else - string(verify_arg_names.args[path[2]]) - end - aval = args[path[2]] - for (cidx, idx) in enumerate(path[3:end]) - if aval isa Array || aval isa Dict - aval = getindex(aval, idx) - stridx = stridx * "[" * string(idx) * "]" - else - fldname = if idx isa Integer - string(fieldname(Core.Typeof(aval), idx)) - else - string(idx) - end - stridx *= "." * fldname - aval = getfield(aval, idx) - end - end - push!(arglocs, MLIR.IR.Location(stridx * " (path=$path)", MLIR.IR.Location())) - end - fnbody = MLIR.IR.Block(in_tys, arglocs) - push!(MLIR.IR.region(func, 1), fnbody) Ops.activate_constant_context!(fnbody) - @assert MLIR.IR._has_block() # Explicitly don't use block! to avoid creating a closure, which creates @@ -766,6 +695,132 @@ function make_mlir_fn( ) end +function prepare_mlir_fn_args( + args, + name, + seen_args, + concretein, + toscalar, + argprefix, + runtime, + optimize_then_pad, + do_transpose, + input_shardings, + verify_arg_names +) + N = length(args) + traced_args = Vector{Any}(undef, N) + inmode = if concretein + @assert !toscalar + Reactant.ConcreteToTraced + else + Reactant.TracedSetPath + end + for i in 1:N + @inbounds traced_args[i] = Reactant.make_tracer( + seen_args, args[i], (argprefix, i), inmode; toscalar, runtime + ) + end + + linear_args = Reactant.TracedType[] + inv_map = IdDict() + for (k, v) in seen_args + v isa Reactant.TracedType || continue + push!(linear_args, v) + inv_map[v] = k + end + + in_tys = Vector{MLIR.IR.Type}(undef, length(linear_args)) + for (i, arg) in enumerate(linear_args) + elT = MLIR.IR.Type(Reactant.unwrapped_eltype(arg)) + if toscalar + in_tys[i] = MLIR.IR.TensorType(Int[], elT) + else + sz = collect(Int, size(arg)) + if !optimize_then_pad + carg = inv_map[arg] + Reactant.has_padding(carg) && (sz .+= Reactant.get_padding(carg)) + end + + typ = MLIR.IR.TensorType(sz, elT) + do_transpose && (typ = transpose_ty(typ)) + in_tys[i] = typ + end + end + + sym_visibility = nothing + if !concretein + sym_visibility = MLIR.IR.Attribute("private") + end + + ctx = MLIR.IR.context() + mod = MLIR.IR.mmodule() + + # Insert meshes for the sharded arguments + traced_args_to_shardings = OrderedIdDict() + for (k, v) in seen_args + if k isa Reactant.AbstractConcreteNumber || k isa Reactant.AbstractConcreteArray + if Reactant.Sharding.is_sharded(k) + Reactant.Ops.mesh(k.sharding.mesh) + traced_args_to_shardings[v] = k.sharding + elseif input_shardings !== nothing && haskey(input_shardings, k) + Reactant.Ops.mesh(input_shardings[k].mesh) + traced_args_to_shardings[v] = input_shardings[k] + end + end + end + + func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=name * "_tmp", + function_type=MLIR.IR.FunctionType(in_tys, Vector{MLIR.IR.Type}(undef, 0)), + body=MLIR.IR.Region(), + ) + end + + arglocs = MLIR.IR.Location[] + for arg in linear_args + path = get_idx(arg, argprefix) + stridx = if verify_arg_names isa Nothing + "arg" * string(path[2]) + else + string(verify_arg_names.args[path[2]]) + end + aval = args[path[2]] + for (cidx, idx) in enumerate(path[3:end]) + if aval isa Array || aval isa Dict + aval = getindex(aval, idx) + stridx = stridx * "[" * string(idx) * "]" + else + fldname = if idx isa Integer + string(fieldname(Core.Typeof(aval), idx)) + else + string(idx) + end + stridx *= "." * fldname + aval = getfield(aval, idx) + end + end + push!(arglocs, MLIR.IR.Location(stridx * " (path=$path)", MLIR.IR.Location())) + end + fnbody = MLIR.IR.Block(in_tys, arglocs) + push!(MLIR.IR.region(func, 1), fnbody) + + return ( + N, + traced_args, + linear_args, + inv_map, + in_tys, + sym_visibility, + ctx, + mod, + traced_args_to_shardings, + func, + fnbody + ) +end + function __lookup_unique_name_in_module(mod, name) new_name = name tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)) From 5e78f8df2445bb8640bcbff1b7d9db762a6c1c6f Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 29 Apr 2025 10:14:31 +0200 Subject: [PATCH 2/9] factor out second half of `make_mlir_fn` --- src/TracedUtils.jl | 376 ++++++++++++++++++++++++++++----------------- 1 file changed, 239 insertions(+), 137 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 620dd2f91e..b431cc6e07 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -356,6 +356,243 @@ function make_mlir_fn( end seen_results = OrderedIdDict() + + # Call the second extracted function to finalize + ( + func2, + traced_result, + ret, + linear_args, + in_tys, + linear_results, + num_partitions, + is_sharded, + unique_meshes, + mutated_args, + global_device_ids + ) = finalize_mlir_fn( + result, + traced_args, + linear_args, + seen_args, + seen_results, + fnbody, + func, + mod, + name, + in_tys, + do_transpose, + optimize_then_pad, + inv_map, + args_in_result, + resprefix, + argprefix, + resargprefix, + verify_arg_names, + return_dialect, + traced_args_to_shardings, + output_shardings, + sym_visibility, + num_replicas, + runtime, + construct_function_without_args, + args, + N, + concretein, + toscalar + ) + + return CompiledMlirFnResult( + false, + func2, + traced_result, + result, + seen_args, + ret, + linear_args, + in_tys, + linear_results, + num_partitions, + num_replicas, + is_sharded, + nothing, + nothing, + unique_meshes, + mutated_args, + true, + missing, + global_device_ids, + nothing, # populated later in `compile_mlir!` + ) +end + +function prepare_mlir_fn_args( + args, + name, + seen_args, + concretein, + toscalar, + argprefix, + runtime, + optimize_then_pad, + do_transpose, + input_shardings, + verify_arg_names +) + N = length(args) + traced_args = Vector{Any}(undef, N) + inmode = if concretein + @assert !toscalar + Reactant.ConcreteToTraced + else + Reactant.TracedSetPath + end + for i in 1:N + @inbounds traced_args[i] = Reactant.make_tracer( + seen_args, args[i], (argprefix, i), inmode; toscalar, runtime + ) + end + + linear_args = Reactant.TracedType[] + inv_map = IdDict() + for (k, v) in seen_args + v isa Reactant.TracedType || continue + push!(linear_args, v) + inv_map[v] = k + end + + in_tys = Vector{MLIR.IR.Type}(undef, length(linear_args)) + for (i, arg) in enumerate(linear_args) + elT = MLIR.IR.Type(Reactant.unwrapped_eltype(arg)) + if toscalar + in_tys[i] = MLIR.IR.TensorType(Int[], elT) + else + sz = collect(Int, size(arg)) + if !optimize_then_pad + carg = inv_map[arg] + Reactant.has_padding(carg) && (sz .+= Reactant.get_padding(carg)) + end + + typ = MLIR.IR.TensorType(sz, elT) + do_transpose && (typ = transpose_ty(typ)) + in_tys[i] = typ + end + end + + sym_visibility = nothing + if !concretein + sym_visibility = MLIR.IR.Attribute("private") + end + + ctx = MLIR.IR.context() + mod = MLIR.IR.mmodule() + + # Insert meshes for the sharded arguments + traced_args_to_shardings = OrderedIdDict() + for (k, v) in seen_args + if k isa Reactant.AbstractConcreteNumber || k isa Reactant.AbstractConcreteArray + if Reactant.Sharding.is_sharded(k) + Reactant.Ops.mesh(k.sharding.mesh) + traced_args_to_shardings[v] = k.sharding + elseif input_shardings !== nothing && haskey(input_shardings, k) + Reactant.Ops.mesh(input_shardings[k].mesh) + traced_args_to_shardings[v] = input_shardings[k] + end + end + end + + func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=name * "_tmp", + function_type=MLIR.IR.FunctionType(in_tys, Vector{MLIR.IR.Type}(undef, 0)), + body=MLIR.IR.Region(), + ) + end + + arglocs = MLIR.IR.Location[] + for arg in linear_args + path = get_idx(arg, argprefix) + stridx = if verify_arg_names isa Nothing + "arg" * string(path[2]) + else + string(verify_arg_names.args[path[2]]) + end + aval = args[path[2]] + for (cidx, idx) in enumerate(path[3:end]) + if aval isa Array || aval isa Dict + aval = getindex(aval, idx) + stridx = stridx * "[" * string(idx) * "]" + else + fldname = if idx isa Integer + string(fieldname(Core.Typeof(aval), idx)) + else + string(idx) + end + stridx *= "." * fldname + aval = getfield(aval, idx) + end + end + push!(arglocs, MLIR.IR.Location(stridx * " (path=$path)", MLIR.IR.Location())) + end + fnbody = MLIR.IR.Block(in_tys, arglocs) + push!(MLIR.IR.region(func, 1), fnbody) + + return ( + N, + traced_args, + linear_args, + inv_map, + in_tys, + sym_visibility, + ctx, + mod, + traced_args_to_shardings, + func, + fnbody + ) +end + +function finalize_mlir_fn( + result, + traced_args, + linear_args, + seen_args, + seen_results, + fnbody, + func, + mod, + name, + in_tys, + do_transpose, + optimize_then_pad, + inv_map, + args_in_result, + resprefix, + argprefix, + resargprefix, + verify_arg_names, + return_dialect, + traced_args_to_shardings, + output_shardings, + sym_visibility, + num_replicas, + runtime, + construct_function_without_args, + args, + N, + concretein, + toscalar +) + # check which arguments have been mutated + mutated_args = Int[] + if !construct_function_without_args + for (i, arg) in enumerate(linear_args) + if get_mlir_data(arg) != MLIR.IR.argument(fnbody, i) + # mutation occured! + push!(mutated_args, i) + end + end + end outmode = if concretein @assert !toscalar @@ -671,153 +908,18 @@ function make_mlir_fn( MLIR.API.mlirOperationDestroy(func.operation) func.operation = MLIR.API.MlirOperation(C_NULL) - return CompiledMlirFnResult( - false, + return ( func2, traced_result, - result, - seen_args, ret, linear_args, in_tys, linear_results, num_partitions, - num_replicas, is_sharded, - nothing, - nothing, unique_meshes, mutated_args, - true, - missing, - global_device_ids, - nothing, # populated later in `compile_mlir!` - ) -end - -function prepare_mlir_fn_args( - args, - name, - seen_args, - concretein, - toscalar, - argprefix, - runtime, - optimize_then_pad, - do_transpose, - input_shardings, - verify_arg_names -) - N = length(args) - traced_args = Vector{Any}(undef, N) - inmode = if concretein - @assert !toscalar - Reactant.ConcreteToTraced - else - Reactant.TracedSetPath - end - for i in 1:N - @inbounds traced_args[i] = Reactant.make_tracer( - seen_args, args[i], (argprefix, i), inmode; toscalar, runtime - ) - end - - linear_args = Reactant.TracedType[] - inv_map = IdDict() - for (k, v) in seen_args - v isa Reactant.TracedType || continue - push!(linear_args, v) - inv_map[v] = k - end - - in_tys = Vector{MLIR.IR.Type}(undef, length(linear_args)) - for (i, arg) in enumerate(linear_args) - elT = MLIR.IR.Type(Reactant.unwrapped_eltype(arg)) - if toscalar - in_tys[i] = MLIR.IR.TensorType(Int[], elT) - else - sz = collect(Int, size(arg)) - if !optimize_then_pad - carg = inv_map[arg] - Reactant.has_padding(carg) && (sz .+= Reactant.get_padding(carg)) - end - - typ = MLIR.IR.TensorType(sz, elT) - do_transpose && (typ = transpose_ty(typ)) - in_tys[i] = typ - end - end - - sym_visibility = nothing - if !concretein - sym_visibility = MLIR.IR.Attribute("private") - end - - ctx = MLIR.IR.context() - mod = MLIR.IR.mmodule() - - # Insert meshes for the sharded arguments - traced_args_to_shardings = OrderedIdDict() - for (k, v) in seen_args - if k isa Reactant.AbstractConcreteNumber || k isa Reactant.AbstractConcreteArray - if Reactant.Sharding.is_sharded(k) - Reactant.Ops.mesh(k.sharding.mesh) - traced_args_to_shardings[v] = k.sharding - elseif input_shardings !== nothing && haskey(input_shardings, k) - Reactant.Ops.mesh(input_shardings[k].mesh) - traced_args_to_shardings[v] = input_shardings[k] - end - end - end - - func = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=name * "_tmp", - function_type=MLIR.IR.FunctionType(in_tys, Vector{MLIR.IR.Type}(undef, 0)), - body=MLIR.IR.Region(), - ) - end - - arglocs = MLIR.IR.Location[] - for arg in linear_args - path = get_idx(arg, argprefix) - stridx = if verify_arg_names isa Nothing - "arg" * string(path[2]) - else - string(verify_arg_names.args[path[2]]) - end - aval = args[path[2]] - for (cidx, idx) in enumerate(path[3:end]) - if aval isa Array || aval isa Dict - aval = getindex(aval, idx) - stridx = stridx * "[" * string(idx) * "]" - else - fldname = if idx isa Integer - string(fieldname(Core.Typeof(aval), idx)) - else - string(idx) - end - stridx *= "." * fldname - aval = getfield(aval, idx) - end - end - push!(arglocs, MLIR.IR.Location(stridx * " (path=$path)", MLIR.IR.Location())) - end - fnbody = MLIR.IR.Block(in_tys, arglocs) - push!(MLIR.IR.region(func, 1), fnbody) - - return ( - N, - traced_args, - linear_args, - inv_map, - in_tys, - sym_visibility, - ctx, - mod, - traced_args_to_shardings, - func, - fnbody + global_device_ids ) end From ca8a8c061b8e2d6023f913ac6b4b9cc8644886f3 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 29 Apr 2025 10:30:56 +0200 Subject: [PATCH 3/9] `push_inst!` --- src/utils.jl | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index b7fbc99bff..22ef783df6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -637,6 +637,10 @@ function call_with_reactant_generator( # the end of the pass, we'll reset `code_info` fields accordingly. overdubbed_code = Any[] overdubbed_codelocs = Int32[] + function push_inst!(inst) + push!(overdubbed_code, inst) + push!(overdubbed_codelocs, code_info.codelocs[1]) + end # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention # required by the base method. @@ -662,15 +666,13 @@ function call_with_reactant_generator( actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset ) - push!(overdubbed_code, actual_argument) - push!(overdubbed_codelocs, code_info.codelocs[1]) + push_inst!(actual_argument) offset += 1 push!(fn_args, Core.SSAValue(length(overdubbed_code))) push!(tys, redub_arguments[i + (guaranteed_error ? 1 : 0)]) if DEBUG_INTERP[] - push!( - overdubbed_code, + push_inst!( Expr( :call, safe_print, @@ -678,7 +680,6 @@ function call_with_reactant_generator( fn_args[end], ), ) - push!(overdubbed_codelocs, code_info.codelocs[1]) end end @@ -687,17 +688,14 @@ function call_with_reactant_generator( if method.isva trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) for i in n_method_args:n_actual_args - push!( - overdubbed_code, + push_inst( Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset), ) - push!(overdubbed_codelocs, code_info.codelocs[1]) push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code))) offset += 1 end - push!(overdubbed_code, trailing_arguments) - push!(overdubbed_codelocs, code_info.codelocs[1]) + push_inst!(trailing_arguments) push!(fn_args, Core.SSAValue(length(overdubbed_code))) push!( tys, @@ -707,8 +705,7 @@ function call_with_reactant_generator( ) if DEBUG_INTERP[] - push!( - overdubbed_code, + push_inst!( Expr( :call, safe_print, @@ -716,7 +713,6 @@ function call_with_reactant_generator( fn_args[end], ), ) - push!(overdubbed_codelocs, code_info.codelocs[1]) end end @@ -751,23 +747,19 @@ function call_with_reactant_generator( else farg = fn_args[1] rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg) - push!(overdubbed_code, rep) - push!(overdubbed_codelocs, code_info.codelocs[1]) + push_inst!(rep) Core.SSAValue(length(overdubbed_code)) end - push!(overdubbed_code, Expr(:call, oc, fn_args[2:end]...)) - push!(overdubbed_codelocs, code_info.codelocs[1]) + push_inst!(Expr(:call, oc, fn_args[2:end]...)) ocres = Core.SSAValue(length(overdubbed_code)) if DEBUG_INTERP[] - push!(overdubbed_code, Expr(:call, safe_print, "ocres", ocres)) - push!(overdubbed_codelocs, code_info.codelocs[1]) + push_inst!(Expr(:call, safe_print, "ocres", ocres)) end - push!(overdubbed_code, Core.ReturnNode(ocres)) - push!(overdubbed_codelocs, code_info.codelocs[1]) + push_inst!(Core.ReturnNode(ocres)) #=== set `code_info`/`reflection` fields accordingly ===# From 85593fe266f94f6272798acd40cffd3fe45533e4 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 29 Apr 2025 11:44:23 +0200 Subject: [PATCH 4/9] move code around in `Ops.call` --- src/Ops.jl | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index dbc85fc716..86b452bb1c 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2353,24 +2353,6 @@ end end @noinline function call(f, args...) - seen_cache = Reactant.OrderedIdDict() - Reactant.make_tracer( - seen_cache, - args, - (), # we have to insert something here, but we remove it immediately below. - Reactant.TracedTrack; - toscalar=false, - ) - linear_args = [] - mlir_caller_args = Reactant.MLIR.IR.Value[] - for (k, v) in seen_cache - v isa Reactant.TracedType || continue - push!(linear_args, v) - push!(mlir_caller_args, v.mlir_data) - # make tracer inserted `()` into the path, here we remove it: - v.paths = v.paths[1:(end - 1)] - end - seen = Dict() cache_key = [] Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes) @@ -2414,6 +2396,24 @@ end ) end + seen_cache = Reactant.OrderedIdDict() + Reactant.make_tracer( + seen_cache, + args, + (), # we have to insert something here, but we remove it immediately below. + Reactant.TracedTrack; + toscalar=false, + ) + linear_args = [] + mlir_caller_args = Reactant.MLIR.IR.Value[] + for (k, v) in seen_cache + v isa Reactant.TracedType || continue + push!(linear_args, v) + push!(mlir_caller_args, v.mlir_data) + # make tracer inserted `()` into the path, here we remove it: + v.paths = v.paths[1:(end - 1)] + end + call_op = MLIR.Dialects.func.call( mlir_caller_args; result_0=mlir_result_types, From 32903c2b9a0b5350b987f3f677f2a68686dda22b Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 29 Apr 2025 16:43:16 +0200 Subject: [PATCH 5/9] factor out `process_linear_args!` from `make_mlir_fn`. --- src/TracedUtils.jl | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index b431cc6e07..261267dc6f 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -282,7 +282,7 @@ function make_mlir_fn( seen_args = OrderedIdDict() - ( + (; N, traced_args, linear_args, @@ -316,23 +316,7 @@ function make_mlir_fn( MLIR.IR.activate!(fnbody) result = try - for (i, arg) in enumerate(linear_args) - raw_arg = MLIR.IR.argument(fnbody, i) - row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg - if !optimize_then_pad - carg = inv_map[arg] - if Reactant.has_padding(carg) - padding = Reactant.get_padding(carg) - sz = size(carg) .+ padding - if !do_transpose - padding = reverse(padding) - sz = reverse(sz) - end - row_maj_arg = MLIR.IR.result(unpad_val_op(row_maj_arg, padding, sz), 1) - end - end - set_mlir_data!(arg, row_maj_arg) - end + process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map) if isempty(kwargs) Reactant.call_with_reactant(f, traced_args...) @@ -357,7 +341,6 @@ function make_mlir_fn( seen_results = OrderedIdDict() - # Call the second extracted function to finalize ( func2, traced_result, @@ -537,7 +520,7 @@ function prepare_mlir_fn_args( fnbody = MLIR.IR.Block(in_tys, arglocs) push!(MLIR.IR.region(func, 1), fnbody) - return ( + return (; N, traced_args, linear_args, @@ -552,6 +535,26 @@ function prepare_mlir_fn_args( ) end +function process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map) + for (i, arg) in enumerate(linear_args) + raw_arg = MLIR.IR.argument(fnbody, i) + row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg + if !optimize_then_pad + carg = inv_map[arg] + if Reactant.has_padding(carg) + padding = Reactant.get_padding(carg) + sz = size(carg) .+ padding + if !do_transpose + padding = reverse(padding) + sz = reverse(sz) + end + row_maj_arg = MLIR.IR.result(unpad_val_op(row_maj_arg, padding, sz), 1) + end + end + set_mlir_data!(arg, row_maj_arg) + end +end + function finalize_mlir_fn( result, traced_args, From c1670afc237d614db26bc0bdfcbc82e67acc5272 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 29 Apr 2025 16:45:13 +0200 Subject: [PATCH 6/9] further cleanup using `push_inst!` --- src/utils.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 22ef783df6..d715b7cc09 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -640,6 +640,7 @@ function call_with_reactant_generator( function push_inst!(inst) push!(overdubbed_code, inst) push!(overdubbed_codelocs, code_info.codelocs[1]) + return Core.SSAValue(length(overdubbed_code)) end # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention # required by the base method. @@ -666,9 +667,9 @@ function call_with_reactant_generator( actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset ) - push_inst!(actual_argument) + arg = push_inst!(actual_argument) offset += 1 - push!(fn_args, Core.SSAValue(length(overdubbed_code))) + push!(fn_args, arg) push!(tys, redub_arguments[i + (guaranteed_error ? 1 : 0)]) if DEBUG_INTERP[] @@ -688,15 +689,14 @@ function call_with_reactant_generator( if method.isva trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) for i in n_method_args:n_actual_args - push_inst( + arg = push_inst!( Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset), ) - push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code))) + push!(trailing_arguments.args, arg) offset += 1 end - push_inst!(trailing_arguments) - push!(fn_args, Core.SSAValue(length(overdubbed_code))) + push!(fn_args, push_inst!(trailing_arguments)) push!( tys, Tuple{ From 2cb8083c0136f6b624440ecfdedd5ff5e289b44a Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Thu, 1 May 2025 17:13:56 +0200 Subject: [PATCH 7/9] formatting --- src/TracedUtils.jl | 44 ++++++++++---------------------------------- src/utils.jl | 2 +- 2 files changed, 11 insertions(+), 35 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 261267dc6f..ee8a36954c 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -282,19 +282,7 @@ function make_mlir_fn( seen_args = OrderedIdDict() - (; - N, - traced_args, - linear_args, - inv_map, - in_tys, - sym_visibility, - ctx, - mod, - traced_args_to_shardings, - func, - fnbody - ) = prepare_mlir_fn_args( + (; N, traced_args, linear_args, inv_map, in_tys, sym_visibility, ctx, mod, traced_args_to_shardings, func, fnbody) = prepare_mlir_fn_args( args, name, seen_args, @@ -305,7 +293,7 @@ function make_mlir_fn( optimize_then_pad, do_transpose, input_shardings, - verify_arg_names + verify_arg_names, ) Ops.activate_constant_context!(fnbody) @@ -340,20 +328,8 @@ function make_mlir_fn( end seen_results = OrderedIdDict() - - ( - func2, - traced_result, - ret, - linear_args, - in_tys, - linear_results, - num_partitions, - is_sharded, - unique_meshes, - mutated_args, - global_device_ids - ) = finalize_mlir_fn( + + (func2, traced_result, ret, linear_args, in_tys, linear_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn( result, traced_args, linear_args, @@ -382,7 +358,7 @@ function make_mlir_fn( args, N, concretein, - toscalar + toscalar, ) return CompiledMlirFnResult( @@ -420,7 +396,7 @@ function prepare_mlir_fn_args( optimize_then_pad, do_transpose, input_shardings, - verify_arg_names + verify_arg_names, ) N = length(args) traced_args = Vector{Any}(undef, N) @@ -519,7 +495,7 @@ function prepare_mlir_fn_args( end fnbody = MLIR.IR.Block(in_tys, arglocs) push!(MLIR.IR.region(func, 1), fnbody) - + return (; N, traced_args, @@ -531,7 +507,7 @@ function prepare_mlir_fn_args( mod, traced_args_to_shardings, func, - fnbody + fnbody, ) end @@ -584,7 +560,7 @@ function finalize_mlir_fn( args, N, concretein, - toscalar + toscalar, ) # check which arguments have been mutated mutated_args = Int[] @@ -922,7 +898,7 @@ function finalize_mlir_fn( is_sharded, unique_meshes, mutated_args, - global_device_ids + global_device_ids, ) end diff --git a/src/utils.jl b/src/utils.jl index d715b7cc09..245b89efae 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -690,7 +690,7 @@ function call_with_reactant_generator( trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) for i in n_method_args:n_actual_args arg = push_inst!( - Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset), + Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset) ) push!(trailing_arguments.args, arg) offset += 1 From 0e29c74a043492f299f2e072dc0422416f0194bf Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Thu, 1 May 2025 21:19:26 +0200 Subject: [PATCH 8/9] fix --- src/TracedUtils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index ee8a36954c..be34acdcd5 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -282,7 +282,7 @@ function make_mlir_fn( seen_args = OrderedIdDict() - (; N, traced_args, linear_args, inv_map, in_tys, sym_visibility, ctx, mod, traced_args_to_shardings, func, fnbody) = prepare_mlir_fn_args( + (; N, traced_args, linear_args, inv_map, in_tys, sym_visibility, mod, traced_args_to_shardings, func, fnbody) = prepare_mlir_fn_args( args, name, seen_args, @@ -443,7 +443,6 @@ function prepare_mlir_fn_args( sym_visibility = MLIR.IR.Attribute("private") end - ctx = MLIR.IR.context() mod = MLIR.IR.mmodule() # Insert meshes for the sharded arguments @@ -503,7 +502,6 @@ function prepare_mlir_fn_args( inv_map, in_tys, sym_visibility, - ctx, mod, traced_args_to_shardings, func, @@ -789,6 +787,8 @@ function finalize_mlir_fn( end end + ctx = MLIR.IR.context() + # Attach `sdy.sharding` attribute to the argument for (i, arg) in enumerate(linear_args) if haskey(traced_args_to_shardings, arg) From 860b5cacc387b40fb9459f2d203e7cb4c5528baa Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Thu, 1 May 2025 21:22:07 +0200 Subject: [PATCH 9/9] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/TracedUtils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index be34acdcd5..cd431cb5f7 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -788,7 +788,6 @@ function finalize_mlir_fn( end ctx = MLIR.IR.context() - # Attach `sdy.sharding` attribute to the argument for (i, arg) in enumerate(linear_args) if haskey(traced_args_to_shardings, arg)