Skip to content
36 changes: 18 additions & 18 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2353,24 +2353,6 @@ end
end

@noinline function call(f, args...)
seen_cache = Reactant.OrderedIdDict()
Reactant.make_tracer(
seen_cache,
args,
(), # we have to insert something here, but we remove it immediately below.
Reactant.TracedTrack;
toscalar=false,
)
linear_args = []
mlir_caller_args = Reactant.MLIR.IR.Value[]
for (k, v) in seen_cache
v isa Reactant.TracedType || continue
push!(linear_args, v)
push!(mlir_caller_args, v.mlir_data)
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:(end - 1)]
end

seen = Dict()
cache_key = []
Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes)
Expand Down Expand Up @@ -2414,6 +2396,24 @@ end
)
end

seen_cache = Reactant.OrderedIdDict()
Reactant.make_tracer(
seen_cache,
args,
(), # we have to insert something here, but we remove it immediately below.
Reactant.TracedTrack;
toscalar=false,
)
linear_args = []
mlir_caller_args = Reactant.MLIR.IR.Value[]
for (k, v) in seen_cache
v isa Reactant.TracedType || continue
push!(linear_args, v)
push!(mlir_caller_args, v.mlir_data)
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:(end - 1)]
end

call_op = MLIR.Dialects.func.call(
mlir_caller_args;
result_0=mlir_result_types,
Expand Down
223 changes: 179 additions & 44 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,125 @@ function make_mlir_fn(
return mlir_fn_res
end

N = length(args)
seen_args = OrderedIdDict()

(; N, traced_args, linear_args, inv_map, in_tys, sym_visibility, mod, traced_args_to_shardings, func, fnbody) = prepare_mlir_fn_args(
args,
name,
seen_args,
concretein,
toscalar,
argprefix,
runtime,
optimize_then_pad,
do_transpose,
input_shardings,
verify_arg_names,
)

Ops.activate_constant_context!(fnbody)
@assert MLIR.IR._has_block()

# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues
MLIR.IR.activate!(fnbody)

result = try
process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)

if isempty(kwargs)
Reactant.call_with_reactant(f, traced_args...)
else
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
end
finally
MLIR.IR.deactivate!(fnbody)
Ops.deactivate_constant_context!(fnbody)
end

# check which arguments have been mutated
mutated_args = Int[]
if !construct_function_without_args
for (i, arg) in enumerate(linear_args)
if get_mlir_data(arg) != MLIR.IR.argument(fnbody, i)
# mutation occured!
push!(mutated_args, i)
end
end
end

seen_results = OrderedIdDict()

(func2, traced_result, ret, linear_args, in_tys, linear_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn(
result,
traced_args,
linear_args,
seen_args,
seen_results,
fnbody,
func,
mod,
name,
in_tys,
do_transpose,
optimize_then_pad,
inv_map,
args_in_result,
resprefix,
argprefix,
resargprefix,
verify_arg_names,
return_dialect,
traced_args_to_shardings,
output_shardings,
sym_visibility,
num_replicas,
runtime,
construct_function_without_args,
args,
N,
concretein,
toscalar,
)

return CompiledMlirFnResult(
false,
func2,
traced_result,
result,
seen_args,
ret,
linear_args,
in_tys,
linear_results,
num_partitions,
num_replicas,
is_sharded,
nothing,
nothing,
unique_meshes,
mutated_args,
true,
missing,
global_device_ids,
nothing, # populated later in `compile_mlir!`
)
end

function prepare_mlir_fn_args(
args,
name,
seen_args,
concretein,
toscalar,
argprefix,
runtime,
optimize_then_pad,
do_transpose,
input_shardings,
verify_arg_names,
)
N = length(args)
traced_args = Vector{Any}(undef, N)
inmode = if concretein
@assert !toscalar
Expand Down Expand Up @@ -326,7 +443,6 @@ function make_mlir_fn(
sym_visibility = MLIR.IR.Attribute("private")
end

ctx = MLIR.IR.context()
mod = MLIR.IR.mmodule()

# Insert meshes for the sharded arguments
Expand Down Expand Up @@ -378,43 +494,72 @@ function make_mlir_fn(
end
fnbody = MLIR.IR.Block(in_tys, arglocs)
push!(MLIR.IR.region(func, 1), fnbody)
Ops.activate_constant_context!(fnbody)

@assert MLIR.IR._has_block()

# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues
MLIR.IR.activate!(fnbody)
return (;
N,
traced_args,
linear_args,
inv_map,
in_tys,
sym_visibility,
mod,
traced_args_to_shardings,
func,
fnbody,
)
end

result = try
for (i, arg) in enumerate(linear_args)
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
if !optimize_then_pad
carg = inv_map[arg]
if Reactant.has_padding(carg)
padding = Reactant.get_padding(carg)
sz = size(carg) .+ padding
if !do_transpose
padding = reverse(padding)
sz = reverse(sz)
end
row_maj_arg = MLIR.IR.result(unpad_val_op(row_maj_arg, padding, sz), 1)
function process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)
for (i, arg) in enumerate(linear_args)
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
if !optimize_then_pad
carg = inv_map[arg]
if Reactant.has_padding(carg)
padding = Reactant.get_padding(carg)
sz = size(carg) .+ padding
if !do_transpose
padding = reverse(padding)
sz = reverse(sz)
end
row_maj_arg = MLIR.IR.result(unpad_val_op(row_maj_arg, padding, sz), 1)
end
set_mlir_data!(arg, row_maj_arg)
end

if isempty(kwargs)
Reactant.call_with_reactant(f, traced_args...)
else
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
end
finally
MLIR.IR.deactivate!(fnbody)
Ops.deactivate_constant_context!(fnbody)
set_mlir_data!(arg, row_maj_arg)
end
end

function finalize_mlir_fn(
result,
traced_args,
linear_args,
seen_args,
seen_results,
fnbody,
func,
mod,
name,
in_tys,
do_transpose,
optimize_then_pad,
inv_map,
args_in_result,
resprefix,
argprefix,
resargprefix,
verify_arg_names,
return_dialect,
traced_args_to_shardings,
output_shardings,
sym_visibility,
num_replicas,
runtime,
construct_function_without_args,
args,
N,
concretein,
toscalar,
)
# check which arguments have been mutated
mutated_args = Int[]
if !construct_function_without_args
Expand All @@ -426,8 +571,6 @@ function make_mlir_fn(
end
end

seen_results = OrderedIdDict()

outmode = if concretein
@assert !toscalar
Reactant.NoStopTracedTrack
Expand Down Expand Up @@ -644,6 +787,7 @@ function make_mlir_fn(
end
end

ctx = MLIR.IR.context()
# Attach `sdy.sharding` attribute to the argument
for (i, arg) in enumerate(linear_args)
if haskey(traced_args_to_shardings, arg)
Expand Down Expand Up @@ -742,27 +886,18 @@ function make_mlir_fn(
MLIR.API.mlirOperationDestroy(func.operation)
func.operation = MLIR.API.MlirOperation(C_NULL)

return CompiledMlirFnResult(
false,
return (
func2,
traced_result,
result,
seen_args,
ret,
linear_args,
in_tys,
linear_results,
num_partitions,
num_replicas,
is_sharded,
nothing,
nothing,
unique_meshes,
mutated_args,
true,
missing,
global_device_ids,
nothing, # populated later in `compile_mlir!`
)
end

Expand Down
Loading
Loading