Skip to content

Commit

Permalink
Add IOutputAllocator
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Aug 22, 2024
1 parent 539fb56 commit 94cbb5a
Show file tree
Hide file tree
Showing 10 changed files with 375 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ class RuntimeSession {
std::unique_ptr<PinnedMemoryAllocator> pinnedMemoryAllocator,
std::unique_ptr<AllocTracker> allocTracker,
std::unique_ptr<ResourceTracker> resourceTracker,
std::unique_ptr<OutputAllocatorTracker> outputAllocatorTracker,
std::unique_ptr<GpuAllocator> gpuAllocator);

ExecutableView getExecutable() const { return executable; }
Expand All @@ -891,6 +892,7 @@ class RuntimeSession {
std::unique_ptr<PinnedMemoryAllocator> pinnedMemoryAllocator;
std::unique_ptr<AllocTracker> allocTracker;
std::unique_ptr<ResourceTracker> resourceTracker;
std::unique_ptr<OutputAllocatorTracker> outputAllocatorTracker;
std::unique_ptr<GpuAllocator> gpuAllocator;
sol::state state;
};
Expand Down Expand Up @@ -973,6 +975,10 @@ class RuntimeClient {
return pinnedMemoryAllocator;
}

OutputAllocatorTracker& getOutputAllocatorTracker() {
return outputAllocatorTracker;
}

private:
RuntimeClient(llvm::SmallVector<std::unique_ptr<Device>> devices)
: devices(std::move(devices)) {}
Expand All @@ -981,6 +987,7 @@ class RuntimeClient {
PinnedMemoryAllocator pinnedMemoryAllocator;
AllocTracker allocTracker;
ResourceTracker resourceTracker;
OutputAllocatorTracker outputAllocatorTracker;
};

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ void registerLuaRuntimeMethods(lua_State *state,
const RuntimeSessionOptions &options,
PinnedMemoryAllocator *pinnedMemoryAllocator,
AllocTracker *allocTracker,
ResourceTracker *resourceTracker, GpuAllocator* allocator);
ResourceTracker *resourceTracker,
OutputAllocatorTracker *outputAllocatorTracker,
GpuAllocator *allocator);

} // namespace mlirtrt::runtime
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace mlirtrt::runtime {
/// `main` function. It is assumed that `main` takes no arguments and returns an
/// integer result (which is returned if the execution is successful).
/// TODO: this should take a handle to a function for streaming output/errors.
StatusOr<int64_t> runExecutorLuaScript(std::string_view luaScript, GpuAllocator* allocator);
StatusOr<int64_t> runExecutorLuaScript(std::string_view luaScript,
GpuAllocator *allocator);

/// Synchronously run a serialized executor Executable one time. An `Executable`
/// is essentially a Lua script packaged with metadata and serialized constants
Expand All @@ -48,12 +49,15 @@ StatusOr<int64_t> runExecutorLuaScript(std::string_view luaScript, GpuAllocator*
/// execution is successful).
/// TODO: this should take a handle to a function for
/// streaming output/errors.
StatusOr<int64_t> runExecutorExecutable(std::unique_ptr<Executable> executable, std::unique_ptr<GpuAllocator> allocator);
StatusOr<int64_t>
runExecutorExecutable(std::unique_ptr<Executable> executable,
std::unique_ptr<GpuAllocator> allocator);

/// Create an execution state. This will setup a Lua environment and invoke
/// global initialization.
StatusOr<std::unique_ptr<RuntimeSession>>
createRuntimeSessionWithLuaBackend(ExecutableView executable, std::unique_ptr<GpuAllocator> allocator,
createRuntimeSessionWithLuaBackend(ExecutableView executable,
std::unique_ptr<GpuAllocator> allocator,
const RuntimeSessionOptions &options);

/// Set the primary stream for the loaded executable to use.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class ResourceTracker;
/// Lua state.
void registerExecutorTensorRTModuleLuaRuntimeMethods(
lua_State *luaState, PinnedMemoryAllocator *pinnedMemoryAllocator,
AllocTracker *allocTracker, ResourceTracker *resourceTracker, GpuAllocator* allocator);
AllocTracker *allocTracker, ResourceTracker *resourceTracker,
OutputAllocatorTracker *outputAllocatorTracker, GpuAllocator *allocator);

} // namespace mlirtrt::runtime

Expand Down
113 changes: 113 additions & 0 deletions mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ namespace mlirtrt {

struct EventPool;

//===----------------------------------------------------------------------===//
// GpuAllocator and CustomTensorRTAllocator
//===----------------------------------------------------------------------===//

class GpuAllocator {
public:
GpuAllocator() = default;
Expand All @@ -56,6 +60,115 @@ class CustomTensorRTAllocator : public GpuAllocator {
cudaStream_t* stream) override;
};

//===----------------------------------------------------------------------===//
// OutputAllocator and CustomTensorRTOuputAllocator
//===----------------------------------------------------------------------===//

//!
//! Class to allocate memory for outputs with data-dependent shapes. The sizes
//! of those are unknown so pre-allocation is not possible.
//!
class OutputAllocator {
public:
virtual ~OutputAllocator() = default;
virtual void setGpuAllocator(GpuAllocator* gpuAllocator) = 0;
virtual void setTensorName(const char *tensorName) = 0;
virtual void setCurrentMemory(void *currentMemory) = 0;
virtual void setOutputSize(const int64_t size) = 0;
virtual void *reallocateOutputAsync(char const *tensorName,
void *currentMemory, uint64_t size,
uint64_t alignment,
cudaStream_t * /*stream*/) = 0;
virtual void notifyShape(char const *tensorName,
std::vector<int64_t> &dims) = 0;
};

class CustomTensorRTOuputAllocator : public OutputAllocator {
public:
CustomTensorRTOuputAllocator() = default;
~CustomTensorRTOuputAllocator() {
if (mOutputPtr != nullptr) {
cudaFree(mOutputPtr);
}
}

void setGpuAllocator(GpuAllocator* gpuAllocator) override {
mGpuAllocator = gpuAllocator;
}

//! Methods are called just after construction. TODO: can they be called
//! during construction?
void setTensorName(const char *tensorName) override {
mTensorName = tensorName;
}

void setCurrentMemory(void *currentMemory) override {
mCurrentMemory = currentMemory;
}

void setOutputSize(int64_t outputSize) override { mOutputSize = outputSize; }

void *reallocateOutputAsync(char const *tensorName, void *currentMemory,
uint64_t size, uint64_t alignment,
cudaStream_t * /*stream*/) override;

void notifyShape(char const *tensorName,
std::vector<int64_t> &dims) override {}

//! nullptr if memory could not be allocated
void *mOutputPtr{nullptr};

//! Size of allocation pointed to by output.
uint64_t mOutputSize{0};

bool mReallocateOutputCalled{false};

bool mNotifyShapeCalled{false};

//! Dimensions of tensor.
std::vector<int64_t> outputDims;

private:
GpuAllocator* mGpuAllocator;
const char *mTensorName;
void *mCurrentMemory;
};

class OutputAllocatorTracker {
public:
OutputAllocatorTracker() = default;
~OutputAllocatorTracker() = default;

OutputAllocatorTracker(const OutputAllocatorTracker &) = delete;
OutputAllocatorTracker &operator=(const OutputAllocatorTracker &) = delete;
OutputAllocatorTracker(OutputAllocatorTracker &&) = default;
OutputAllocatorTracker &operator=(OutputAllocatorTracker &&) = default;

// Add a new OutputAllocator
void addAllocator(std::unique_ptr<OutputAllocator> allocator) {
outputAllocators.emplace_back(std::move(allocator));
}

// Get a reference to an OutputAllocator
OutputAllocator *getAllocator(size_t index) {
if (index < outputAllocators.size()) {
return outputAllocators[index].get();
}
return nullptr;
}

// Get the last added OutputAllocator
OutputAllocator *getLastAllocator() {
if (!outputAllocators.empty()) {
return outputAllocators.back().get();
}
return nullptr;
}

private:
std::vector<std::unique_ptr<OutputAllocator>> outputAllocators;
};

//===----------------------------------------------------------------------===//
// PoolTrackedCudaEvent
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 4 additions & 2 deletions mlir-tensorrt/executor/lib/Runtime/API/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,14 @@ RuntimeSession::RuntimeSession(
std::unique_ptr<PinnedMemoryAllocator> pinnedMemoryAllocator,
std::unique_ptr<AllocTracker> allocTracker,
std::unique_ptr<ResourceTracker> resourceTracker,
std::unique_ptr<OutputAllocatorTracker> outputAllocatorTracker,
std::unique_ptr<GpuAllocator> gpuAllocator)
: options(std::move(options)), executable(exe),
pinnedMemoryAllocator(std::move(pinnedMemoryAllocator)),
allocTracker(std::move(allocTracker)),
resourceTracker(std::move(resourceTracker)), gpuAllocator(std::move(gpuAllocator)),
state(std::move(state)) {}
resourceTracker(std::move(resourceTracker)),
outputAllocatorTracker(std::move(outputAllocatorTracker)),
gpuAllocator(std::move(gpuAllocator)), state(std::move(state)) {}

//===----------------------------------------------------------------------===//
// AllocTracker
Expand Down
35 changes: 21 additions & 14 deletions mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ static void registerDefaultDeviceDependentMethods(lua_State *state,

static void registerLuaRuntimeMethodsCommon(
lua_State *state, PinnedMemoryAllocator *pinnedMemoryAllocator,
AllocTracker *allocTracker, ResourceTracker *resourceTracker, GpuAllocator* allocator) {
AllocTracker *allocTracker, ResourceTracker *resourceTracker,
GpuAllocator *allocator, OutputAllocatorTracker *outputAllocatorTracker) {
registerExecutorCoreModuleLuaRuntimeMethods(state, pinnedMemoryAllocator,
allocTracker);
registerExecutorCUDAModuleLuaRuntimeMethods(
Expand All @@ -84,15 +85,15 @@ static void registerLuaRuntimeMethodsCommon(
#endif

registerExecutorTensorRTModuleLuaRuntimeMethods(
state, pinnedMemoryAllocator, allocTracker, resourceTracker, allocator);
state, pinnedMemoryAllocator, allocTracker, resourceTracker, outputAllocatorTracker, allocator);
}

void mlirtrt::runtime::registerLuaRuntimeMethods(
lua_State *state, const RuntimeSessionOptions &options,
PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker,
ResourceTracker *resourceTracker, GpuAllocator* allocator) {
ResourceTracker *resourceTracker, OutputAllocatorTracker* outputAllocatorTracker, GpuAllocator* allocator) {
registerLuaRuntimeMethodsCommon(state, pinnedMemoryAllocator, allocTracker,
resourceTracker, allocator);
resourceTracker, allocator, outputAllocatorTracker);
#ifdef MLIR_EXECUTOR_ENABLE_NCCL
registerExecutorNCCLModuleLuaRuntimeMethods(state, resourceTracker);
registerDeviceDependentNCCLMethods(state, options.getNumDevices(),
Expand All @@ -107,8 +108,8 @@ void mlirtrt::runtime::registerLuaRuntimeMethods(
#endif
}

StatusOr<int64_t>
mlirtrt::runtime::runExecutorLuaScript(std::string_view luaScript, GpuAllocator* allocator) {
StatusOr<int64_t> mlirtrt::runtime::runExecutorLuaScript(
std::string_view luaScript, GpuAllocator *allocator) {
ADD_RUNTIME_MODULE_RANGE("runtime_runExecutorLuaScript");

StatusOr<std::unique_ptr<RuntimeClient>> client = RuntimeClient::create();
Expand All @@ -120,7 +121,8 @@ mlirtrt::runtime::runExecutorLuaScript(std::string_view luaScript, GpuAllocator*
registerLuaRuntimeMethods(lua.lua_state(), RuntimeSessionOptions(),
&(*client)->getPinnedMemorAllocator(),
&(*client)->getAllocTracker(),
&(*client)->getResourceTracker(), allocator);
&(*client)->getResourceTracker(),
&(*client)->getOutputAllocatorTracker(), allocator);

sol::protected_function_result result = lua.script(luaScript);
if (!result.valid()) {
Expand Down Expand Up @@ -171,20 +173,22 @@ static Status maybeCheckForValidNcclUuid(const RuntimeSessionOptions &options) {
/// global initialization.
StatusOr<std::unique_ptr<RuntimeSession>>
mlirtrt::runtime::createRuntimeSessionWithLuaBackend(
ExecutableView executable, std::unique_ptr<GpuAllocator> allocator, const RuntimeSessionOptions &options) {
ExecutableView executable, std::unique_ptr<GpuAllocator> allocator,
const RuntimeSessionOptions &options) {
ADD_RUNTIME_MODULE_RANGE("runtime_loadExecutable");

MTRT_RETURN_IF_ERROR(maybeCheckForValidNcclUuid(options));

auto pinnedMemoryAllocator = std::make_unique<PinnedMemoryAllocator>();
auto allocTracker = std::make_unique<AllocTracker>();
auto resourceTracker = std::make_unique<ResourceTracker>();
auto outputAllocatorTracker = std::make_unique<OutputAllocatorTracker>();

sol::state lua;
lua.open_libraries(sol::lib::base, sol::lib::string);
registerLuaRuntimeMethods(lua.lua_state(), options,
pinnedMemoryAllocator.get(), allocTracker.get(),
resourceTracker.get(), allocator.get());
registerLuaRuntimeMethods(
lua.lua_state(), options, pinnedMemoryAllocator.get(), allocTracker.get(),
resourceTracker.get(), outputAllocatorTracker.get(), allocator.get());

// Load globals into the context.
// TODO: eliminate this copy, we already own the executable.
Expand Down Expand Up @@ -225,11 +229,13 @@ mlirtrt::runtime::createRuntimeSessionWithLuaBackend(
}
return std::make_unique<RuntimeSession>(
options, executable, std::move(lua), std::move(pinnedMemoryAllocator),
std::move(allocTracker), std::move(resourceTracker), std::move(allocator));
std::move(allocTracker), std::move(resourceTracker),
std::move(outputAllocatorTracker), std::move(allocator));
}

StatusOr<int64_t> mlirtrt::runtime::runExecutorExecutable(
std::unique_ptr<Executable> executable, std::unique_ptr<GpuAllocator> allocator) {
std::unique_ptr<Executable> executable,
std::unique_ptr<GpuAllocator> allocator) {

StatusOr<std::unique_ptr<RuntimeClient>> client = RuntimeClient::create();
if (!client.isOk())
Expand All @@ -245,7 +251,8 @@ StatusOr<int64_t> mlirtrt::runtime::runExecutorExecutable(
return options.getStatus();

StatusOr<std::unique_ptr<RuntimeSession>> session =
createRuntimeSessionWithLuaBackend(executable->getView(), std::move(allocator), *options);
createRuntimeSessionWithLuaBackend(executable->getView(),
std::move(allocator), *options);
if (!session.isOk())
return session.getStatus();

Expand Down
Loading

0 comments on commit 94cbb5a

Please sign in to comment.