Skip to content

Commit

Permalink
Update allocate/deallocate interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Aug 22, 2024
1 parent 87ca776 commit 539fb56
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 75 deletions.
23 changes: 9 additions & 14 deletions mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <stddef.h>
#include <stdint.h>

#include "cuda_runtime.h"

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -323,24 +325,17 @@ mtrtScalarValueGetType(MTRT_ScalarValue scalar, MTRT_ScalarTypeCode *code);
// MTRT_GpuAllocator
//===----------------------------------------------------------------------===//

// Function pointer types for the allocate and deallocate callbacks.
typedef void *(*AllocateFunc)(void *self, uint64_t size, uint64_t alignment, uint32_t flags, cudaStream_t* stream);
typedef bool (*DeallocateFunc)(void *self, void *memory, cudaStream_t* stream);

// Function pointer types for the allocate and deallocate callbacks
typedef void* (*AllocateFunc)(void* self, uint64_t size);
typedef bool (*DeallocateFunc)(void* self, void* memory);

// The MTRT_GpuAllocator struct
typedef struct MTRT_GpuAllocator {
void* ptr; // Pointer to the implementation (PyGpuAllocatorTrampoline in our case)
AllocateFunc allocate; // Function pointer for allocation
DeallocateFunc deallocate; // Function pointer for deallocation
void *ptr; // Pointer to the implementation (PyGpuAllocatorTrampoline in our
// case.)
AllocateFunc allocate; // Function pointer for allocation
DeallocateFunc deallocate; // Function pointer for deallocation
} MTRT_GpuAllocator;

/// Checks nullity of `GpuAllocator`.
MTRT_CAPI_EXPORTED bool GpuAllocatorIsNull(MTRT_GpuAllocator gpuAllocator);

MTRT_CAPI_EXPORTED MTRT_Status
GpuAllocatorDestroy(MTRT_GpuAllocator executable);

//===----------------------------------------------------------------------===//
// MTRT_RuntimeSessionOptions
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 12 additions & 4 deletions mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,24 @@ class GpuAllocator {
public:
GpuAllocator() = default;
virtual ~GpuAllocator() = default;
virtual void* allocate(uint64_t const size) { return nullptr; }
virtual bool deallocate(void *const memory) { return false; }
virtual void *allocate(uint64_t const size, uint64_t const alignment,
uint32_t flags, cudaStream_t* stream) {
return nullptr;
}
virtual bool deallocate(void *const memory,
cudaStream_t* stream) {
return false;
}
};

class CustomTensorRTAllocator : public GpuAllocator {
public:
CustomTensorRTAllocator() = default;
~CustomTensorRTAllocator() = default;
void* allocate(uint64_t const size) override;
bool deallocate(void *const memory) override;
void *allocate(uint64_t const size, uint64_t const alignment, uint32_t flags,
cudaStream_t* stream) override;
bool deallocate(void *const memory,
cudaStream_t* stream) override;
};

//===----------------------------------------------------------------------===//
Expand Down
31 changes: 4 additions & 27 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,29 +601,6 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
return wrap(static_cast<ScalarValue *>(x));
}

//===----------------------------------------------------------------------===//
// MTRT_GpuAllocator
//===----------------------------------------------------------------------===//

bool GpuAllocatorIsNull(MTRT_GpuAllocator gpuAllocator) {
return !gpuAllocator.ptr;
}

MTRT_Status GpuAllocatorDestroy(MTRT_GpuAllocator gpuAllocator) {
// delete unwrap(gpuAllocator);
return mtrtStatusGetOk();
}

// TODO: Implement destroy method to release resources.
// void mtrtGpuAllocatorDestroy(MTRT_GpuAllocator* allocator) {
// if (allocator && allocator->ptr) {
// delete static_cast<PyGpuAllocatorTrampoline*>(allocator->ptr);
// allocator->ptr = nullptr;
// allocator->allocate = nullptr;
// allocator->deallocate = nullptr;
// }
// }

//===----------------------------------------------------------------------===//
// MTRT_RuntimeSessionOptions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -660,12 +637,12 @@ class GpuAllocatorWrapper : public GpuAllocator {
GpuAllocatorWrapper(MTRT_GpuAllocator gpuAllocator)
: mPyGpuAllocator(gpuAllocator) {}

void *allocate(uint64_t size) override {
return mPyGpuAllocator.allocate(mPyGpuAllocator.ptr, size);
void *allocate(uint64_t size, uint64_t alignment, uint32_t flags, cudaStream_t* stream) override {
return mPyGpuAllocator.allocate(mPyGpuAllocator.ptr, size, alignment, flags, stream);
}

bool deallocate(void *ptr) override {
return mPyGpuAllocator.deallocate(mPyGpuAllocator.ptr, ptr);
bool deallocate(void *ptr, cudaStream_t* stream) override {
return mPyGpuAllocator.deallocate(mPyGpuAllocator.ptr, ptr, stream);
}

// Static method to create a GpuAllocator from MTRT_GpuAllocator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,21 @@ class StdioLogger : public nvinfer1::ILogger {
// TensorRTCallBackAllocator
//===----------------------------------------------------------------------===//

class TensorRTCallBackAllocator final : public nvinfer1::IGpuAllocator {
class TensorRTCallBackAllocator final : public nvinfer1::IGpuAsyncAllocator {
public:
TensorRTCallBackAllocator(GpuAllocator *gpuAllocator)
: nvinfer1::IGpuAllocator(), mGpuAllocatorCallBack(gpuAllocator) {}
: nvinfer1::IGpuAsyncAllocator(), mGpuAllocatorCallBack(gpuAllocator) {}

void *allocate(uint64_t size, uint64_t alignment,
nvinfer1::AllocatorFlags flags) noexcept final {
return allocateAsync(size, alignment, flags, nullptr);
}

bool deallocate(void *memory) noexcept final {
return deallocateAsync(memory, nullptr);
}

void *allocateAsync(uint64_t const size, uint64_t const /*alignment*/,
uint32_t /*flags*/, cudaStream_t /*stream*/) noexcept final {
void* result = mGpuAllocatorCallBack->allocate(size);
void *allocateAsync(uint64_t const size, uint64_t const alignment,
uint32_t flags, cudaStream_t stream) noexcept final {
void *result =
mGpuAllocatorCallBack->allocate(size, alignment, flags, &stream);
return result;
}

bool deallocateAsync(void *const memory,
cudaStream_t /*stream*/) noexcept override {
bool result = mGpuAllocatorCallBack->deallocate(memory);
cudaStream_t stream) noexcept override {
bool result = mGpuAllocatorCallBack->deallocate(memory, &stream);
return result;
}

Expand Down
40 changes: 36 additions & 4 deletions mlir-tensorrt/executor/lib/Support/Allocators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,46 @@ using namespace mlirtrt;
// CustomTensorRTAllocator
//===----------------------------------------------------------------------===//

void *CustomTensorRTAllocator::allocate(uint64_t const size) {

void*
CustomTensorRTAllocator::allocate(uint64_t const size, uint64_t const alignment,
uint32_t /*flags*/,
cudaStream_t* stream) {
uint8_t *memory;
cudaMalloc(reinterpret_cast<void **>(&memory), size);
assert(alignment > 0 && (alignment & (alignment - 1)) == 0 &&
"Memory alignment has to be power of 2");
if (stream && *stream != nullptr) {
auto status = cudaMallocAsync(reinterpret_cast<void **>(&memory), size, *stream);
assert(status == cudaSuccess);
MTRT_DBGF("[CustomTensorRTAllocator][allocate]: Asynchronously allocated %lx bytes at 0x%lx on stream %lx", size,
reinterpret_cast<uintptr_t>(memory),
reinterpret_cast<uintptr_t>(*stream));
} else {
auto status = cudaMalloc(reinterpret_cast<void **>(&memory), size);
assert(status == cudaSuccess);
MTRT_DBGF("[CustomTensorRTAllocator][allocate]: Synchronously allocated %lx bytes at 0x%lx", size,
reinterpret_cast<uintptr_t>(memory));
}
assert(reinterpret_cast<uintptr_t>(memory) % alignment == 0);
return memory;
}

bool CustomTensorRTAllocator::deallocate(void *const memory) {
cudaFree(memory);
bool CustomTensorRTAllocator::deallocate(void *const memory,
cudaStream_t* stream) {
if (stream && *stream != nullptr) {
MTRT_DBGF("[CustomTensorRTAllocator][deallocate]: Asynchronously freeing CUDA device memory 0x%lx on stream %lx",
reinterpret_cast<uintptr_t>(memory),
reinterpret_cast<uintptr_t>(*stream));
cudaError_t status = cudaFreeAsync(memory, *stream);
assert(status == cudaSuccess);
} else {
MTRT_DBGF("[CustomTensorRTAllocator][deallocate]: Synchronously freeing CUDA device/pinned host memory 0x%lx ptr "
"on stream %lx",
reinterpret_cast<uintptr_t>(memory),
reinterpret_cast<uintptr_t>(*stream));
cudaError_t status = cudaFree(memory);
assert(status == cudaSuccess);
}
return true;
}

Expand Down
22 changes: 13 additions & 9 deletions mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class PyStream : public PyMTRTWrapper<PyStream, MTRT_Stream> {
public:
using Base::Base;
DECLARE_WRAPPER_CONSTRUCTORS(PyStream);

static constexpr auto kMethodTable = CAPITable<MTRT_Stream>{
mtrtStreamIsNull, mtrtStreamDestroy, mtrtPythonCapsuleToStream,
mtrtPythonStreamToCapsule};
Expand Down Expand Up @@ -195,24 +196,28 @@ class PyGpuAllocator {
PyGpuAllocator(py::object self) : pySelf(self) {}

virtual ~PyGpuAllocator() = default;
virtual std::uintptr_t allocate(uint64_t size) = 0;
virtual std::uintptr_t allocate(uint64_t size, uint64_t alignment,
uint32_t flags) = 0;
virtual bool deallocate(std::uintptr_t ptr) = 0;

// Creates a C-compatible struct for interfacing with lower-level APIs.
MTRT_GpuAllocator getCApiObject() { return createWithPythonCallbacks(this); }

private:
// Trampoline function: Routes C-style allocation calls to C++ virtual method.
static void *pyGpuAllocatorAllocate(void *self, uint64_t size) {
static void *pyGpuAllocatorAllocate(void *self, uint64_t size,
uint64_t alignment, uint32_t flags,
cudaStream_t* /*stream*/) {
py::gil_scoped_acquire acquire;
auto *allocator = static_cast<PyGpuAllocator *>(self);
std::uintptr_t ptr = allocator->allocate(size);
std::uintptr_t ptr = allocator->allocate(size, alignment, flags);
return reinterpret_cast<void *>(ptr);
}

// Trampoline function: Routes C-style deallocation calls to C++ virtual
// method.
static bool pyGpuAllocatorDeallocate(void *self, void *memory) {
static bool pyGpuAllocatorDeallocate(void *self, void *memory,
cudaStream_t* /*stream*/) {
py::gil_scoped_acquire acquire;
auto *allocator = static_cast<PyGpuAllocator *>(self);
return allocator->deallocate(reinterpret_cast<std::uintptr_t>(memory));
Expand All @@ -237,12 +242,12 @@ class PyGpuAllocatorTrampoline : public PyGpuAllocator {

// Trampoline for allocate: Dispatches call to Python implementation if
// overridden.
uintptr_t allocate(uint64_t size) override {
uintptr_t allocate(uint64_t size, uint64_t alignment, uint32_t flags) override {
PYBIND11_OVERRIDE_PURE(uintptr_t, // Return type
PyGpuAllocator, // Parent class
allocate, // Name of function in C++
size // Arguments
);
size, // Arguments
alignment, flags);
}

// Trampoline for deallocate: Dispatches call to Python implementation if
Expand All @@ -251,8 +256,7 @@ class PyGpuAllocatorTrampoline : public PyGpuAllocator {
PYBIND11_OVERRIDE_PURE(bool, // Return type
PyGpuAllocator, // Parent class
deallocate, // Name of function in C++
ptr // Arguments
);
ptr); // Arguments
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self):
super().__init__(self)
self.allocations = {} # Keep track of allocations

def allocate(self, size):
def allocate(self, size, alignment, flags):
# Allocate memory on the GPU using CuPy
mem = cp.cuda.alloc(size)
ptr = int(mem.ptr) # Convert to integer
Expand Down

0 comments on commit 539fb56

Please sign in to comment.