Skip to content

Small refactoring of make_mlir_fn + more #1226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
36 changes: 18 additions & 18 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2353,24 +2353,6 @@ end
end

@noinline function call(f, args...)
seen_cache = Reactant.OrderedIdDict()
Reactant.make_tracer(
seen_cache,
args,
(), # we have to insert something here, but we remove it immediately below.
Reactant.TracedTrack;
toscalar=false,
)
linear_args = []
mlir_caller_args = Reactant.MLIR.IR.Value[]
for (k, v) in seen_cache
v isa Reactant.TracedType || continue
push!(linear_args, v)
push!(mlir_caller_args, v.mlir_data)
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:(end - 1)]
end

seen = Dict()
cache_key = []
Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes)
Expand Down Expand Up @@ -2414,6 +2396,24 @@ end
)
end

seen_cache = Reactant.OrderedIdDict()
Reactant.make_tracer(
seen_cache,
args,
(), # we have to insert something here, but we remove it immediately below.
Reactant.TracedTrack;
toscalar=false,
)
linear_args = []
mlir_caller_args = Reactant.MLIR.IR.Value[]
for (k, v) in seen_cache
v isa Reactant.TracedType || continue
push!(linear_args, v)
push!(mlir_caller_args, v.mlir_data)
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:(end - 1)]
end

call_op = MLIR.Dialects.func.call(
mlir_caller_args;
result_0=mlir_result_types,
Expand Down
223 changes: 179 additions & 44 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,125 @@ function make_mlir_fn(
return mlir_fn_res
end

N = length(args)
seen_args = OrderedIdDict()

(; N, traced_args, linear_args, inv_map, in_tys, sym_visibility, mod, traced_args_to_shardings, func, fnbody) = prepare_mlir_fn_args(
args,
name,
seen_args,
concretein,
toscalar,
argprefix,
runtime,
optimize_then_pad,
do_transpose,
input_shardings,
verify_arg_names,
)

Ops.activate_constant_context!(fnbody)
@assert MLIR.IR._has_block()

# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues
MLIR.IR.activate!(fnbody)

result = try
process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)

if isempty(kwargs)
Reactant.call_with_reactant(f, traced_args...)
else
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
end
finally
MLIR.IR.deactivate!(fnbody)
Ops.deactivate_constant_context!(fnbody)
end

# check which arguments have been mutated
mutated_args = Int[]
if !construct_function_without_args
for (i, arg) in enumerate(linear_args)
if get_mlir_data(arg) != MLIR.IR.argument(fnbody, i)
# mutation occured!
push!(mutated_args, i)
end
end
end

seen_results = OrderedIdDict()

(func2, traced_result, ret, linear_args, in_tys, linear_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn(
result,
traced_args,
linear_args,
seen_args,
seen_results,
fnbody,
func,
mod,
name,
in_tys,
do_transpose,
optimize_then_pad,
inv_map,
args_in_result,
resprefix,
argprefix,
resargprefix,
verify_arg_names,
return_dialect,
traced_args_to_shardings,
output_shardings,
sym_visibility,
num_replicas,
runtime,
construct_function_without_args,
args,
N,
concretein,
toscalar,
)

return CompiledMlirFnResult(
false,
func2,
traced_result,
result,
seen_args,
ret,
linear_args,
in_tys,
linear_results,
num_partitions,
num_replicas,
is_sharded,
nothing,
nothing,
unique_meshes,
mutated_args,
true,
missing,
global_device_ids,
nothing, # populated later in `compile_mlir!`
)
end

function prepare_mlir_fn_args(
args,
name,
seen_args,
concretein,
toscalar,
argprefix,
runtime,
optimize_then_pad,
do_transpose,
input_shardings,
verify_arg_names,
)
N = length(args)
traced_args = Vector{Any}(undef, N)
inmode = if concretein
@assert !toscalar
Expand Down Expand Up @@ -326,7 +443,6 @@ function make_mlir_fn(
sym_visibility = MLIR.IR.Attribute("private")
end

ctx = MLIR.IR.context()
mod = MLIR.IR.mmodule()

# Insert meshes for the sharded arguments
Expand Down Expand Up @@ -378,43 +494,72 @@ function make_mlir_fn(
end
fnbody = MLIR.IR.Block(in_tys, arglocs)
push!(MLIR.IR.region(func, 1), fnbody)
Ops.activate_constant_context!(fnbody)

@assert MLIR.IR._has_block()

# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues
MLIR.IR.activate!(fnbody)
return (;
N,
traced_args,
linear_args,
inv_map,
in_tys,
sym_visibility,
mod,
traced_args_to_shardings,
func,
fnbody,
)
end

result = try
for (i, arg) in enumerate(linear_args)
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
if !optimize_then_pad
carg = inv_map[arg]
if Reactant.has_padding(carg)
padding = Reactant.get_padding(carg)
sz = size(carg) .+ padding
if !do_transpose
padding = reverse(padding)
sz = reverse(sz)
end
row_maj_arg = MLIR.IR.result(unpad_val_op(row_maj_arg, padding, sz), 1)
function process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)
for (i, arg) in enumerate(linear_args)
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
if !optimize_then_pad
carg = inv_map[arg]
if Reactant.has_padding(carg)
padding = Reactant.get_padding(carg)
sz = size(carg) .+ padding
if !do_transpose
padding = reverse(padding)
sz = reverse(sz)
end
row_maj_arg = MLIR.IR.result(unpad_val_op(row_maj_arg, padding, sz), 1)
end
set_mlir_data!(arg, row_maj_arg)
end

if isempty(kwargs)
Reactant.call_with_reactant(f, traced_args...)
else
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
end
finally
MLIR.IR.deactivate!(fnbody)
Ops.deactivate_constant_context!(fnbody)
set_mlir_data!(arg, row_maj_arg)
end
end

function finalize_mlir_fn(
result,
traced_args,
linear_args,
seen_args,
seen_results,
fnbody,
func,
mod,
name,
in_tys,
do_transpose,
optimize_then_pad,
inv_map,
args_in_result,
resprefix,
argprefix,
resargprefix,
verify_arg_names,
return_dialect,
traced_args_to_shardings,
output_shardings,
sym_visibility,
num_replicas,
runtime,
construct_function_without_args,
args,
N,
concretein,
toscalar,
)
# check which arguments have been mutated
mutated_args = Int[]
if !construct_function_without_args
Expand All @@ -426,8 +571,6 @@ function make_mlir_fn(
end
end

seen_results = OrderedIdDict()

outmode = if concretein
@assert !toscalar
Reactant.NoStopTracedTrack
Expand Down Expand Up @@ -644,6 +787,7 @@ function make_mlir_fn(
end
end

ctx = MLIR.IR.context()
# Attach `sdy.sharding` attribute to the argument
for (i, arg) in enumerate(linear_args)
if haskey(traced_args_to_shardings, arg)
Expand Down Expand Up @@ -742,27 +886,18 @@ function make_mlir_fn(
MLIR.API.mlirOperationDestroy(func.operation)
func.operation = MLIR.API.MlirOperation(C_NULL)

return CompiledMlirFnResult(
false,
return (
func2,
traced_result,
result,
seen_args,
ret,
linear_args,
in_tys,
linear_results,
num_partitions,
num_replicas,
is_sharded,
nothing,
nothing,
unique_meshes,
mutated_args,
true,
missing,
global_device_ids,
nothing, # populated later in `compile_mlir!`
)
end

Expand Down
Loading
Loading