Skip to content

Commit 63c8339

Browse files
committed
Add initial IR for alloc enqueue
1 parent 1f10b30 commit 63c8339

File tree

9 files changed

+508
-13
lines changed

9 files changed

+508
-13
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def ConvertCUDAToExecutorPass : Pass<"convert-cuda-to-executor",
192192
//===----------------------------------------------------------------------===//
193193
// ConvertTensorRTRuntimeToExecutorPass
194194
//===----------------------------------------------------------------------===//
195+
// TODO: Modify this pass to generate non-DPS stype enqueue functions.
195196
def ConvertTensorRTRuntimeToExecutorPass : Pass<"convert-tensorrt-runtime-to-executor",
196197
"::mlir::ModuleOp"> {
197198
let summary = "Converts TensorRTRuntime dialect ops to executor dialect operations";

mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,11 +487,13 @@ StableHloToExecutableTask::compileStableHLOToExecutable(
487487
runner = pm.get();
488488
}
489489

490+
runner->printAsTextualPipeline(llvm::dbgs());
491+
490492
// Setup pass manager
491-
if (failed(runner->run(module)))
492-
return getInternalErrorStatus(
493-
"failed to run compilation on module with symbol name: {0}",
494-
module.getName() ? *module.getName() : "no-symbol-name");
493+
// if (failed(runner->run(module)))
494+
// return getInternalErrorStatus(
495+
// "failed to run compilation on module with symbol name: {0}",
496+
// module.getName() ? *module.getName() : "no-symbol-name");
495497

496498
// Translate to Runtime Executable
497499
FailureOr<std::unique_ptr<runtime::ExecutableStorage>> exeStorage =

mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) {
289289
return !value.ptr;
290290
}
291291

292+
// Returns whether the RuntimeValue is MemRef.
293+
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value);
294+
295+
// Returns whether the RuntimeValue is Scalar.
296+
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value);
297+
292298
/// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue.
293299
MLIR_CAPI_EXPORTED MTRT_RuntimeValue
294300
mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref);
@@ -383,6 +389,16 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction(
383389
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
384390
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream);
385391

392+
/// Variant of above function which return results.
393+
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunctionWithResult(
394+
MTRT_RuntimeSession session, MTRT_RuntimeClient client,
395+
MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs,
396+
MTRT_RuntimeValue *resultArgs, size_t numResultArgs, MTRT_Stream stream);
397+
398+
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNbResults(MTRT_RuntimeSession session,
399+
MTRT_StringView name,
400+
int64_t *numResults);
401+
386402
//===----------------------------------------------------------------------===//
387403
// DLPack
388404
//===----------------------------------------------------------------------===//

mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name,
100100
llvm::ArrayRef<RuntimeValue *> outputArgs,
101101
std::optional<CudaStream> stream = {});
102102

103+
/// Execute a named function in the session with the specified input args and return results.
104+
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
105+
executeFunctionWithResultWithLuaBackend(
106+
LuaRuntimeSession &session, RuntimeClient &client, std::string_view name,
107+
llvm::ArrayRef<RuntimeValue *> inputArgs,
108+
std::optional<CudaStream> stream = {});
109+
103110
} // namespace mlirtrt::runtime
104111

105112
#endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H

mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
641641
return wrap(static_cast<ScalarValue *>(x));
642642
}
643643

644+
bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) {
645+
RuntimeValue *x = unwrap(value);
646+
return x->getKind() == RuntimeValue::Kind::MemRef;
647+
}
648+
649+
bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) {
650+
RuntimeValue *x = unwrap(value);
651+
return x->getKind() == RuntimeValue::Kind::Scalar;
652+
}
653+
644654
//===----------------------------------------------------------------------===//
645655
// MTRT_RuntimeSessionOptions
646656
//===----------------------------------------------------------------------===//
@@ -697,19 +707,64 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction(
697707
llvm::SmallVector<RuntimeValue *> outArgValues =
698708
llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs),
699709
[](MTRT_RuntimeValue arg) { return unwrap(arg); });
700-
701710
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> result =
702711
executeFunctionWithLuaBackend(
703712
*cppSession, std::string_view(name.data, name.length), inArgValues,
704713
outArgValues,
705714
!mtrtStreamIsNull(stream)
706715
? std::optional(unwrap(stream)->getRawStream())
707716
: std::nullopt);
708-
if (!result.isOk())
717+
if (!result.isOk()) {
709718
return wrap(result.getStatus());
719+
}
720+
return mtrtStatusGetOk();
721+
}
722+
723+
MTRT_Status mtrtRuntimeSessionExecuteFunctionWithResult(
724+
MTRT_RuntimeSession session, MTRT_RuntimeClient client,
725+
MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs,
726+
MTRT_RuntimeValue *resultArgs, size_t numResultArgs,
727+
MTRT_Stream stream) {
728+
LuaRuntimeSession *cppSession =
729+
static_cast<LuaRuntimeSession *>(unwrap(session));
730+
731+
RuntimeClient *cppClient = unwrap(client);
732+
733+
llvm::SmallVector<RuntimeValue *> inArgValues =
734+
llvm::map_to_vector(llvm::ArrayRef(inArgs, numInArgs),
735+
[](MTRT_RuntimeValue arg) { return unwrap(arg); });
736+
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> results =
737+
executeFunctionWithResultWithLuaBackend(
738+
*cppSession, *cppClient, std::string_view(name.data, name.length),
739+
inArgValues,
740+
!mtrtStreamIsNull(stream)
741+
? std::optional(unwrap(stream)->getRawStream())
742+
: std::nullopt);
743+
if (!results.isOk()) {
744+
return wrap(results.getStatus());
745+
}
746+
747+
assert(results->size() == numResultArgs);
748+
749+
for (size_t i = 0; i < numResultArgs; ++i) {
750+
resultArgs[i] = wrap((*results)[i].release());
751+
}
710752

711753
return mtrtStatusGetOk();
712754
}
755+
756+
MTRT_Status mtrtRuntimeSessionGetNbResults(MTRT_RuntimeSession session,
757+
MTRT_StringView name,
758+
int64_t *numResults) {
759+
LuaRuntimeSession *cppSession =
760+
static_cast<LuaRuntimeSession *>(unwrap(session));
761+
*numResults = cppSession->getExecutable()
762+
.getFunction(std::string_view(name.data, name.length))
763+
.getSignature()
764+
.getNumResults();
765+
return mtrtStatusGetOk();
766+
}
767+
713768
//===----------------------------------------------------------------------===//
714769
// MTRT_RuntimeClient
715770
//===----------------------------------------------------------------------===//

mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp

Lines changed: 164 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,83 @@ static Status pushScalarArgument(sol::state_view &lua,
423423
return getOkStatus();
424424
}
425425

426+
// Function to extract shape and stride from sol::object table
427+
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
428+
extractShapeAndStride(const sol::table &table) {
429+
size_t tableSize = table.size();
430+
assert(tableSize >= 3 &&
431+
"Table does not contain shape and stride information");
432+
size_t shapeStrideSize = (tableSize - 3) / 2;
433+
std::vector<int64_t> shape;
434+
std::vector<int64_t> stride;
435+
436+
shape.reserve(shapeStrideSize);
437+
stride.reserve(shapeStrideSize);
438+
439+
// Extract shape
440+
for (size_t i = 4; i <= 3 + shapeStrideSize; ++i) {
441+
shape.push_back(table[i].get<int64_t>());
442+
}
443+
444+
// Extract stride
445+
for (size_t i = 4 + shapeStrideSize; i <= tableSize; ++i) {
446+
stride.push_back(table[i].get<int64_t>());
447+
}
448+
449+
return std::make_tuple(shape, stride);
450+
}
451+
452+
// Convert sol::object to MemRefValue
453+
StatusOr<std::unique_ptr<MemRefValue>>
454+
solObjectToMemRefValue(RuntimeClient *client, const sol::object &obj) {
455+
assert(obj.is<sol::table>() && "Expected a table for MemRefValue");
456+
457+
sol::table memrefTable = obj.as<sol::table>();
458+
uintptr_t ptr = memrefTable[1].get<uintptr_t>();
459+
int64_t offset = memrefTable[3].get<int64_t>();
460+
461+
auto [shape, strides] = extractShapeAndStride(memrefTable);
462+
463+
// TODO: How to extract this information. Should we use function signature to fill in this information later?
464+
mlirtrt::runtime::PointerType addressSpace =
465+
mlirtrt::runtime::PointerType::device;
466+
int64_t bitsPerElement = 32;
467+
std::optional<const Device *> device =
468+
std::nullopt;
469+
std::optional<ScalarType> scalarType = ScalarTypeCode::f32;
470+
471+
return MemRefValue::create(client, addressSpace, bitsPerElement, ptr, offset,
472+
llvm::ArrayRef<int64_t>(shape),
473+
llvm::ArrayRef<int64_t>(strides), device,
474+
scalarType);
475+
}
476+
477+
// Convert sol::object to ScalarValue
478+
std::unique_ptr<ScalarValue> solObjectToScalarValue(const sol::object &obj) {
479+
480+
// TODO: ScalarType is not known. Should we use function signature to fill in
481+
// this information later? Since ScalarValue data type is int64_t. Let's cast
482+
// the object value to int64_t for now.
483+
return std::make_unique<ScalarValue>(obj.as<int64_t>(), ScalarTypeCode::unknown);
484+
}
485+
486+
// Convert sol::object to RuntimeValue's
487+
llvm::SmallVector<std::unique_ptr<RuntimeValue>>
488+
solObjectToRuntimeValues(RuntimeClient *client,
489+
std::vector<sol::object> const &results) {
490+
llvm::SmallVector<std::unique_ptr<RuntimeValue>> values;
491+
for (sol::object r : results) {
492+
// if (r.is<sol::table>()) {
493+
// Assume it's a MemRefValue if it's a table
494+
values.emplace_back(std::move(*solObjectToMemRefValue(client, r)));
495+
// } else {
496+
// // Assume it's a ScalarValue for all other cases
497+
// values.emplace_back(solObjectToScalarValue(r));
498+
// }
499+
}
500+
return values;
501+
}
502+
426503
static Status validateArgsTypesAgainstFuncArgs(const RuntimeValue *runArg,
427504
const TypeUnionView &sigArg) {
428505
if (sigArg.isa<MemrefTypeView>()) {
@@ -520,11 +597,11 @@ runtime::executeFunctionWithLuaBackend(
520597
return getStatusWithMsg(StatusCode::InternalError, "no function named \"",
521598
std::string(name), "\" found");
522599

523-
if (sig.getNumResults() > 0)
524-
return getInvalidArgStatus("functions with {0} results are not supported",
525-
sig.getNumResults());
526-
527600
// Validate the number of arguments against the signature.
601+
if (sig.getNumResults() != 0)
602+
return getInvalidArgStatus(
603+
"function expects 0 result args but received {0}",
604+
sig.getNumResults());
528605
if (sig.getNumOutputArgs() != outputArgs.size())
529606
return getInvalidArgStatus(
530607
"function expects {0} output args (destination args) but received {1}",
@@ -600,3 +677,86 @@ runtime::executeFunctionWithLuaBackend(
600677

601678
return llvm::SmallVector<std::unique_ptr<RuntimeValue>>{};
602679
}
680+
681+
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
682+
runtime::executeFunctionWithResultWithLuaBackend(
683+
LuaRuntimeSession &session,
684+
RuntimeClient &client,
685+
std::string_view name,
686+
llvm::ArrayRef<RuntimeValue *> inputArgs,
687+
std::optional<CudaStream> stream) {
688+
689+
FunctionView meta = session.getExecutable().getFunction(name);
690+
FunctionSignatureView sig = meta.getSignature();
691+
692+
// Call the main function, if present.
693+
sol::state &lua = session.getLuaState();
694+
AllocTracker &tracker = session.getAllocTracker();
695+
sol::protected_function funcObj = lua[name];
696+
if (funcObj.get_type() != sol::type::function)
697+
return getStatusWithMsg(StatusCode::InternalError, "no function named \"",
698+
std::string(name), "\" found");
699+
700+
// Validate the number of arguments against the signature.
701+
if (sig.getNumOutputArgs() != 0)
702+
return getInvalidArgStatus(
703+
"function expects 0 output args (destination args) but received {0}",
704+
sig.getNumOutputArgs());
705+
if (sig.getNumInputArgs() != inputArgs.size())
706+
return getInvalidArgStatus("function expects {0} input args "
707+
"(non-destination args) but received {1}",
708+
sig.getNumInputArgs(), inputArgs.size());
709+
710+
// Validate the inferred Lua function type here against the signature.
711+
for (unsigned i = 0; i < inputArgs.size(); ++i) {
712+
auto status = validateArgsTypesAgainstFuncArgs(inputArgs[i], sig.getArg(i));
713+
if (!status.isOk())
714+
return getInvalidArgStatus(
715+
"Input argument {0} validation failed against "
716+
"corresponding function signature arg {0}. Reason: {1}",
717+
i, status.getString());
718+
}
719+
720+
// Create the arguments.
721+
llvm::SmallVector<sol::object> args;
722+
args.reserve(inputArgs.size());
723+
for (auto [idx, rv] : llvm::enumerate(inputArgs)) {
724+
if (MemRefValue *memref = llvm::dyn_cast<MemRefValue>(rv)) {
725+
MTRT_RETURN_IF_ERROR(pushMemRefTableArg(lua, tracker, args, *memref));
726+
continue;
727+
}
728+
if (ScalarValue *scalar = llvm::dyn_cast<ScalarValue>(rv)) {
729+
MTRT_RETURN_IF_ERROR(pushScalarArgument(lua, args, *scalar));
730+
continue;
731+
}
732+
return getInvalidArgStatus(
733+
"input argument #{0} to function {1} has an unsupported type; "
734+
"arguments must be either MemRefs or scalars",
735+
idx + 1, name);
736+
}
737+
if (stream)
738+
RETURN_STATUS_IF_ERROR(session.setCudaStream(*stream));
739+
740+
// If the number of arguments exceed a particular threshold, then
741+
// we pass arguments packed into a table, otherwise we pass as arguments.
742+
sol::protected_function_result result =
743+
sig.getCConv() == CallingConvention::unpacked
744+
? funcObj(sol::as_args(args))
745+
: funcObj(args);
746+
747+
if (!result.valid()) {
748+
sol::error err(result);
749+
return getStatusWithMsg(StatusCode::InternalError,
750+
"failed to run function \"", std::string(name),
751+
"\": ", err.what());
752+
}
753+
754+
int returnCount = result.return_count();
755+
std::vector<sol::object> results;
756+
// Lua index start from 1
757+
for (int i = 1; i <= returnCount; ++i) {
758+
results.push_back(result[i]);
759+
}
760+
761+
return solObjectToRuntimeValues(&client, results);
762+
}

0 commit comments

Comments
 (0)