@@ -280,8 +280,125 @@ function make_mlir_fn(
280
280
return mlir_fn_res
281
281
end
282
282
283
- N = length (args)
284
283
seen_args = OrderedIdDict ()
284
+
285
+ (; N, traced_args, linear_args, inv_map, in_tys, sym_visibility, mod, traced_args_to_shardings, func, fnbody) = prepare_mlir_fn_args (
286
+ args,
287
+ name,
288
+ seen_args,
289
+ concretein,
290
+ toscalar,
291
+ argprefix,
292
+ runtime,
293
+ optimize_then_pad,
294
+ do_transpose,
295
+ input_shardings,
296
+ verify_arg_names,
297
+ )
298
+
299
+ Ops. activate_constant_context! (fnbody)
300
+ @assert MLIR. IR. _has_block ()
301
+
302
+ # Explicitly don't use block! to avoid creating a closure, which creates
303
+ # both compile-time and relocatability issues
304
+ MLIR. IR. activate! (fnbody)
305
+
306
+ result = try
307
+ process_linear_args! (linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)
308
+
309
+ if isempty (kwargs)
310
+ Reactant. call_with_reactant (f, traced_args... )
311
+ else
312
+ Reactant. call_with_reactant (Core. kwcall, kwargs, f, traced_args... )
313
+ end
314
+ finally
315
+ MLIR. IR. deactivate! (fnbody)
316
+ Ops. deactivate_constant_context! (fnbody)
317
+ end
318
+
319
+ # check which arguments have been mutated
320
+ mutated_args = Int[]
321
+ if ! construct_function_without_args
322
+ for (i, arg) in enumerate (linear_args)
323
+ if get_mlir_data (arg) != MLIR. IR. argument (fnbody, i)
324
+ # mutation occured!
325
+ push! (mutated_args, i)
326
+ end
327
+ end
328
+ end
329
+
330
+ seen_results = OrderedIdDict ()
331
+
332
+ (func2, traced_result, ret, linear_args, in_tys, linear_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn (
333
+ result,
334
+ traced_args,
335
+ linear_args,
336
+ seen_args,
337
+ seen_results,
338
+ fnbody,
339
+ func,
340
+ mod,
341
+ name,
342
+ in_tys,
343
+ do_transpose,
344
+ optimize_then_pad,
345
+ inv_map,
346
+ args_in_result,
347
+ resprefix,
348
+ argprefix,
349
+ resargprefix,
350
+ verify_arg_names,
351
+ return_dialect,
352
+ traced_args_to_shardings,
353
+ output_shardings,
354
+ sym_visibility,
355
+ num_replicas,
356
+ runtime,
357
+ construct_function_without_args,
358
+ args,
359
+ N,
360
+ concretein,
361
+ toscalar,
362
+ )
363
+
364
+ return CompiledMlirFnResult (
365
+ false ,
366
+ func2,
367
+ traced_result,
368
+ result,
369
+ seen_args,
370
+ ret,
371
+ linear_args,
372
+ in_tys,
373
+ linear_results,
374
+ num_partitions,
375
+ num_replicas,
376
+ is_sharded,
377
+ nothing ,
378
+ nothing ,
379
+ unique_meshes,
380
+ mutated_args,
381
+ true ,
382
+ missing ,
383
+ global_device_ids,
384
+ nothing , # populated later in `compile_mlir!`
385
+ )
386
+ end
387
+
388
+ function prepare_mlir_fn_args (
389
+ args,
390
+ name,
391
+ seen_args,
392
+ concretein,
393
+ toscalar,
394
+ argprefix,
395
+ runtime,
396
+ optimize_then_pad,
397
+ do_transpose,
398
+ input_shardings,
399
+ verify_arg_names,
400
+ )
401
+ N = length (args)
285
402
traced_args = Vector {Any} (undef, N)
286
403
inmode = if concretein
287
404
@assert ! toscalar
@@ -326,7 +443,6 @@ function make_mlir_fn(
326
443
sym_visibility = MLIR. IR. Attribute (" private" )
327
444
end
328
445
329
- ctx = MLIR. IR. context ()
330
446
mod = MLIR. IR. mmodule ()
331
447
332
448
# Insert meshes for the sharded arguments
@@ -378,43 +494,72 @@ function make_mlir_fn(
378
494
end
379
495
fnbody = MLIR. IR. Block (in_tys, arglocs)
380
496
push! (MLIR. IR. region (func, 1 ), fnbody)
381
- Ops. activate_constant_context! (fnbody)
382
497
383
- @assert MLIR. IR. _has_block ()
384
-
385
- # Explicitly don't use block! to avoid creating a closure, which creates
386
- # both compile-time and relocatability issues
387
- MLIR. IR. activate! (fnbody)
498
+ return (;
499
+ N,
500
+ traced_args,
501
+ linear_args,
502
+ inv_map,
503
+ in_tys,
504
+ sym_visibility,
505
+ mod,
506
+ traced_args_to_shardings,
507
+ func,
508
+ fnbody,
509
+ )
510
+ end
388
511
389
- result = try
390
- for (i, arg) in enumerate (linear_args)
391
- raw_arg = MLIR. IR. argument (fnbody, i)
392
- row_maj_arg = do_transpose ? transpose_val (raw_arg) : raw_arg
393
- if ! optimize_then_pad
394
- carg = inv_map[arg]
395
- if Reactant. has_padding (carg)
396
- padding = Reactant. get_padding (carg)
397
- sz = size (carg) .+ padding
398
- if ! do_transpose
399
- padding = reverse (padding)
400
- sz = reverse (sz)
401
- end
402
- row_maj_arg = MLIR. IR. result (unpad_val_op (row_maj_arg, padding, sz), 1 )
512
+ function process_linear_args! (linear_args, fnbody, do_transpose, optimize_then_pad, inv_map)
513
+ for (i, arg) in enumerate (linear_args)
514
+ raw_arg = MLIR. IR. argument (fnbody, i)
515
+ row_maj_arg = do_transpose ? transpose_val (raw_arg) : raw_arg
516
+ if ! optimize_then_pad
517
+ carg = inv_map[arg]
518
+ if Reactant. has_padding (carg)
519
+ padding = Reactant. get_padding (carg)
520
+ sz = size (carg) .+ padding
521
+ if ! do_transpose
522
+ padding = reverse (padding)
523
+ sz = reverse (sz)
403
524
end
525
+ row_maj_arg = MLIR. IR. result (unpad_val_op (row_maj_arg, padding, sz), 1 )
404
526
end
405
- set_mlir_data! (arg, row_maj_arg)
406
527
end
407
-
408
- if isempty (kwargs)
409
- Reactant. call_with_reactant (f, traced_args... )
410
- else
411
- Reactant. call_with_reactant (Core. kwcall, kwargs, f, traced_args... )
412
- end
413
- finally
414
- MLIR. IR. deactivate! (fnbody)
415
- Ops. deactivate_constant_context! (fnbody)
528
+ set_mlir_data! (arg, row_maj_arg)
416
529
end
530
+ end
417
531
532
+ function finalize_mlir_fn (
533
+ result,
534
+ traced_args,
535
+ linear_args,
536
+ seen_args,
537
+ seen_results,
538
+ fnbody,
539
+ func,
540
+ mod,
541
+ name,
542
+ in_tys,
543
+ do_transpose,
544
+ optimize_then_pad,
545
+ inv_map,
546
+ args_in_result,
547
+ resprefix,
548
+ argprefix,
549
+ resargprefix,
550
+ verify_arg_names,
551
+ return_dialect,
552
+ traced_args_to_shardings,
553
+ output_shardings,
554
+ sym_visibility,
555
+ num_replicas,
556
+ runtime,
557
+ construct_function_without_args,
558
+ args,
559
+ N,
560
+ concretein,
561
+ toscalar,
562
+ )
418
563
# check which arguments have been mutated
419
564
mutated_args = Int[]
420
565
if ! construct_function_without_args
@@ -426,8 +571,6 @@ function make_mlir_fn(
426
571
end
427
572
end
428
573
429
- seen_results = OrderedIdDict ()
430
-
431
574
outmode = if concretein
432
575
@assert ! toscalar
433
576
Reactant. NoStopTracedTrack
@@ -644,6 +787,7 @@ function make_mlir_fn(
644
787
end
645
788
end
646
789
790
+ ctx = MLIR. IR. context ()
647
791
# Attach `sdy.sharding` attribute to the argument
648
792
for (i, arg) in enumerate (linear_args)
649
793
if haskey (traced_args_to_shardings, arg)
@@ -742,27 +886,18 @@ function make_mlir_fn(
742
886
MLIR. API. mlirOperationDestroy (func. operation)
743
887
func. operation = MLIR. API. MlirOperation (C_NULL )
744
888
745
- return CompiledMlirFnResult (
746
- false ,
889
+ return (
747
890
func2,
748
891
traced_result,
749
- result,
750
- seen_args,
751
892
ret,
752
893
linear_args,
753
894
in_tys,
754
895
linear_results,
755
896
num_partitions,
756
- num_replicas,
757
897
is_sharded,
758
- nothing ,
759
- nothing ,
760
898
unique_meshes,
761
899
mutated_args,
762
- true ,
763
- missing ,
764
900
global_device_ids,
765
- nothing , # populated later in `compile_mlir!`
766
901
)
767
902
end
768
903
0 commit comments