Skip to content

Commit dec20fe

Browse files
committed
Add IGpuAllocator to MLIR-TensorRT
1 parent b6836bb commit dec20fe

File tree

13 files changed

+351
-38
lines changed

13 files changed

+351
-38
lines changed

mlir-tensorrt/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ We currently support only building on Linux x86 systems.
2323
We support building several different ways (only via CMake) depending on use-case.
2424

2525
In each case, the LLVM-Project version that we are currently aligned to is
26-
given in `build_tools/cmake/LLVMCommit.txt`.
26+
given in `build_tools/cmake/LLVMCommit.cmake`.
2727

2828
Note that currently we provide an LLVM patch which essentially cherry-picks the
2929
bug fixes from [this open MLIR PR](https://github.com/llvm/llvm-project/pull/91524).
@@ -82,7 +82,7 @@ git clone https://github.com/llvm/llvm-project.git llvm-project
8282
# Checkout the right commit. Of course, you may try
8383
# a newer commit or your own modified LLVM-Project.
8484
cd llvm-project
85-
git checkout $(cat build_tools/cmake/LLVMCommit.cmake | grep -Po '(?<=").*(?=")')
85+
git checkout $(cat ../build_tools/cmake/LLVMCommit.cmake | grep -Po '(?<=").*(?=")')
8686

8787
# Apply patch from llvm-project PR 91524
8888
git apply ../build_tools/llvm-project.patch

mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,30 @@ mtrtScalarValueCastToRuntimeValue(MTRT_ScalarValue v);
312312
MLIR_CAPI_EXPORTED MTRT_Status
313313
mtrtScalarValueGetType(MTRT_ScalarValue scalar, MTRT_ScalarTypeCode *code);
314314

315+
//===----------------------------------------------------------------------===//
316+
// MTRT_GpuAllocator
317+
//===----------------------------------------------------------------------===//
318+
319+
typedef struct MTRT_GpuAllocator {
320+
void *ptr;
321+
} MTRT_GpuAllocator;
322+
323+
MTRT_CAPI_EXPORTED bool mtrtGpuAllocatorIsNull(MTRT_GpuAllocator gpuAllocator);
324+
325+
MTRT_CAPI_EXPORTED MTRT_Status
326+
mtrtGpuAllocatorDestroy(MTRT_GpuAllocator executable);
327+
328+
329+
//===----------------------------------------------------------------------===//
330+
// MTRT_GpuAllocator
331+
//===----------------------------------------------------------------------===//
332+
333+
MTRT_CAPI_EXPORTED MTRT_Status mtrtGpuAllocatorAllocate(MTRT_GpuAllocator gpuAllocator,
334+
uint64_t size, uint64_t alignment,
335+
void **memory);
336+
337+
MTRT_CAPI_EXPORTED MTRT_Status mtrtGpuAllocatorDeallocate(MTRT_GpuAllocator gpuAllocator, void *memory, bool *success);
338+
315339
//===----------------------------------------------------------------------===//
316340
// MTRT_RuntimeSessionOptions
317341
//===----------------------------------------------------------------------===//
@@ -352,7 +376,7 @@ typedef struct MTRT_RuntimeSession {
352376
/// that the session only has a read-only view in to the Executable for code and
353377
/// constant data. Therefore the Executable must outlive the RuntimeSession.
354378
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionCreate(
355-
MTRT_RuntimeSessionOptions options, MTRT_Executable executable,
379+
MTRT_RuntimeSessionOptions options, MTRT_Executable executable, MTRT_GpuAllocator allocator,
356380
MTRT_RuntimeSession *result);
357381

358382
/// Destory the session. This does not destroy the associated Executable, which

mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,8 @@ class RuntimeSession {
840840
sol::state state,
841841
std::unique_ptr<PinnedMemoryAllocator> pinnedMemoryAllocator,
842842
std::unique_ptr<AllocTracker> allocTracker,
843-
std::unique_ptr<ResourceTracker> resourceTracker);
843+
std::unique_ptr<ResourceTracker> resourceTracker,
844+
GpuAllocator* gpuAllocator);
844845

845846
ExecutableView getExecutable() const { return executable; }
846847

@@ -854,14 +855,16 @@ class RuntimeSession {
854855

855856
ResourceTracker &getResourceTracker() { return *resourceTracker; }
856857

858+
GpuAllocator* getGpuAllocator() { return gpuAllocator; }
859+
857860
private:
858861
RuntimeSessionOptions options;
859862
ExecutableView executable;
860863

861864
std::unique_ptr<PinnedMemoryAllocator> pinnedMemoryAllocator;
862865
std::unique_ptr<AllocTracker> allocTracker;
863866
std::unique_ptr<ResourceTracker> resourceTracker;
864-
867+
GpuAllocator* gpuAllocator;
865868
sol::state state;
866869
};
867870

mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ void registerLuaRuntimeMethods(lua_State *state,
3737
const RuntimeSessionOptions &options,
3838
PinnedMemoryAllocator *pinnedMemoryAllocator,
3939
AllocTracker *allocTracker,
40-
ResourceTracker *resourceTracker);
40+
ResourceTracker *resourceTracker, GpuAllocator* allocator);
4141

4242
} // namespace mlirtrt::runtime

mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace mlirtrt::runtime {
3636
/// `main` function. It is assumed that `main` takes no arguments and returns an
3737
/// integer result (which is returned if the execution is successful).
3838
/// TODO: this should take a handle to a function for streaming output/errors.
39-
StatusOr<int64_t> runExecutorLuaScript(std::string_view luaScript);
39+
StatusOr<int64_t> runExecutorLuaScript(std::string_view luaScript, GpuAllocator* allocator);
4040

4141
/// Synchronously run a serialized executor Executable one time. An `Executable`
4242
/// is essentially a Lua script packaged with metadata and serialized constants
@@ -48,12 +48,12 @@ StatusOr<int64_t> runExecutorLuaScript(std::string_view luaScript);
4848
/// execution is successful).
4949
/// TODO: this should take a handle to a function for
5050
/// streaming output/errors.
51-
StatusOr<int64_t> runExecutorExecutable(std::unique_ptr<Executable> executable);
51+
StatusOr<int64_t> runExecutorExecutable(std::unique_ptr<Executable> executable, GpuAllocator* allocator);
5252

5353
/// Create an execution state. This will setup a Lua environment and invoke
5454
/// global initialization.
5555
StatusOr<std::unique_ptr<RuntimeSession>>
56-
createRuntimeSessionWithLuaBackend(ExecutableView executable,
56+
createRuntimeSessionWithLuaBackend(ExecutableView executable, GpuAllocator* allocator,
5757
const RuntimeSessionOptions &options);
5858

5959
/// Set the primary stream for the loaded executable to use.

mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ResourceTracker;
3737
/// Lua state.
3838
void registerExecutorTensorRTModuleLuaRuntimeMethods(
3939
lua_State *luaState, PinnedMemoryAllocator *pinnedMemoryAllocator,
40-
AllocTracker *allocTracker, ResourceTracker *resourceTracker);
40+
AllocTracker *allocTracker, ResourceTracker *resourceTracker, GpuAllocator* allocator);
4141

4242
} // namespace mlirtrt::runtime
4343

mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ namespace mlirtrt {
3232

3333
struct EventPool;
3434

35+
class GpuAllocator {
36+
public:
37+
virtual ~GpuAllocator() = default; // Add a virtual destructor
38+
/// Allocate gpu memory. Needs to be implemented a client.
39+
virtual void* allocate(uint64_t size, uint64_t alignment);
40+
/// Returns true if deallocation succeeds
41+
virtual bool deallocate(void* memory);
42+
};
43+
3544
//===----------------------------------------------------------------------===//
3645
// PoolTrackedCudaEvent
3746
//===----------------------------------------------------------------------===//

mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir-executor/Runtime/API/API.h"
2828
#include "mlir-executor/Runtime/API/ExecutableFlatbuffer.h"
2929
#include "mlir-executor/Runtime/Backend/Lua/LuaRuntime.h"
30+
#include "mlir-executor/Support/Allocators.h"
3031
#include "mlir-executor/Support/Status.h"
3132
#include "mlir/Support/FileUtilities.h"
3233
#include "llvm/ADT/SmallVectorExtras.h"
@@ -48,6 +49,8 @@ DEFINE_C_API_PTR_METHODS(MTRT_RuntimeSession,
4849
::mlirtrt::runtime::RuntimeSession)
4950
DEFINE_C_API_PTR_METHODS(MTRT_RuntimeSessionOptions,
5051
::mlirtrt::runtime::RuntimeSessionOptions)
52+
DEFINE_C_API_PTR_METHODS(MTRT_GpuAllocator,
53+
::mlirtrt::GpuAllocator)
5154
DEFINE_C_API_PTR_METHODS(MTRT_Executable, ::mlirtrt::runtime::Executable)
5255
DEFINE_C_API_PTR_METHODS(MTRT_Stream, MTRT_StreamImpl)
5356
DEFINE_C_API_PTR_METHODS(MTRT_RuntimeValue, ::mlirtrt::runtime::RuntimeValue)
@@ -529,6 +532,24 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
529532
return wrap(static_cast<ScalarValue *>(x));
530533
}
531534

535+
//===----------------------------------------------------------------------===//
536+
// MTRT_GpuAllocator
537+
//===----------------------------------------------------------------------===//
538+
539+
MTRT_Status mtrtGpuAllocatorAllocate(MTRT_GpuAllocator gpuAllocator,
540+
uint64_t size, uint64_t alignment,
541+
void **memory) {
542+
GpuAllocator *cppGpuAllocator = unwrap(gpuAllocator);
543+
*memory = cppGpuAllocator->allocate(size, alignment);
544+
return mtrtStatusGetOk();
545+
}
546+
547+
MTRT_Status mtrtGpuAllocatorDeallocate(MTRT_GpuAllocator gpuAllocator, void *memory, bool *success) {
548+
GpuAllocator *cppGpuAllocator = unwrap(gpuAllocator);
549+
*success = cppGpuAllocator->deallocate(memory);
550+
return mtrtStatusGetOk();
551+
}
552+
532553
//===----------------------------------------------------------------------===//
533554
// MTRT_RuntimeSessionOptions
534555
//===----------------------------------------------------------------------===//
@@ -556,12 +577,14 @@ mtrtRuntimeSessionOptionsDestroy(MTRT_RuntimeSessionOptions options) {
556577

557578
MTRT_Status mtrtRuntimeSessionCreate(MTRT_RuntimeSessionOptions options,
558579
MTRT_Executable executable,
580+
MTRT_GpuAllocator gpuAllocator,
559581
MTRT_RuntimeSession *result) {
560582
RuntimeSessionOptions *cppOptions = unwrap(options);
561583
Executable *cppExecutable = unwrap(executable);
584+
GpuAllocator *cppGpuAllocator = unwrap(gpuAllocator);
562585

563586
StatusOr<std::unique_ptr<RuntimeSession>> session =
564-
createRuntimeSessionWithLuaBackend(cppExecutable->getView(), *cppOptions);
587+
createRuntimeSessionWithLuaBackend(cppExecutable->getView(), cppGpuAllocator, *cppOptions);
565588
if (session.isError())
566589
return wrap(session.getStatus());
567590

mlir-tensorrt/executor/lib/Runtime/API/API.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,16 +348,17 @@ RuntimeSessionOptions::createUsingSingleHostMpi() {
348348
//===----------------------------------------------------------------------===//
349349
// RuntimeSession
350350
//===----------------------------------------------------------------------===//
351-
352351
RuntimeSession::RuntimeSession(
353352
RuntimeSessionOptions options, ExecutableView exe, sol::state state,
354353
std::unique_ptr<PinnedMemoryAllocator> pinnedMemoryAllocator,
355354
std::unique_ptr<AllocTracker> allocTracker,
356-
std::unique_ptr<ResourceTracker> resourceTracker)
355+
std::unique_ptr<ResourceTracker> resourceTracker,
356+
GpuAllocator *gpuAllocator)
357357
: options(std::move(options)), executable(exe),
358358
pinnedMemoryAllocator(std::move(pinnedMemoryAllocator)),
359359
allocTracker(std::move(allocTracker)),
360-
resourceTracker(std::move(resourceTracker)), state(std::move(state)) {}
360+
resourceTracker(std::move(resourceTracker)), gpuAllocator(gpuAllocator),
361+
state(std::move(state)) {}
361362

362363
//===----------------------------------------------------------------------===//
363364
// AllocTracker

mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ static void registerDefaultDeviceDependentMethods(lua_State *state,
7272

7373
static void registerLuaRuntimeMethodsCommon(
7474
lua_State *state, PinnedMemoryAllocator *pinnedMemoryAllocator,
75-
AllocTracker *allocTracker, ResourceTracker *resourceTracker) {
75+
AllocTracker *allocTracker, ResourceTracker *resourceTracker, GpuAllocator* allocator) {
7676
registerExecutorCoreModuleLuaRuntimeMethods(state, pinnedMemoryAllocator,
7777
allocTracker);
7878
registerExecutorCUDAModuleLuaRuntimeMethods(
@@ -84,15 +84,15 @@ static void registerLuaRuntimeMethodsCommon(
8484
#endif
8585

8686
registerExecutorTensorRTModuleLuaRuntimeMethods(
87-
state, pinnedMemoryAllocator, allocTracker, resourceTracker);
87+
state, pinnedMemoryAllocator, allocTracker, resourceTracker, allocator);
8888
}
8989

9090
void mlirtrt::runtime::registerLuaRuntimeMethods(
9191
lua_State *state, const RuntimeSessionOptions &options,
9292
PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker,
93-
ResourceTracker *resourceTracker) {
93+
ResourceTracker *resourceTracker, GpuAllocator* allocator) {
9494
registerLuaRuntimeMethodsCommon(state, pinnedMemoryAllocator, allocTracker,
95-
resourceTracker);
95+
resourceTracker, allocator);
9696
#ifdef MLIR_EXECUTOR_ENABLE_NCCL
9797
registerExecutorNCCLModuleLuaRuntimeMethods(state, resourceTracker);
9898
registerDeviceDependentNCCLMethods(state, options.getNumDevices(),
@@ -108,7 +108,7 @@ void mlirtrt::runtime::registerLuaRuntimeMethods(
108108
}
109109

110110
StatusOr<int64_t>
111-
mlirtrt::runtime::runExecutorLuaScript(std::string_view luaScript) {
111+
mlirtrt::runtime::runExecutorLuaScript(std::string_view luaScript, GpuAllocator* allocator) {
112112
ADD_RUNTIME_MODULE_RANGE("runtime_runExecutorLuaScript");
113113

114114
StatusOr<std::unique_ptr<RuntimeClient>> client = RuntimeClient::create();
@@ -120,7 +120,7 @@ mlirtrt::runtime::runExecutorLuaScript(std::string_view luaScript) {
120120
registerLuaRuntimeMethods(lua.lua_state(), RuntimeSessionOptions(),
121121
&(*client)->getPinnedMemorAllocator(),
122122
&(*client)->getAllocTracker(),
123-
&(*client)->getResourceTracker());
123+
&(*client)->getResourceTracker(), allocator);
124124

125125
sol::protected_function_result result = lua.script(luaScript);
126126
if (!result.valid()) {
@@ -171,7 +171,7 @@ static Status maybeCheckForValidNcclUuid(const RuntimeSessionOptions &options) {
171171
/// global initialization.
172172
StatusOr<std::unique_ptr<RuntimeSession>>
173173
mlirtrt::runtime::createRuntimeSessionWithLuaBackend(
174-
ExecutableView executable, const RuntimeSessionOptions &options) {
174+
ExecutableView executable, GpuAllocator* allocator, const RuntimeSessionOptions &options) {
175175
ADD_RUNTIME_MODULE_RANGE("runtime_loadExecutable");
176176

177177
MTRT_RETURN_IF_ERROR(maybeCheckForValidNcclUuid(options));
@@ -184,7 +184,7 @@ mlirtrt::runtime::createRuntimeSessionWithLuaBackend(
184184
lua.open_libraries(sol::lib::base, sol::lib::string);
185185
registerLuaRuntimeMethods(lua.lua_state(), options,
186186
pinnedMemoryAllocator.get(), allocTracker.get(),
187-
resourceTracker.get());
187+
resourceTracker.get(), allocator);
188188

189189
// Load globals into the context.
190190
// TODO: eliminate this copy, we already own the executable.
@@ -225,11 +225,11 @@ mlirtrt::runtime::createRuntimeSessionWithLuaBackend(
225225
}
226226
return std::make_unique<RuntimeSession>(
227227
options, executable, std::move(lua), std::move(pinnedMemoryAllocator),
228-
std::move(allocTracker), std::move(resourceTracker));
228+
std::move(allocTracker), std::move(resourceTracker), allocator);
229229
}
230230

231231
StatusOr<int64_t> mlirtrt::runtime::runExecutorExecutable(
232-
std::unique_ptr<Executable> executable) {
232+
std::unique_ptr<Executable> executable, GpuAllocator* allocator) {
233233

234234
StatusOr<std::unique_ptr<RuntimeClient>> client = RuntimeClient::create();
235235
if (!client.isOk())
@@ -245,7 +245,7 @@ StatusOr<int64_t> mlirtrt::runtime::runExecutorExecutable(
245245
return options.getStatus();
246246

247247
StatusOr<std::unique_ptr<RuntimeSession>> session =
248-
createRuntimeSessionWithLuaBackend(executable->getView(), *options);
248+
createRuntimeSessionWithLuaBackend(executable->getView(), allocator, *options);
249249
if (!session.isOk())
250250
return session.getStatus();
251251

0 commit comments

Comments
 (0)