Skip to content

Commit 9334481

Browse files
committed
Improve memory leak
1 parent 6e647ca commit 9334481

File tree

9 files changed

+75
-59
lines changed

9 files changed

+75
-59
lines changed

mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -795,11 +795,11 @@ class AllocTracker {
795795
/// Returns true if the ptr is released internally.
796796
bool isReleasedInternally(uintptr_t ptr) const;
797797

798-
/// Set the pointer is allocated by TensorRT.
799-
void setTensorRTAllocated(uintptr_t ptr);
798+
/// Mark pointer for release after consumption
799+
void markForReleaseAfterConsumption(uintptr_t ptr);
800800

801-
/// Get that pointer is allocated by TensorRT.
802-
bool getTensorRTAllocated(uintptr_t ptr);
801+
/// Check if pointer is marked for release after consumption
802+
bool isMarkedForReleaseAfterConsumption(uintptr_t ptr);
803803

804804
private:
805805
struct Metadata {
@@ -808,7 +808,7 @@ class AllocTracker {
808808
// if this is true then it should be truelly released and untracked
809809
// when decrementExternalCount causes count to go to zero
810810
bool releasedInternally{false};
811-
bool tensorrtAllocated{false};
811+
bool releaseAfterConsumption{false};
812812
PointerInfo info;
813813
};
814814

@@ -877,7 +877,7 @@ class RuntimeSession {
877877

878878
ExecutableView getExecutable() const { return executable; }
879879

880-
PinnedMemoryAllocator &getPinnedMemorAllocator() {
880+
PinnedMemoryAllocator &getPinnedMemoryAllocator() {
881881
return *pinnedMemoryAllocator;
882882
}
883883

@@ -975,7 +975,7 @@ class RuntimeClient {
975975
ResourceTracker &getResourceTracker() { return resourceTracker; }
976976

977977
/// Return the PinnedMemoryAllocator.
978-
PinnedMemoryAllocator &getPinnedMemorAllocator() {
978+
PinnedMemoryAllocator &getPinnedMemoryAllocator() {
979979
return pinnedMemoryAllocator;
980980
}
981981

mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,11 @@ class PinnedMemoryAllocator {
104104
PinnedMemoryAllocator();
105105
~PinnedMemoryAllocator();
106106

107-
/// Untracks
107+
/// Marks a pointer as client-managed, deferring its deallocation
108+
/// This method is used when a pinned memory pointer is returned to the client
109+
/// and its lifecycle is no longer managed by the PinnedMemoryAllocator.
110+
/// Pointers marked this way will not be automatically freed in the
111+
/// allocator's destructor.
108112
void untrack(uintptr_t ptr);
109113

110114
StatusOr<PinnedMemoryBlock> allocate(size_t size);
@@ -117,8 +121,8 @@ class PinnedMemoryAllocator {
117121
private:
118122
EventPool eventPool;
119123

120-
/// Tracks all the pointers which need not to freed up.
121-
static std::vector<uintptr_t> untrackedPtrs;
124+
/// Stores pointers to memory blocks that are now managed by the client.
125+
static std::vector<uintptr_t> clientManagedPtrs;
122126

123127
/// Tracks all blocks allocated by the allocator.
124128
struct BlockTracker;

mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir-executor/Runtime/API/API.h"
2828
#include "mlir-executor/Runtime/API/ExecutableFlatbuffer.h"
2929
#include "mlir-executor/Runtime/Backend/Lua/LuaRuntime.h"
30+
#include "mlir-executor/Runtime/Support/Support.h"
3031
#include "mlir-executor/Support/Status.h"
3132
#include "mlir/Support/FileUtilities.h"
3233
#include "llvm/Support/Debug.h"
@@ -324,9 +325,9 @@ MTRT_Status mtrtMemRefCreateExternal(
324325

325326
MTRT_Status mtrtMemRefValueDestroyAsync(MTRT_MemRefValue buffer,
326327
MTRT_Stream stream) {
327-
328328
MemRefValue *memref = unwrap(buffer);
329-
llvm::dbgs() << "[MLIR-TRT] Deallocating memref pointer " << memref->getMemory() << "\n";
329+
MTRT_DBGF("destroying memref pointer 0x%lx asynchronously",
330+
memref->getMemory());
330331
Status s = memref->getClient()->deallocate(
331332
std::unique_ptr<MemRefValue>(memref),
332333
mtrtStreamIsNull(stream) ? std::nullopt
@@ -338,7 +339,7 @@ MTRT_Status mtrtMemRefValueDestroyAsync(MTRT_MemRefValue buffer,
338339

339340
MTRT_Status mtrtMemRefValueDestroy(MTRT_MemRefValue buffer) {
340341
MemRefValue *memref = unwrap(buffer);
341-
llvm::dbgs() << "[MLIR-TRT] Deallocating memref pointer " << memref->getMemory() << "\n";
342+
MTRT_DBGF("destroying memref pointer 0x%lx", memref->getMemory());
342343
Status s =
343344
memref->getClient()->deallocate(std::unique_ptr<MemRefValue>(memref));
344345
if (!s.isOk())

mlir-tensorrt/executor/lib/Runtime/API/API.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,18 +396,18 @@ AllocTracker::~AllocTracker() {
396396
MTRT_DBGF("freed %zu bytes of unfreed memory", totalSize);
397397
}
398398

399-
void AllocTracker::setTensorRTAllocated(uintptr_t ptr) {
399+
void AllocTracker::markForReleaseAfterConsumption(uintptr_t ptr) {
400400
assert(llvm::is_contained(map, ptr) &&
401401
llvm::formatv("Untracked pointer {0}", ptr).str().c_str());
402402
std::unique_ptr<Metadata> const &metadata = map.at(ptr);
403-
metadata->tensorrtAllocated = true;
403+
metadata->releaseAfterConsumption = true;
404404
}
405405

406-
bool AllocTracker::getTensorRTAllocated(uintptr_t ptr) {
406+
bool AllocTracker::isMarkedForReleaseAfterConsumption(uintptr_t ptr) {
407407
assert(llvm::is_contained(map, ptr) &&
408408
llvm::formatv("Untracked pointer {0}", ptr).str().c_str());
409409
std::unique_ptr<Metadata> const &metadata = map.at(ptr);
410-
return metadata->tensorrtAllocated;
410+
return metadata->releaseAfterConsumption;
411411
}
412412

413413
void AllocTracker::markReleasedInternally(uintptr_t ptr) {
@@ -486,8 +486,8 @@ void AllocTracker::track(PointerInfo info) {
486486
auto value = std::make_unique<Metadata>();
487487
value->externalReferenceCount.store(0);
488488
value->releasedInternally = false;
489+
value->releaseAfterConsumption = false;
489490
value->info = info;
490-
value->tensorrtAllocated = false;
491491
if (!contains(info.ptr)) {
492492
map.insert(std::make_pair(info.ptr, std::move(value)));
493493
return;
@@ -1001,7 +1001,7 @@ RuntimeClient::copyToDevice(const MemRefValue &hostBufferImpl,
10011001
// TODO: Currently, this implementation supports only row major packed
10021002
// canonical layout (no padding).
10031003
StatusOr<mlirtrt::PinnedMemoryBlock> pinnedMemory =
1004-
this->getPinnedMemorAllocator().allocate(totalBufferSize);
1004+
this->getPinnedMemoryAllocator().allocate(totalBufferSize);
10051005
if (!pinnedMemory.isOk())
10061006
return pinnedMemory.getStatus();
10071007

@@ -1020,7 +1020,7 @@ RuntimeClient::copyToDevice(const MemRefValue &hostBufferImpl,
10201020
reinterpret_cast<cudaStream_t>(*cudaStream)));
10211021

10221022
// Free pinned host memory asynchronously.
1023-
getPinnedMemorAllocator().freeAsync(pinnedMemory->ptr, *cudaStream);
1023+
getPinnedMemoryAllocator().freeAsync(pinnedMemory->ptr, *cudaStream);
10241024
} else {
10251025
MTRT_DBG("synchronously copying {0} (host) to {1} (device), size={2} bytes",
10261026
hostBufferImpl.getVoidPtr(), (*deviceMemRef)->getVoidPtr(),

mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ LuaRuntimeSession::create(RuntimeSessionOptions options,
152152

153153
// Register builtin methods.
154154
registerLuaRuntimeMethods(lua.lua_state(), session->getOptions(),
155-
&session->getPinnedMemorAllocator(),
155+
&session->getPinnedMemoryAllocator(),
156156
&session->getAllocTracker(),
157157
&session->getResourceTracker());
158158

@@ -624,22 +624,22 @@ parseResults(const sol::protected_function_result &pfr,
624624
"This ptr is registered with the session and will now be tracked "
625625
"by the client as well.",
626626
allocPtr, static_cast<void *>(&session.getAllocTracker()),
627-
static_cast<void *>(&session.getPinnedMemorAllocator()),
627+
static_cast<void *>(&session.getPinnedMemoryAllocator()),
628628
static_cast<void *>(*client),
629629
static_cast<void *>(&(*client)->getAllocTracker()));
630630

631631
// We need here actually is to "release" the pointer from the session
632632
// ownership and have the client assume
633633
PointerInfo info = session.getAllocTracker().get(allocPtr);
634634
session.getAllocTracker().untrack(info.ptr);
635-
636-
// It is possible that pinned memory also tracks the memory for
637-
// deallocation.
638-
session.getPinnedMemorAllocator().untrack(info.ptr);
639-
640-
AllocTracker &allocator = (*client)->getAllocTracker();
641-
// if (!allocator.contains(info.ptr))
642-
allocator.track(info);
635+
(*client)->getAllocTracker().track(info);
636+
637+
// Defer deallocation of this pinned memory pointer
638+
// This pointer is likely still in use by the client and should not be
639+
// immediately freed. By untracking it here, we ensure it won't be
640+
// deallocated in the PinnedMemoryAllocator's destructor, allowing
641+
// the client to manage its lifecycle.
642+
session.getPinnedMemoryAllocator().untrack(info.ptr);
643643

644644
// Create a memref so that client now tracks it.
645645
auto memref = MemRefValue::create(

mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,14 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
432432
cudaMemcpyDeviceToHost,
433433
stream),
434434
state);
435-
if (allocTracker->getTensorRTAllocated(
436-
reinterpret_cast<uintptr_t>(srcPtr))) {
437-
// Free tensorrt allocate source pointer, since there it won't be
438-
// released by external memref.
439-
SET_LUA_ERROR_IF_CUDART_ERROR(cudaFreeAsync(srcPtr, stream), state);
440-
allocTracker->untrack(reinterpret_cast<uintptr_t>(srcPtr));
435+
// Check if the source pointer is marked for release after consumption
436+
if (allocTracker->isMarkedForReleaseAfterConsumption(src)) {
437+
// This pointer was allocated by TensorRT and used in a device-device
438+
// or device-host copy operation. It's not wrapped in a memref, so it
439+
// won't be released by external memref destruction. We need to
440+
// explicitly free it.
441+
SET_LUA_ERROR_IF_ERROR(runtime::safeDeallocate(*allocTracker, src),
442+
state);
441443
}
442444
};
443445

@@ -487,12 +489,14 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
487489
cudaMemcpyDeviceToHost,
488490
stream),
489491
state);
490-
if (allocTracker->getTensorRTAllocated(
491-
reinterpret_cast<uintptr_t>(srcPtr))) {
492-
// Free tensorrt allocate source pointer, since there it won't be
493-
// released by external memref.
494-
SET_LUA_ERROR_IF_CUDART_ERROR(cudaFreeAsync(srcPtr, stream), state);
495-
allocTracker->untrack(reinterpret_cast<uintptr_t>(srcPtr));
492+
// Check if the source pointer is marked for release after consumption
493+
if (allocTracker->isMarkedForReleaseAfterConsumption(src)) {
494+
// This pointer was allocated by TensorRT and used in a device-device
495+
// or device-host copy operation. It's not wrapped in a memref, so it
496+
// won't be released by external memref destruction. We need to
497+
// explicitly free it.
498+
SET_LUA_ERROR_IF_ERROR(runtime::safeDeallocate(*allocTracker, src),
499+
state);
496500
}
497501
};
498502
lua["__cuda_memcpy_device2device"] = [allocTracker](
@@ -518,12 +522,14 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
518522
cudaMemcpyDeviceToDevice,
519523
stream),
520524
state);
521-
if (allocTracker->getTensorRTAllocated(
522-
reinterpret_cast<uintptr_t>(srcPtr))) {
523-
// Free tensorrt allocate source pointer, since there it won't be
524-
// released by external memref.
525-
SET_LUA_ERROR_IF_CUDART_ERROR(cudaFreeAsync(srcPtr, stream), state);
526-
allocTracker->untrack(reinterpret_cast<uintptr_t>(srcPtr));
525+
// Check if the source pointer is marked for release after consumption
526+
if (allocTracker->isMarkedForReleaseAfterConsumption(src)) {
527+
// This pointer was allocated by TensorRT and used in a device-device
528+
// or device-host copy operation. It's not wrapped in a memref, so it
529+
// won't be released by external memref destruction. We need to
530+
// explicitly free it.
531+
SET_LUA_ERROR_IF_ERROR(runtime::safeDeallocate(*allocTracker, src),
532+
state);
527533
}
528534
return;
529535
};

mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,12 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator {
155155
if (memory.isOk()) {
156156
mOutputPtr = (*memory).ptr;
157157
mOutputSize = memory->size;
158-
mTracker->setTensorRTAllocated(memory->ptr);
158+
// Mark the output pointer for release after consumption
159+
// This is necessary because TensorRT-allocated pointers used in device-device
160+
// or device-host copies may not be wrapped in a memref and tracked by the client.
161+
// By marking it here, we ensure it will be explicitly freed after it's consumed
162+
// in copy operations, preventing memory leaks.
163+
mTracker->markForReleaseAfterConsumption(mOutputPtr);
159164
MTRT_DBGF(
160165
"tensorrt module output allocator allocating %lu bytes at 0x%lx",
161166
mOutputSize, mOutputPtr);

mlir-tensorrt/executor/lib/Support/Allocators.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ static void cudaFreeHostWrapper(uintptr_t ptr) {
206206
#endif
207207
}
208208

209-
std::vector<uintptr_t> PinnedMemoryAllocator::untrackedPtrs;
209+
std::vector<uintptr_t> PinnedMemoryAllocator::clientManagedPtrs;
210210

211211
struct PinnedMemoryAllocator::BlockTracker {
212212
std::set<Block *, BlockComparison> blocks;
@@ -218,13 +218,14 @@ struct PinnedMemoryAllocator::BlockTracker {
218218
"[PinnedMemoryAllocator] Releasing block tracker that has %lu blocks",
219219
blocks.size());
220220
for (Block *block : blocks) {
221-
if (std::find(PinnedMemoryAllocator::untrackedPtrs.begin(),
222-
PinnedMemoryAllocator::untrackedPtrs.end(),
223-
block->ptr) == PinnedMemoryAllocator::untrackedPtrs.end()) {
221+
if (std::find(clientManagedPtrs.begin(), clientManagedPtrs.end(),
222+
block->ptr) == clientManagedPtrs.end()) {
224223
ALLOC_DBGF("[PinnedMemoryAllocator] releasing block %lu of size %lu",
225224
block->ptr, block->size);
226225
cudaFreeHostWrapper(block->ptr);
227226
}
227+
// Blocks found in clientManagedPtrs are not freed here, as they are now
228+
// managed by the client
228229
}
229230
}
230231
};
@@ -251,7 +252,7 @@ StatusOr<PinnedMemoryBlock> PinnedMemoryAllocator::allocate(size_t size) {
251252
if (lowerBound != freeBlocks->set.end()) {
252253
Block *result = *lowerBound;
253254
freeBlocks->set.erase(result);
254-
ALLOC_DBGF("re-using block %lu of size %lu", result->ptr, result->size);
255+
ALLOC_DBGF("re-using block %lx of size %lu", result->ptr, result->size);
255256
return PinnedMemoryBlock{result->ptr, result->size};
256257
}
257258

@@ -275,10 +276,11 @@ StatusOr<PinnedMemoryBlock> PinnedMemoryAllocator::allocate(size_t size) {
275276
#endif
276277
}
277278

278-
// Free the given block.
279+
std::vector<uintptr_t> clientManagedPtrs;
280+
279281
void PinnedMemoryAllocator::untrack(uintptr_t ptr) {
280-
if (!llvm::is_contained(untrackedPtrs, ptr)) {
281-
untrackedPtrs.emplace_back(ptr);
282+
if (!llvm::is_contained(clientManagedPtrs, ptr)) {
283+
clientManagedPtrs.emplace_back(ptr);
282284
}
283285
}
284286

mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,6 @@ static std::unique_ptr<PyMemRefValue>
346346
createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule,
347347
std::optional<bool> assertCanonicalStrides) {
348348

349-
llvm::dbgs() << "Creating a memref view from DL pack tensors\n";
350-
351349
DLManagedTensor *managedTensor = static_cast<DLManagedTensor *>(
352350
PyCapsule_GetPointer(capsule.ptr(), "dltensor"));
353351

0 commit comments

Comments
 (0)