From 94cbb5a4b11a9df22b003183a736f179c6f05ef2 Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Wed, 21 Aug 2024 00:22:47 -0700 Subject: [PATCH] Add IOutputAllocator --- .../include/mlir-executor/Runtime/API/API.h | 7 + .../Runtime/Backend/Lua/LuaRegistration.h | 4 +- .../Runtime/Backend/Lua/LuaRuntime.h | 10 +- .../Lua/Modules/TensorRT/TensorRTModule.h | 3 +- .../mlir-executor/Support/Allocators.h | 113 ++++++++++++ .../executor/lib/Runtime/API/API.cpp | 6 +- .../lib/Runtime/Backend/Lua/LuaRuntime.cpp | 35 ++-- .../Lua/Modules/TensorRT/TensorRTModule.cpp | 172 ++++++++++++++++-- .../executor/lib/Support/Allocators.cpp | 55 ++++++ .../executor/lib/Tools/ExecutorRunnerMain.cpp | 5 +- 10 files changed, 375 insertions(+), 35 deletions(-) diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h index 70384c60d..41db8f778 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h @@ -868,6 +868,7 @@ class RuntimeSession { std::unique_ptr pinnedMemoryAllocator, std::unique_ptr allocTracker, std::unique_ptr resourceTracker, + std::unique_ptr outputAllocatorTracker, std::unique_ptr gpuAllocator); ExecutableView getExecutable() const { return executable; } @@ -891,6 +892,7 @@ class RuntimeSession { std::unique_ptr pinnedMemoryAllocator; std::unique_ptr allocTracker; std::unique_ptr resourceTracker; + std::unique_ptr outputAllocatorTracker; std::unique_ptr gpuAllocator; sol::state state; }; @@ -973,6 +975,10 @@ class RuntimeClient { return pinnedMemoryAllocator; } + OutputAllocatorTracker& getOutputAllocatorTracker() { + return outputAllocatorTracker; + } + private: RuntimeClient(llvm::SmallVector> devices) : devices(std::move(devices)) {} @@ -981,6 +987,7 @@ class RuntimeClient { PinnedMemoryAllocator pinnedMemoryAllocator; AllocTracker allocTracker; ResourceTracker resourceTracker; + OutputAllocatorTracker outputAllocatorTracker; }; //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h index 922e964d4..9dd689de8 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h @@ -37,6 +37,8 @@ void registerLuaRuntimeMethods(lua_State *state, const RuntimeSessionOptions &options, PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker, - ResourceTracker *resourceTracker, GpuAllocator* allocator); + ResourceTracker *resourceTracker, + OutputAllocatorTracker *outputAllocatorTracker, + GpuAllocator *allocator); } // namespace mlirtrt::runtime diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h index d4f07f13a..e7251580f 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h @@ -36,7 +36,8 @@ namespace mlirtrt::runtime { /// `main` function. It is assumed that `main` takes no arguments and returns an /// integer result (which is returned if the execution is successful). /// TODO: this should take a handle to a function for streaming output/errors. -StatusOr runExecutorLuaScript(std::string_view luaScript, GpuAllocator* allocator); +StatusOr runExecutorLuaScript(std::string_view luaScript, + GpuAllocator *allocator); /// Synchronously run a serialized executor Executable one time. An `Executable` /// is essentially a Lua script packaged with metadata and serialized constants @@ -48,12 +49,15 @@ StatusOr runExecutorLuaScript(std::string_view luaScript, GpuAllocator* /// execution is successful). /// TODO: this should take a handle to a function for /// streaming output/errors. -StatusOr runExecutorExecutable(std::unique_ptr executable, std::unique_ptr allocator); +StatusOr +runExecutorExecutable(std::unique_ptr executable, + std::unique_ptr allocator); /// Create an execution state. This will setup a Lua environment and invoke /// global initialization. StatusOr> -createRuntimeSessionWithLuaBackend(ExecutableView executable, std::unique_ptr allocator, +createRuntimeSessionWithLuaBackend(ExecutableView executable, + std::unique_ptr allocator, const RuntimeSessionOptions &options); /// Set the primary stream for the loaded executable to use. diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.h index 1ceb91367..54655ddf7 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.h @@ -37,7 +37,8 @@ class ResourceTracker; /// Lua state. void registerExecutorTensorRTModuleLuaRuntimeMethods( lua_State *luaState, PinnedMemoryAllocator *pinnedMemoryAllocator, - AllocTracker *allocTracker, ResourceTracker *resourceTracker, GpuAllocator* allocator); + AllocTracker *allocTracker, ResourceTracker *resourceTracker, + OutputAllocatorTracker *outputAllocatorTracker, GpuAllocator *allocator); } // namespace mlirtrt::runtime diff --git a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h index 536619ba7..79571f4e9 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h @@ -32,6 +32,10 @@ namespace mlirtrt { struct EventPool; +//===----------------------------------------------------------------------===// +// GpuAllocator and CustomTensorRTAllocator +//===----------------------------------------------------------------------===// + class GpuAllocator { public: GpuAllocator() = default; @@ -56,6 +60,115 @@ class CustomTensorRTAllocator : public GpuAllocator { cudaStream_t* stream) override; }; +//===----------------------------------------------------------------------===// +// OutputAllocator and CustomTensorRTOuputAllocator +//===----------------------------------------------------------------------===// + +//! +//! Class to allocate memory for outputs with data-dependent shapes. The sizes +//! of those are unknown so pre-allocation is not possible. +//! +class OutputAllocator { +public: + virtual ~OutputAllocator() = default; + virtual void setGpuAllocator(GpuAllocator* gpuAllocator) = 0; + virtual void setTensorName(const char *tensorName) = 0; + virtual void setCurrentMemory(void *currentMemory) = 0; + virtual void setOutputSize(const int64_t size) = 0; + virtual void *reallocateOutputAsync(char const *tensorName, + void *currentMemory, uint64_t size, + uint64_t alignment, + cudaStream_t * /*stream*/) = 0; + virtual void notifyShape(char const *tensorName, + std::vector &dims) = 0; +}; + +class CustomTensorRTOuputAllocator : public OutputAllocator { +public: + CustomTensorRTOuputAllocator() = default; + ~CustomTensorRTOuputAllocator() { + if (mOutputPtr != nullptr) { + cudaFree(mOutputPtr); + } + } + + void setGpuAllocator(GpuAllocator* gpuAllocator) override { + mGpuAllocator = gpuAllocator; + } + + //! Methods are called just after construction. TODO: can they be called + //! during construction? + void setTensorName(const char *tensorName) override { + mTensorName = tensorName; + } + + void setCurrentMemory(void *currentMemory) override { + mCurrentMemory = currentMemory; + } + + void setOutputSize(int64_t outputSize) override { mOutputSize = outputSize; } + + void *reallocateOutputAsync(char const *tensorName, void *currentMemory, + uint64_t size, uint64_t alignment, + cudaStream_t * /*stream*/) override; + + void notifyShape(char const *tensorName, + std::vector &dims) override {} + + //! nullptr if memory could not be allocated + void *mOutputPtr{nullptr}; + + //! Size of allocation pointed to by output. + uint64_t mOutputSize{0}; + + bool mReallocateOutputCalled{false}; + + bool mNotifyShapeCalled{false}; + + //! Dimensions of tensor. + std::vector outputDims; + +private: + GpuAllocator* mGpuAllocator; + const char *mTensorName; + void *mCurrentMemory; +}; + +class OutputAllocatorTracker { +public: + OutputAllocatorTracker() = default; + ~OutputAllocatorTracker() = default; + + OutputAllocatorTracker(const OutputAllocatorTracker &) = delete; + OutputAllocatorTracker &operator=(const OutputAllocatorTracker &) = delete; + OutputAllocatorTracker(OutputAllocatorTracker &&) = default; + OutputAllocatorTracker &operator=(OutputAllocatorTracker &&) = default; + + // Add a new OutputAllocator + void addAllocator(std::unique_ptr allocator) { + outputAllocators.emplace_back(std::move(allocator)); + } + + // Get a reference to an OutputAllocator + OutputAllocator *getAllocator(size_t index) { + if (index < outputAllocators.size()) { + return outputAllocators[index].get(); + } + return nullptr; + } + + // Get the last added OutputAllocator + OutputAllocator *getLastAllocator() { + if (!outputAllocators.empty()) { + return outputAllocators.back().get(); + } + return nullptr; + } + +private: + std::vector> outputAllocators; +}; + //===----------------------------------------------------------------------===// // PoolTrackedCudaEvent //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index 52b02f72a..583dc344b 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -358,12 +358,14 @@ RuntimeSession::RuntimeSession( std::unique_ptr pinnedMemoryAllocator, std::unique_ptr allocTracker, std::unique_ptr resourceTracker, + std::unique_ptr outputAllocatorTracker, std::unique_ptr gpuAllocator) : options(std::move(options)), executable(exe), pinnedMemoryAllocator(std::move(pinnedMemoryAllocator)), allocTracker(std::move(allocTracker)), - resourceTracker(std::move(resourceTracker)), gpuAllocator(std::move(gpuAllocator)), - state(std::move(state)) {} + resourceTracker(std::move(resourceTracker)), + outputAllocatorTracker(std::move(outputAllocatorTracker)), + gpuAllocator(std::move(gpuAllocator)), state(std::move(state)) {} //===----------------------------------------------------------------------===// // AllocTracker diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index 17af64a91..78b22190d 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -72,7 +72,8 @@ static void registerDefaultDeviceDependentMethods(lua_State *state, static void registerLuaRuntimeMethodsCommon( lua_State *state, PinnedMemoryAllocator *pinnedMemoryAllocator, - AllocTracker *allocTracker, ResourceTracker *resourceTracker, GpuAllocator* allocator) { + AllocTracker *allocTracker, ResourceTracker *resourceTracker, + GpuAllocator *allocator, OutputAllocatorTracker *outputAllocatorTracker) { registerExecutorCoreModuleLuaRuntimeMethods(state, pinnedMemoryAllocator, allocTracker); registerExecutorCUDAModuleLuaRuntimeMethods( @@ -84,15 +85,15 @@ static void registerLuaRuntimeMethodsCommon( #endif registerExecutorTensorRTModuleLuaRuntimeMethods( - state, pinnedMemoryAllocator, allocTracker, resourceTracker, allocator); + state, pinnedMemoryAllocator, allocTracker, resourceTracker, outputAllocatorTracker, allocator); } void mlirtrt::runtime::registerLuaRuntimeMethods( lua_State *state, const RuntimeSessionOptions &options, PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker, - ResourceTracker *resourceTracker, GpuAllocator* allocator) { + ResourceTracker *resourceTracker, OutputAllocatorTracker* outputAllocatorTracker, GpuAllocator* allocator) { registerLuaRuntimeMethodsCommon(state, pinnedMemoryAllocator, allocTracker, - resourceTracker, allocator); + resourceTracker, allocator, outputAllocatorTracker); #ifdef MLIR_EXECUTOR_ENABLE_NCCL registerExecutorNCCLModuleLuaRuntimeMethods(state, resourceTracker); registerDeviceDependentNCCLMethods(state, options.getNumDevices(), @@ -107,8 +108,8 @@ void mlirtrt::runtime::registerLuaRuntimeMethods( #endif } -StatusOr -mlirtrt::runtime::runExecutorLuaScript(std::string_view luaScript, GpuAllocator* allocator) { +StatusOr mlirtrt::runtime::runExecutorLuaScript( + std::string_view luaScript, GpuAllocator *allocator) { ADD_RUNTIME_MODULE_RANGE("runtime_runExecutorLuaScript"); StatusOr> client = RuntimeClient::create(); @@ -120,7 +121,8 @@ mlirtrt::runtime::runExecutorLuaScript(std::string_view luaScript, GpuAllocator* registerLuaRuntimeMethods(lua.lua_state(), RuntimeSessionOptions(), &(*client)->getPinnedMemorAllocator(), &(*client)->getAllocTracker(), - &(*client)->getResourceTracker(), allocator); + &(*client)->getResourceTracker(), + &(*client)->getOutputAllocatorTracker(), allocator); sol::protected_function_result result = lua.script(luaScript); if (!result.valid()) { @@ -171,7 +173,8 @@ static Status maybeCheckForValidNcclUuid(const RuntimeSessionOptions &options) { /// global initialization. StatusOr> mlirtrt::runtime::createRuntimeSessionWithLuaBackend( - ExecutableView executable, std::unique_ptr allocator, const RuntimeSessionOptions &options) { + ExecutableView executable, std::unique_ptr allocator, + const RuntimeSessionOptions &options) { ADD_RUNTIME_MODULE_RANGE("runtime_loadExecutable"); MTRT_RETURN_IF_ERROR(maybeCheckForValidNcclUuid(options)); @@ -179,12 +182,13 @@ mlirtrt::runtime::createRuntimeSessionWithLuaBackend( auto pinnedMemoryAllocator = std::make_unique(); auto allocTracker = std::make_unique(); auto resourceTracker = std::make_unique(); + auto outputAllocatorTracker = std::make_unique(); sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::string); - registerLuaRuntimeMethods(lua.lua_state(), options, - pinnedMemoryAllocator.get(), allocTracker.get(), - resourceTracker.get(), allocator.get()); + registerLuaRuntimeMethods( + lua.lua_state(), options, pinnedMemoryAllocator.get(), allocTracker.get(), + resourceTracker.get(), outputAllocatorTracker.get(), allocator.get()); // Load globals into the context. // TODO: eliminate this copy, we already own the executable. @@ -225,11 +229,13 @@ mlirtrt::runtime::createRuntimeSessionWithLuaBackend( } return std::make_unique( options, executable, std::move(lua), std::move(pinnedMemoryAllocator), - std::move(allocTracker), std::move(resourceTracker), std::move(allocator)); + std::move(allocTracker), std::move(resourceTracker), + std::move(outputAllocatorTracker), std::move(allocator)); } StatusOr mlirtrt::runtime::runExecutorExecutable( - std::unique_ptr executable, std::unique_ptr allocator) { + std::unique_ptr executable, + std::unique_ptr allocator) { StatusOr> client = RuntimeClient::create(); if (!client.isOk()) @@ -245,7 +251,8 @@ StatusOr mlirtrt::runtime::runExecutorExecutable( return options.getStatus(); StatusOr> session = - createRuntimeSessionWithLuaBackend(executable->getView(), std::move(allocator), *options); + createRuntimeSessionWithLuaBackend(executable->getView(), + std::move(allocator), *options); if (!session.isOk()) return session.getStatus(); diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp index 1ae3b4440..e822e1e7a 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp @@ -64,6 +64,109 @@ class StdioLogger : public nvinfer1::ILogger { bool verbose; }; +//===----------------------------------------------------------------------===// +// TensorRTCallBackOutputAllocator +//===----------------------------------------------------------------------===// + +static bool isSubByte(nvinfer1::DataType t) { + return t == nvinfer1::DataType::kINT4; +} + +static int32_t elementSizeInBits(nvinfer1::DataType t) { + switch (t) { + case nvinfer1::DataType::kINT64: + return 64; + case nvinfer1::DataType::kINT32: + return 32; + case nvinfer1::DataType::kFLOAT: + return 32; + case nvinfer1::DataType::kHALF: + return 16; + case nvinfer1::DataType::kBF16: + return 16; + case nvinfer1::DataType::kINT8: + return 8; + case nvinfer1::DataType::kBOOL: + return 8; + case nvinfer1::DataType::kUINT8: + return 8; + case nvinfer1::DataType::kFP8: + return 8; + case nvinfer1::DataType::kINT4: + return 4; + } + return 0; +} + +static int32_t elementeSizeInBytes(nvinfer1::DataType dtype) { + if (!isSubByte(dtype)) { + auto bits = elementSizeInBits(dtype); + assert(bits % 8 == 0); + return bits / 8; + } + if (dtype == nvinfer1::DataType::kINT4) { + return 1; + } + return -1; +} + +static int64_t volume(nvinfer1::Dims64 const& d) +{ + int64_t v = 1; + for (int64_t i = 0; i < d.nbDims; i++) + { + v *= d.d[i]; + } + return v; +} + +class TensorRTCallBackOutputAllocator final + : public nvinfer1::IOutputAllocator { +public: + TensorRTCallBackOutputAllocator(GpuAllocator* gpuAllocator, OutputAllocator *outputAllocator, + const char *tensorName, void *currentMemory, + nvinfer1::Dims64 dims, + nvinfer1::DataType dtype) + : nvinfer1::IOutputAllocator(), + mOutputAllocatorCallBack(outputAllocator) { + mOutputAllocatorCallBack->setGpuAllocator(gpuAllocator); + mOutputAllocatorCallBack->setTensorName(tensorName); + mOutputAllocatorCallBack->setCurrentMemory(currentMemory); + mOutputAllocatorCallBack->setOutputSize(volume(dims) * + elementeSizeInBytes(dtype)); + } + + void *reallocateOutput(char const *tensorName, void *currentMemory, + uint64_t size, uint64_t alignment) noexcept override { + return mOutputAllocatorCallBack->reallocateOutputAsync( + tensorName, currentMemory, size, alignment, nullptr); + } + + //! IMirroredBuffer does not implement Async allocation, hence this is just a + //! wrap around + void *reallocateOutputAsync(char const *tensorName, void *currentMemory, + uint64_t size, uint64_t alignment, + cudaStream_t stream) noexcept override { + + return mOutputAllocatorCallBack->reallocateOutputAsync( + tensorName, currentMemory, size, alignment, &stream); + } + + void notifyShape(char const *tensorName, + nvinfer1::Dims const &dims) noexcept override { + std::vector dimsVec(dims.nbDims); + for (int32_t i = 0; i < dims.nbDims; ++i) { + dimsVec[i] = dims.d[i]; + } + return mOutputAllocatorCallBack->notifyShape(tensorName, dimsVec); + } + + ~TensorRTCallBackOutputAllocator() override {} + +private: + OutputAllocator *mOutputAllocatorCallBack; +}; + //===----------------------------------------------------------------------===// // TensorRTCallBackAllocator //===----------------------------------------------------------------------===// @@ -86,6 +189,10 @@ class TensorRTCallBackAllocator final : public nvinfer1::IGpuAsyncAllocator { return result; } + GpuAllocator* getCallBackAllocator() { + return mGpuAllocatorCallBack; + } + private: GpuAllocator *mGpuAllocatorCallBack; }; @@ -125,9 +232,11 @@ class NvInferRuntimeWrapper { }); // GpuAllocator is optional. if (gpuAllocator) { - callbackAllocator = std::shared_ptr( - new TensorRTCallBackAllocator(gpuAllocator)); - runtime->setGpuAllocator(callbackAllocator.get()); + callbackAllocatorPair = + std::make_pair(std::shared_ptr( + new TensorRTCallBackAllocator(gpuAllocator)), + gpuAllocator); + runtime->setGpuAllocator(callbackAllocatorPair.first.get()); } } @@ -135,7 +244,7 @@ class NvInferRuntimeWrapper { nvinfer1::IRuntime *operator->() { return runtime.get(); } std::shared_ptr runtime; - std::shared_ptr callbackAllocator; + std::pair, GpuAllocator*> callbackAllocatorPair; }; class NvInferEngineWrapper { @@ -234,6 +343,22 @@ class NvInferExecContextWrapper { /// Returned the pre-allocated host staging buffers. std::vector &getHostIOBuffers() { return hostIOBuffers; } + /// Add a call back output allocator. + void addCallBackAllocators( + std::unique_ptr allocator) { + outputAllocators.emplace_back(std::move(allocator)); + } + + /// Return the last call back output allocator pointer. + TensorRTCallBackOutputAllocator *getLastCallBackAllocatorPtr() { + return outputAllocators.back().get(); + } + + /// Return registered callback gpu allocator. + GpuAllocator *getGpuAllocator() { + return engine->runtime->callbackAllocatorPair.second; + } + private: // We keep a reference to the cuda engine to keep it from going out of scope. // The standard TensorRTRuntime-to-Executor lowering only creates globals for @@ -247,13 +372,14 @@ class NvInferExecContextWrapper { /// A set of pinned host buffers one per input host buffer (shape tensor) to /// the TRT network. std::vector hostIOBuffers; + std::vector> outputAllocators; }; } // namespace -static Status setTensorAddressesOrReport( +static Status setTensorAddressesAndOutputAllocatorsOrReport( NvInferExecContextWrapper &context, const std::vector> - &buffers) { + &buffers, OutputAllocatorTracker &outputAllocatorTracker) { ADD_TENSORRT_MODULE_RANGE("set_tensor_addresses"); unsigned idx = 0; for (auto &[name, ptr, dims] : buffers) { @@ -266,9 +392,10 @@ static Status setTensorAddressesOrReport( bool result = context->setTensorAddress(name.c_str(), reinterpret_cast(ptr)); + const nvinfer1::ICudaEngine &engine = context->getEngine(); + if (!result) { std::stringstream ss; - const nvinfer1::ICudaEngine &engine = context->getEngine(); ss << "Failed to set tensor address for IO tensor: " << name << " at position " << idx << "; the IO tensors are:\n"; for (int64_t i = 0; i < engine.getNbIOTensors(); i++) { @@ -289,6 +416,27 @@ static Status setTensorAddressesOrReport( return getInternalErrorStatus("failed to set input shape"); } + // Set output allocators + if (engine.getTensorIOMode(name.c_str()) == + nvinfer1::TensorIOMode::kOUTPUT and + engine.getTensorLocation(name.c_str()) == + nvinfer1::TensorLocation::kDEVICE) { + + // Allocator custom output allocator - actually implementes alloc/dealloc. + // Tracks output allocators and output buffers. + outputAllocatorTracker.addAllocator( + std::make_unique()); + + context.addCallBackAllocators( + std::make_unique(context.getGpuAllocator(), + outputAllocatorTracker.getLastAllocator(), name.c_str(), + reinterpret_cast(ptr), dims, + engine.getTensorDataType(name.c_str()))); + context->setOutputAllocator(name.c_str(), + static_cast( + context.getLastCallBackAllocatorPtr())); + } + MTRT_DBGF("Set tensor address [%d] = %lu", idx, ptr); idx++; } @@ -390,6 +538,7 @@ prepareBuffers(const AllocTracker &allocTracker, static Status enqueueV3Wrapper(AllocTracker &tracker, ResourceTracker &resourceTracker, + OutputAllocatorTracker &outputAllocatorTracker, NvInferExecContextWrapper &context, CudaStreamPtr stream, sol::table &va) { StatusOr>> @@ -398,8 +547,8 @@ static Status enqueueV3Wrapper(AllocTracker &tracker, return getStatusWithMsg(StatusCode::InternalError, "failed to prepare buffers: ", buffers.getString()); - MTRT_RETURN_IF_ERROR(setTensorAddressesOrReport(context, *buffers)); + MTRT_RETURN_IF_ERROR(setTensorAddressesAndOutputAllocatorsOrReport(context, *buffers, outputAllocatorTracker)); // Create an event that we can wait on for releasing any host-pinned staging // allocations we made. MTRT_ASSIGN_OR_RETURN(CudaEventPtr inputConsumedEvent, @@ -426,7 +575,8 @@ static Status enqueueV3Wrapper(AllocTracker &tracker, //===----------------------------------------------------------------------===// void mlirtrt::runtime::registerExecutorTensorRTModuleLuaRuntimeMethods( lua_State *luaState, PinnedMemoryAllocator *pinnedMemoryAllocator, - AllocTracker *allocTracker, ResourceTracker *resourceTracker, GpuAllocator* allocator) { + AllocTracker *allocTracker, ResourceTracker *resourceTracker, + OutputAllocatorTracker *outputAllocatorTracker, GpuAllocator *allocator) { sol::state_view lua(luaState); lua["_trtrt_create_runtime"] = @@ -464,14 +614,14 @@ void mlirtrt::runtime::registerExecutorTensorRTModuleLuaRuntimeMethods( lua["_trtrt_enqueue"] = [allocTracker, - resourceTracker](sol::this_state state, + resourceTracker, outputAllocatorTracker](sol::this_state state, std::shared_ptr context, CudaStreamPtr stream, sol::table va) { ADD_TENSORRT_MODULE_RANGE("trtrt_enqueue"); sol::state_view luaState(state); assert(context != nullptr); assert(stream != nullptr && "expected valid stream"); - Status result = enqueueV3Wrapper(*allocTracker, *resourceTracker, + Status result = enqueueV3Wrapper(*allocTracker, *resourceTracker, *outputAllocatorTracker, *context, stream, va); SET_LUA_ERROR_IF_ERROR(result, state); }; diff --git a/mlir-tensorrt/executor/lib/Support/Allocators.cpp b/mlir-tensorrt/executor/lib/Support/Allocators.cpp index 100cb0361..ab2eb0787 100644 --- a/mlir-tensorrt/executor/lib/Support/Allocators.cpp +++ b/mlir-tensorrt/executor/lib/Support/Allocators.cpp @@ -42,6 +42,61 @@ using namespace mlirtrt; DEBUG_WITH_TYPE("allocators", fprintf(stderr, "%s:%d " fmt "\n", __FILE__, \ __LINE__, __VA_ARGS__)) +//===----------------------------------------------------------------------===// +// CustomTensorRTOutputAllocator +//===----------------------------------------------------------------------===// + +inline uint64_t roundUp(uint64_t m, uint64_t n) { + return ((m + n - 1) / n) * n; +} + +void *CustomTensorRTOuputAllocator::reallocateOutputAsync( + char const *tensorName, void *currentMemory, uint64_t size, + uint64_t alignment, cudaStream_t *stream) { + + assert(currentMemory == mCurrentMemory && "output buffer mismatch"); + assert(strcmp(tensorName, mTensorName) == 0 && "tensor name mismatch"); + assert(!mReallocateOutputCalled && "duplicate call to reallocateOutput"); + mReallocateOutputCalled = true; + // Some memory allocators return nullptr when allocating zero bytes, but + // TensorRT requires a non-null ptr even for empty tensors, so allocate a + // dummy byte. + size = std::max(size, static_cast(1)); + + // Check if reallocation is required. + if (size > mOutputSize) { + size = roundUp(size, alignment); + + if (mOutputPtr) { + if (mGpuAllocator) { + // Use registeted call back GPU allocator for output allocations. + mGpuAllocator->deallocate(mOutputPtr, stream); + } else { + // Fall-back to local memory management. + cudaFree(mOutputPtr); + } + } + + mOutputPtr = nullptr; + mOutputSize = 0; + + void *memory; + if (mGpuAllocator) { + // Use registeted call back GPU allocator for output allocations. + memory = mGpuAllocator->allocate(size, alignment, 0 /* flags */, stream); + } else { + // Fall-back to local memory management. + cudaMalloc(&memory, size); + } + mOutputPtr = memory; + if (mOutputPtr != nullptr) { + mOutputSize = size; + } + return mOutputPtr; + } + return mCurrentMemory; +} + //===----------------------------------------------------------------------===// // CustomTensorRTAllocator //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp b/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp index dc14db16f..d06b59618 100644 --- a/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp +++ b/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp @@ -179,8 +179,6 @@ executor::ExecutorRunnerMain(int argc, char **argv, allocator.reset(new CustomTensorRTAllocator()); } - // Read the buffer as a Lua script and execute. - if (options.inputType == Lua) { assert(!options.dumpFunctionSignature && "Can not dump function signature for Lua input type."); @@ -213,7 +211,8 @@ executor::ExecutorRunnerMain(int argc, char **argv, } mlirtrt::StatusOr executionResult = - mlirtrt::runtime::runExecutorExecutable(std::move(*executable), std::move(allocator)); + mlirtrt::runtime::runExecutorExecutable(std::move(*executable), + std::move(allocator)); if (!executionResult.isOk()) return emitError(UnknownLoc::get(&context)) << "failed to load and run executable: "