Skip to content

Commit c0aa16d

Browse files
Small refactoring of make_mlir_fn + more (#1226)
* factor out first half of `make_mlir_fn` * factor out second half of `make_mlir_fn` * `push_inst!` * move code around in `Ops.call` * factor out `process_linear_args!` from `make_mlir_fn`. * further cleanup using `push_inst!` * formatting * fix * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 1dfda73 commit c0aa16d

File tree

3 files changed

+214
-87
lines changed

3 files changed

+214
-87
lines changed

src/Ops.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,24 +2353,6 @@ end
23532353
end
23542354

23552355
@noinline function call(f, args...)
2356-
seen_cache = Reactant.OrderedIdDict()
2357-
Reactant.make_tracer(
2358-
seen_cache,
2359-
args,
2360-
(), # we have to insert something here, but we remove it immediately below.
2361-
Reactant.TracedTrack;
2362-
toscalar=false,
2363-
)
2364-
linear_args = []
2365-
mlir_caller_args = Reactant.MLIR.IR.Value[]
2366-
for (k, v) in seen_cache
2367-
v isa Reactant.TracedType || continue
2368-
push!(linear_args, v)
2369-
push!(mlir_caller_args, v.mlir_data)
2370-
# make tracer inserted `()` into the path, here we remove it:
2371-
v.paths = v.paths[1:(end - 1)]
2372-
end
2373-
23742356
seen = Dict()
23752357
cache_key = []
23762358
Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes)
@@ -2414,6 +2396,24 @@ end
24142396
)
24152397
end
24162398

2399+
seen_cache = Reactant.OrderedIdDict()
2400+
Reactant.make_tracer(
2401+
seen_cache,
2402+
args,
2403+
(), # we have to insert something here, but we remove it immediately below.
2404+
Reactant.TracedTrack;
2405+
toscalar=false,
2406+
)
2407+
linear_args = []
2408+
mlir_caller_args = Reactant.MLIR.IR.Value[]
2409+
for (k, v) in seen_cache
2410+
v isa Reactant.TracedType || continue
2411+
push!(linear_args, v)
2412+
push!(mlir_caller_args, v.mlir_data)
2413+
# make tracer inserted `()` into the path, here we remove it:
2414+
v.paths = v.paths[1:(end - 1)]
2415+
end
2416+
24172417
call_op = MLIR.Dialects.func.call(
24182418
mlir_caller_args;
24192419
result_0=mlir_result_types,

src/TracedUtils.jl

Lines changed: 179 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,125 @@ function make_mlir_fn(
280280
return mlir_fn_res
281281
end
282282

283-
N = length(args)
284283
seen_args = OrderedIdDict()
284+
285+
(; N, traced_args, linear_args, inv_map, in_tys, sym_visibility, mod, traced_args_to_shardings, func, fnbody) = prepare_mlir_fn_args(
286+
args,
287+
name,
288+
seen_args,
289+
concretein,
290+
toscalar,
291+
argprefix,
292+
runtime,
293+
optimize_then_pad,
294+
do_transpose,
295+
input_shardings,
296+
verify_arg_names,
297+
)
298+
299+
Ops.activate_constant_context!(fnbody)
300+
@assert MLIR.IR._has_block()
301+
302+
# Explicitly don't use block! to avoid creating a closure, which creates
303+
# both compile-time and relocatability issues
304+
MLIR.IR.activate!(fnbody)
305+
306+
result = try
307+
process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)
308+
309+
if isempty(kwargs)
310+
Reactant.call_with_reactant(f, traced_args...)
311+
else
312+
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
313+
end
314+
finally
315+
MLIR.IR.deactivate!(fnbody)
316+
Ops.deactivate_constant_context!(fnbody)
317+
end
318+
319+
# check which arguments have been mutated
320+
mutated_args = Int[]
321+
if !construct_function_without_args
322+
for (i, arg) in enumerate(linear_args)
323+
if get_mlir_data(arg) != MLIR.IR.argument(fnbody, i)
324+
# mutation occured!
325+
push!(mutated_args, i)
326+
end
327+
end
328+
end
329+
330+
seen_results = OrderedIdDict()
331+
332+
(func2, traced_result, ret, linear_args, in_tys, linear_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn(
333+
result,
334+
traced_args,
335+
linear_args,
336+
seen_args,
337+
seen_results,
338+
fnbody,
339+
func,
340+
mod,
341+
name,
342+
in_tys,
343+
do_transpose,
344+
optimize_then_pad,
345+
inv_map,
346+
args_in_result,
347+
resprefix,
348+
argprefix,
349+
resargprefix,
350+
verify_arg_names,
351+
return_dialect,
352+
traced_args_to_shardings,
353+
output_shardings,
354+
sym_visibility,
355+
num_replicas,
356+
runtime,
357+
construct_function_without_args,
358+
args,
359+
N,
360+
concretein,
361+
toscalar,
362+
)
363+
364+
return CompiledMlirFnResult(
365+
false,
366+
func2,
367+
traced_result,
368+
result,
369+
seen_args,
370+
ret,
371+
linear_args,
372+
in_tys,
373+
linear_results,
374+
num_partitions,
375+
num_replicas,
376+
is_sharded,
377+
nothing,
378+
nothing,
379+
unique_meshes,
380+
mutated_args,
381+
true,
382+
missing,
383+
global_device_ids,
384+
nothing, # populated later in `compile_mlir!`
385+
)
386+
end
387+
388+
function prepare_mlir_fn_args(
389+
args,
390+
name,
391+
seen_args,
392+
concretein,
393+
toscalar,
394+
argprefix,
395+
runtime,
396+
optimize_then_pad,
397+
do_transpose,
398+
input_shardings,
399+
verify_arg_names,
400+
)
401+
N = length(args)
285402
traced_args = Vector{Any}(undef, N)
286403
inmode = if concretein
287404
@assert !toscalar
@@ -326,7 +443,6 @@ function make_mlir_fn(
326443
sym_visibility = MLIR.IR.Attribute("private")
327444
end
328445

329-
ctx = MLIR.IR.context()
330446
mod = MLIR.IR.mmodule()
331447

332448
# Insert meshes for the sharded arguments
@@ -378,43 +494,72 @@ function make_mlir_fn(
378494
end
379495
fnbody = MLIR.IR.Block(in_tys, arglocs)
380496
push!(MLIR.IR.region(func, 1), fnbody)
381-
Ops.activate_constant_context!(fnbody)
382497

383-
@assert MLIR.IR._has_block()
384-
385-
# Explicitly don't use block! to avoid creating a closure, which creates
386-
# both compile-time and relocatability issues
387-
MLIR.IR.activate!(fnbody)
498+
return (;
499+
N,
500+
traced_args,
501+
linear_args,
502+
inv_map,
503+
in_tys,
504+
sym_visibility,
505+
mod,
506+
traced_args_to_shardings,
507+
func,
508+
fnbody,
509+
)
510+
end
388511

389-
result = try
390-
for (i, arg) in enumerate(linear_args)
391-
raw_arg = MLIR.IR.argument(fnbody, i)
392-
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
393-
if !optimize_then_pad
394-
carg = inv_map[arg]
395-
if Reactant.has_padding(carg)
396-
padding = Reactant.get_padding(carg)
397-
sz = size(carg) .+ padding
398-
if !do_transpose
399-
padding = reverse(padding)
400-
sz = reverse(sz)
401-
end
402-
row_maj_arg = MLIR.IR.result(unpad_val_op(row_maj_arg, padding, sz), 1)
512+
function process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)
513+
for (i, arg) in enumerate(linear_args)
514+
raw_arg = MLIR.IR.argument(fnbody, i)
515+
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
516+
if !optimize_then_pad
517+
carg = inv_map[arg]
518+
if Reactant.has_padding(carg)
519+
padding = Reactant.get_padding(carg)
520+
sz = size(carg) .+ padding
521+
if !do_transpose
522+
padding = reverse(padding)
523+
sz = reverse(sz)
403524
end
525+
row_maj_arg = MLIR.IR.result(unpad_val_op(row_maj_arg, padding, sz), 1)
404526
end
405-
set_mlir_data!(arg, row_maj_arg)
406527
end
407-
408-
if isempty(kwargs)
409-
Reactant.call_with_reactant(f, traced_args...)
410-
else
411-
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
412-
end
413-
finally
414-
MLIR.IR.deactivate!(fnbody)
415-
Ops.deactivate_constant_context!(fnbody)
528+
set_mlir_data!(arg, row_maj_arg)
416529
end
530+
end
417531

532+
function finalize_mlir_fn(
533+
result,
534+
traced_args,
535+
linear_args,
536+
seen_args,
537+
seen_results,
538+
fnbody,
539+
func,
540+
mod,
541+
name,
542+
in_tys,
543+
do_transpose,
544+
optimize_then_pad,
545+
inv_map,
546+
args_in_result,
547+
resprefix,
548+
argprefix,
549+
resargprefix,
550+
verify_arg_names,
551+
return_dialect,
552+
traced_args_to_shardings,
553+
output_shardings,
554+
sym_visibility,
555+
num_replicas,
556+
runtime,
557+
construct_function_without_args,
558+
args,
559+
N,
560+
concretein,
561+
toscalar,
562+
)
418563
# check which arguments have been mutated
419564
mutated_args = Int[]
420565
if !construct_function_without_args
@@ -426,8 +571,6 @@ function make_mlir_fn(
426571
end
427572
end
428573

429-
seen_results = OrderedIdDict()
430-
431574
outmode = if concretein
432575
@assert !toscalar
433576
Reactant.NoStopTracedTrack
@@ -644,6 +787,7 @@ function make_mlir_fn(
644787
end
645788
end
646789

790+
ctx = MLIR.IR.context()
647791
# Attach `sdy.sharding` attribute to the argument
648792
for (i, arg) in enumerate(linear_args)
649793
if haskey(traced_args_to_shardings, arg)
@@ -742,27 +886,18 @@ function make_mlir_fn(
742886
MLIR.API.mlirOperationDestroy(func.operation)
743887
func.operation = MLIR.API.MlirOperation(C_NULL)
744888

745-
return CompiledMlirFnResult(
746-
false,
889+
return (
747890
func2,
748891
traced_result,
749-
result,
750-
seen_args,
751892
ret,
752893
linear_args,
753894
in_tys,
754895
linear_results,
755896
num_partitions,
756-
num_replicas,
757897
is_sharded,
758-
nothing,
759-
nothing,
760898
unique_meshes,
761899
mutated_args,
762-
true,
763-
missing,
764900
global_device_ids,
765-
nothing, # populated later in `compile_mlir!`
766901
)
767902
end
768903

0 commit comments

Comments
 (0)