@@ -423,6 +423,83 @@ static Status pushScalarArgument(sol::state_view &lua,
423
423
return getOkStatus ();
424
424
}
425
425
426
+ // Function to extract shape and stride from sol::object table
427
+ std::tuple<std::vector<int64_t >, std::vector<int64_t >>
428
+ extractShapeAndStride (const sol::table &table) {
429
+ size_t tableSize = table.size ();
430
+ assert (tableSize >= 3 &&
431
+ " Table does not contain shape and stride information" );
432
+ size_t shapeStrideSize = (tableSize - 3 ) / 2 ;
433
+ std::vector<int64_t > shape;
434
+ std::vector<int64_t > stride;
435
+
436
+ shape.reserve (shapeStrideSize);
437
+ stride.reserve (shapeStrideSize);
438
+
439
+ // Extract shape
440
+ for (size_t i = 4 ; i <= 3 + shapeStrideSize; ++i) {
441
+ shape.push_back (table[i].get <int64_t >());
442
+ }
443
+
444
+ // Extract stride
445
+ for (size_t i = 4 + shapeStrideSize; i <= tableSize; ++i) {
446
+ stride.push_back (table[i].get <int64_t >());
447
+ }
448
+
449
+ return std::make_tuple (shape, stride);
450
+ }
451
+
452
+ // Convert sol::object to MemRefValue
453
+ StatusOr<std::unique_ptr<MemRefValue>>
454
+ solObjectToMemRefValue (RuntimeClient *client, const sol::object &obj) {
455
+ assert (obj.is <sol::table>() && " Expected a table for MemRefValue" );
456
+
457
+ sol::table memrefTable = obj.as <sol::table>();
458
+ uintptr_t ptr = memrefTable[1 ].get <uintptr_t >();
459
+ int64_t offset = memrefTable[3 ].get <int64_t >();
460
+
461
+ auto [shape, strides] = extractShapeAndStride (memrefTable);
462
+
463
+ // TODO: How to extract this information. Should we use function signature to fill in this information later?
464
+ mlirtrt::runtime::PointerType addressSpace =
465
+ mlirtrt::runtime::PointerType::device;
466
+ int64_t bitsPerElement = 32 ;
467
+ std::optional<const Device *> device =
468
+ std::nullopt;
469
+ std::optional<ScalarType> scalarType = ScalarTypeCode::f32 ;
470
+
471
+ return MemRefValue::create (client, addressSpace, bitsPerElement, ptr, offset,
472
+ llvm::ArrayRef<int64_t >(shape),
473
+ llvm::ArrayRef<int64_t >(strides), device,
474
+ scalarType);
475
+ }
476
+
477
+ // Convert sol::object to ScalarValue
478
+ std::unique_ptr<ScalarValue> solObjectToScalarValue (const sol::object &obj) {
479
+
480
+ // TODO: ScalarType is not known. Should we use function signature to fill in
481
+ // this information later? Since ScalarValue data type is int64_t. Let's cast
482
+ // the object value to int64_t for now.
483
+ return std::make_unique<ScalarValue>(obj.as <int64_t >(), ScalarTypeCode::unknown);
484
+ }
485
+
486
+ // Convert sol::object to RuntimeValue's
487
+ llvm::SmallVector<std::unique_ptr<RuntimeValue>>
488
+ solObjectToRuntimeValues (RuntimeClient *client,
489
+ std::vector<sol::object> const &results) {
490
+ llvm::SmallVector<std::unique_ptr<RuntimeValue>> values;
491
+ for (sol::object r : results) {
492
+ // if (r.is<sol::table>()) {
493
+ // Assume it's a MemRefValue if it's a table
494
+ values.emplace_back (std::move (*solObjectToMemRefValue (client, r)));
495
+ // } else {
496
+ // // Assume it's a ScalarValue for all other cases
497
+ // values.emplace_back(solObjectToScalarValue(r));
498
+ // }
499
+ }
500
+ return values;
501
+ }
502
+
426
503
static Status validateArgsTypesAgainstFuncArgs (const RuntimeValue *runArg,
427
504
const TypeUnionView &sigArg) {
428
505
if (sigArg.isa <MemrefTypeView>()) {
@@ -520,11 +597,11 @@ runtime::executeFunctionWithLuaBackend(
520
597
return getStatusWithMsg (StatusCode::InternalError, " no function named \" " ,
521
598
std::string (name), " \" found" );
522
599
523
- if (sig.getNumResults () > 0 )
524
- return getInvalidArgStatus (" functions with {0} results are not supported" ,
525
- sig.getNumResults ());
526
-
527
600
// Validate the number of arguments against the signature.
601
+ if (sig.getNumResults () != 0 )
602
+ return getInvalidArgStatus (
603
+ " function expects 0 result args but received {0}" ,
604
+ sig.getNumResults ());
528
605
if (sig.getNumOutputArgs () != outputArgs.size ())
529
606
return getInvalidArgStatus (
530
607
" function expects {0} output args (destination args) but received {1}" ,
@@ -600,3 +677,86 @@ runtime::executeFunctionWithLuaBackend(
600
677
601
678
return llvm::SmallVector<std::unique_ptr<RuntimeValue>>{};
602
679
}
680
+
681
+ StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
682
+ runtime::executeFunctionWithResultWithLuaBackend (
683
+ LuaRuntimeSession &session,
684
+ RuntimeClient &client,
685
+ std::string_view name,
686
+ llvm::ArrayRef<RuntimeValue *> inputArgs,
687
+ std::optional<CudaStream> stream) {
688
+
689
+ FunctionView meta = session.getExecutable ().getFunction (name);
690
+ FunctionSignatureView sig = meta.getSignature ();
691
+
692
+ // Call the main function, if present.
693
+ sol::state &lua = session.getLuaState ();
694
+ AllocTracker &tracker = session.getAllocTracker ();
695
+ sol::protected_function funcObj = lua[name];
696
+ if (funcObj.get_type () != sol::type::function)
697
+ return getStatusWithMsg (StatusCode::InternalError, " no function named \" " ,
698
+ std::string (name), " \" found" );
699
+
700
+ // Validate the number of arguments against the signature.
701
+ if (sig.getNumOutputArgs () != 0 )
702
+ return getInvalidArgStatus (
703
+ " function expects 0 output args (destination args) but received {0}" ,
704
+ sig.getNumOutputArgs ());
705
+ if (sig.getNumInputArgs () != inputArgs.size ())
706
+ return getInvalidArgStatus (" function expects {0} input args "
707
+ " (non-destination args) but received {1}" ,
708
+ sig.getNumInputArgs (), inputArgs.size ());
709
+
710
+ // Validate the inferred Lua function type here against the signature.
711
+ for (unsigned i = 0 ; i < inputArgs.size (); ++i) {
712
+ auto status = validateArgsTypesAgainstFuncArgs (inputArgs[i], sig.getArg (i));
713
+ if (!status.isOk ())
714
+ return getInvalidArgStatus (
715
+ " Input argument {0} validation failed against "
716
+ " corresponding function signature arg {0}. Reason: {1}" ,
717
+ i, status.getString ());
718
+ }
719
+
720
+ // Create the arguments.
721
+ llvm::SmallVector<sol::object> args;
722
+ args.reserve (inputArgs.size ());
723
+ for (auto [idx, rv] : llvm::enumerate (inputArgs)) {
724
+ if (MemRefValue *memref = llvm::dyn_cast<MemRefValue>(rv)) {
725
+ MTRT_RETURN_IF_ERROR (pushMemRefTableArg (lua, tracker, args, *memref));
726
+ continue ;
727
+ }
728
+ if (ScalarValue *scalar = llvm::dyn_cast<ScalarValue>(rv)) {
729
+ MTRT_RETURN_IF_ERROR (pushScalarArgument (lua, args, *scalar));
730
+ continue ;
731
+ }
732
+ return getInvalidArgStatus (
733
+ " input argument #{0} to function {1} has an unsupported type; "
734
+ " arguments must be either MemRefs or scalars" ,
735
+ idx + 1 , name);
736
+ }
737
+ if (stream)
738
+ RETURN_STATUS_IF_ERROR (session.setCudaStream (*stream));
739
+
740
+ // If the number of arguments exceed a particular threshold, then
741
+ // we pass arguments packed into a table, otherwise we pass as arguments.
742
+ sol::protected_function_result result =
743
+ sig.getCConv () == CallingConvention::unpacked
744
+ ? funcObj (sol::as_args (args))
745
+ : funcObj (args);
746
+
747
+ if (!result.valid ()) {
748
+ sol::error err (result);
749
+ return getStatusWithMsg (StatusCode::InternalError,
750
+ " failed to run function \" " , std::string (name),
751
+ " \" : " , err.what ());
752
+ }
753
+
754
+ int returnCount = result.return_count ();
755
+ std::vector<sol::object> results;
756
+ // Lua index start from 1
757
+ for (int i = 1 ; i <= returnCount; ++i) {
758
+ results.push_back (result[i]);
759
+ }
760
+
761
+ return solObjectToRuntimeValues (&client, results);
762
+ }
0 commit comments