Skip to content

Commit

Permalink
Add initial IR for alloc enqueue
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Oct 8, 2024
1 parent 1f10b30 commit 9760adc
Show file tree
Hide file tree
Showing 10 changed files with 521 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def ConvertCUDAToExecutorPass : Pass<"convert-cuda-to-executor",
//===----------------------------------------------------------------------===//
// ConvertTensorRTRuntimeToExecutorPass
//===----------------------------------------------------------------------===//
// TODO: Modify this pass to generate non-DPS stype enqueue functions.
def ConvertTensorRTRuntimeToExecutorPass : Pass<"convert-tensorrt-runtime-to-executor",
"::mlir::ModuleOp"> {
let summary = "Converts TensorRTRuntime dialect ops to executor dialect operations";
Expand Down
10 changes: 6 additions & 4 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,11 +487,13 @@ StableHloToExecutableTask::compileStableHLOToExecutable(
runner = pm.get();
}

runner->printAsTextualPipeline(llvm::dbgs());

// Setup pass manager
if (failed(runner->run(module)))
return getInternalErrorStatus(
"failed to run compilation on module with symbol name: {0}",
module.getName() ? *module.getName() : "no-symbol-name");
// if (failed(runner->run(module)))
// return getInternalErrorStatus(
// "failed to run compilation on module with symbol name: {0}",
// module.getName() ? *module.getName() : "no-symbol-name");

// Translate to Runtime Executable
FailureOr<std::unique_ptr<runtime::ExecutableStorage>> exeStorage =
Expand Down
16 changes: 16 additions & 0 deletions mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) {
return !value.ptr;
}

// Returns whether the RuntimeValue is MemRef.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value);

// Returns whether the RuntimeValue is Scalar.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value);

/// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue.
MLIR_CAPI_EXPORTED MTRT_RuntimeValue
mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref);
Expand Down Expand Up @@ -383,6 +389,16 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction(
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream);

/// Variant of above function which return results.
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunctionWithResult(
MTRT_RuntimeSession session, MTRT_RuntimeClient client,
MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs,
MTRT_RuntimeValue *resultArgs, size_t numResultArgs, MTRT_Stream stream);

MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNbResults(MTRT_RuntimeSession session,
MTRT_StringView name,
int64_t *numResults);

//===----------------------------------------------------------------------===//
// DLPack
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name,
llvm::ArrayRef<RuntimeValue *> outputArgs,
std::optional<CudaStream> stream = {});

/// Execute a named function in the session with the specified input args and return results.
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
executeFunctionWithResultWithLuaBackend(
LuaRuntimeSession &session, RuntimeClient &client, std::string_view name,
llvm::ArrayRef<RuntimeValue *> inputArgs,
std::optional<CudaStream> stream = {});

} // namespace mlirtrt::runtime

#endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H
59 changes: 57 additions & 2 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
return wrap(static_cast<ScalarValue *>(x));
}

bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::MemRef;
}

bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::Scalar;
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeSessionOptions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -697,19 +707,64 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction(
llvm::SmallVector<RuntimeValue *> outArgValues =
llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs),
[](MTRT_RuntimeValue arg) { return unwrap(arg); });

StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> result =
executeFunctionWithLuaBackend(
*cppSession, std::string_view(name.data, name.length), inArgValues,
outArgValues,
!mtrtStreamIsNull(stream)
? std::optional(unwrap(stream)->getRawStream())
: std::nullopt);
if (!result.isOk())
if (!result.isOk()) {
return wrap(result.getStatus());
}
return mtrtStatusGetOk();
}

MTRT_Status mtrtRuntimeSessionExecuteFunctionWithResult(
MTRT_RuntimeSession session, MTRT_RuntimeClient client,
MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs,
MTRT_RuntimeValue *resultArgs, size_t numResultArgs,
MTRT_Stream stream) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));

RuntimeClient *cppClient = unwrap(client);

llvm::SmallVector<RuntimeValue *> inArgValues =
llvm::map_to_vector(llvm::ArrayRef(inArgs, numInArgs),
[](MTRT_RuntimeValue arg) { return unwrap(arg); });
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> results =
executeFunctionWithResultWithLuaBackend(
*cppSession, *cppClient, std::string_view(name.data, name.length),
inArgValues,
!mtrtStreamIsNull(stream)
? std::optional(unwrap(stream)->getRawStream())
: std::nullopt);
if (!results.isOk()) {
return wrap(results.getStatus());
}

assert(results->size() == numResultArgs);

for (size_t i = 0; i < numResultArgs; ++i) {
resultArgs[i] = wrap((*results)[i].release());
}

return mtrtStatusGetOk();
}

MTRT_Status mtrtRuntimeSessionGetNbResults(MTRT_RuntimeSession session,
MTRT_StringView name,
int64_t *numResults) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));
*numResults = cppSession->getExecutable()
.getFunction(std::string_view(name.data, name.length))
.getSignature()
.getNumResults();
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeClient
//===----------------------------------------------------------------------===//
Expand Down
168 changes: 164 additions & 4 deletions mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,83 @@ static Status pushScalarArgument(sol::state_view &lua,
return getOkStatus();
}

// Function to extract shape and stride from sol::object table
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
extractShapeAndStride(const sol::table &table) {
size_t tableSize = table.size();
assert(tableSize >= 3 &&
"Table does not contain shape and stride information");
size_t shapeStrideSize = (tableSize - 3) / 2;
std::vector<int64_t> shape;
std::vector<int64_t> stride;

shape.reserve(shapeStrideSize);
stride.reserve(shapeStrideSize);

// Extract shape
for (size_t i = 4; i <= 3 + shapeStrideSize; ++i) {
shape.push_back(table[i].get<int64_t>());
}

// Extract stride
for (size_t i = 4 + shapeStrideSize; i <= tableSize; ++i) {
stride.push_back(table[i].get<int64_t>());
}

return std::make_tuple(shape, stride);
}

// Convert sol::object to MemRefValue
StatusOr<std::unique_ptr<MemRefValue>>
solObjectToMemRefValue(RuntimeClient *client, const sol::object &obj) {
assert(obj.is<sol::table>() && "Expected a table for MemRefValue");

sol::table memrefTable = obj.as<sol::table>();
uintptr_t ptr = memrefTable[1].get<uintptr_t>();
int64_t offset = memrefTable[3].get<int64_t>();

auto [shape, strides] = extractShapeAndStride(memrefTable);

// TODO: How to extract this information. Should we use function signature to fill in this information later?
mlirtrt::runtime::PointerType addressSpace =
mlirtrt::runtime::PointerType::device;
int64_t bitsPerElement = 32;
std::optional<const Device *> device =
std::nullopt;
std::optional<ScalarType> scalarType = ScalarTypeCode::f32;

return MemRefValue::create(client, addressSpace, bitsPerElement, ptr, offset,
llvm::ArrayRef<int64_t>(shape),
llvm::ArrayRef<int64_t>(strides), device,
scalarType);
}

// Convert sol::object to ScalarValue
std::unique_ptr<ScalarValue> solObjectToScalarValue(const sol::object &obj) {

// TODO: ScalarType is not known. Should we use function signature to fill in
// this information later? Since ScalarValue data type is int64_t. Let's cast
// the object value to int64_t for now.
return std::make_unique<ScalarValue>(obj.as<int64_t>(), ScalarTypeCode::unknown);
}

// Convert sol::object to RuntimeValue's
llvm::SmallVector<std::unique_ptr<RuntimeValue>>
solObjectToRuntimeValues(RuntimeClient *client,
std::vector<sol::object> const &results) {
llvm::SmallVector<std::unique_ptr<RuntimeValue>> values;
for (sol::object r : results) {
// if (r.is<sol::table>()) {
// Assume it's a MemRefValue if it's a table
values.emplace_back(std::move(*solObjectToMemRefValue(client, r)));
// } else {
// // Assume it's a ScalarValue for all other cases
// values.emplace_back(solObjectToScalarValue(r));
// }
}
return values;
}

static Status validateArgsTypesAgainstFuncArgs(const RuntimeValue *runArg,
const TypeUnionView &sigArg) {
if (sigArg.isa<MemrefTypeView>()) {
Expand Down Expand Up @@ -520,11 +597,11 @@ runtime::executeFunctionWithLuaBackend(
return getStatusWithMsg(StatusCode::InternalError, "no function named \"",
std::string(name), "\" found");

if (sig.getNumResults() > 0)
return getInvalidArgStatus("functions with {0} results are not supported",
sig.getNumResults());

// Validate the number of arguments against the signature.
if (sig.getNumResults() != 0)
return getInvalidArgStatus(
"function expects 0 result args but received {0}",
sig.getNumResults());
if (sig.getNumOutputArgs() != outputArgs.size())
return getInvalidArgStatus(
"function expects {0} output args (destination args) but received {1}",
Expand Down Expand Up @@ -600,3 +677,86 @@ runtime::executeFunctionWithLuaBackend(

return llvm::SmallVector<std::unique_ptr<RuntimeValue>>{};
}

StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
runtime::executeFunctionWithResultWithLuaBackend(
LuaRuntimeSession &session,
RuntimeClient &client,
std::string_view name,
llvm::ArrayRef<RuntimeValue *> inputArgs,
std::optional<CudaStream> stream) {

FunctionView meta = session.getExecutable().getFunction(name);
FunctionSignatureView sig = meta.getSignature();

// Call the main function, if present.
sol::state &lua = session.getLuaState();
AllocTracker &tracker = session.getAllocTracker();
sol::protected_function funcObj = lua[name];
if (funcObj.get_type() != sol::type::function)
return getStatusWithMsg(StatusCode::InternalError, "no function named \"",
std::string(name), "\" found");

// Validate the number of arguments against the signature.
if (sig.getNumOutputArgs() != 0)
return getInvalidArgStatus(
"function expects 0 output args (destination args) but received {0}",
sig.getNumOutputArgs());
if (sig.getNumInputArgs() != inputArgs.size())
return getInvalidArgStatus("function expects {0} input args "
"(non-destination args) but received {1}",
sig.getNumInputArgs(), inputArgs.size());

// Validate the inferred Lua function type here against the signature.
for (unsigned i = 0; i < inputArgs.size(); ++i) {
auto status = validateArgsTypesAgainstFuncArgs(inputArgs[i], sig.getArg(i));
if (!status.isOk())
return getInvalidArgStatus(
"Input argument {0} validation failed against "
"corresponding function signature arg {0}. Reason: {1}",
i, status.getString());
}

// Create the arguments.
llvm::SmallVector<sol::object> args;
args.reserve(inputArgs.size());
for (auto [idx, rv] : llvm::enumerate(inputArgs)) {
if (MemRefValue *memref = llvm::dyn_cast<MemRefValue>(rv)) {
MTRT_RETURN_IF_ERROR(pushMemRefTableArg(lua, tracker, args, *memref));
continue;
}
if (ScalarValue *scalar = llvm::dyn_cast<ScalarValue>(rv)) {
MTRT_RETURN_IF_ERROR(pushScalarArgument(lua, args, *scalar));
continue;
}
return getInvalidArgStatus(
"input argument #{0} to function {1} has an unsupported type; "
"arguments must be either MemRefs or scalars",
idx + 1, name);
}
if (stream)
RETURN_STATUS_IF_ERROR(session.setCudaStream(*stream));

// If the number of arguments exceed a particular threshold, then
// we pass arguments packed into a table, otherwise we pass as arguments.
sol::protected_function_result result =
sig.getCConv() == CallingConvention::unpacked
? funcObj(sol::as_args(args))
: funcObj(args);

if (!result.valid()) {
sol::error err(result);
return getStatusWithMsg(StatusCode::InternalError,
"failed to run function \"", std::string(name),
"\": ", err.what());
}

int returnCount = result.return_count();
std::vector<sol::object> results;
// Lua index start from 1
for (int i = 1; i <= returnCount; ++i) {
results.push_back(result[i]);
}

return solObjectToRuntimeValues(&client, results);
}
Loading

0 comments on commit 9760adc

Please sign in to comment.