@@ -191,6 +191,9 @@ class PyRuntimeValue : public PyMTRTWrapper<PyRuntimeValue, MTRT_RuntimeValue> {
191
191
// routing.
192
192
class PyGpuAllocator {
193
193
public:
194
+ py::object pySelf;
195
+ PyGpuAllocator (py::object self) : pySelf(self) {}
196
+
194
197
virtual ~PyGpuAllocator () = default ;
195
198
virtual std::uintptr_t allocate (uint64_t size) = 0;
196
199
virtual bool deallocate (std::uintptr_t ptr) = 0;
@@ -201,6 +204,7 @@ class PyGpuAllocator {
201
204
private:
202
205
// Trampoline function: Routes C-style allocation calls to C++ virtual method.
203
206
static void *pyGpuAllocatorAllocate (void *self, uint64_t size) {
207
+ py::gil_scoped_acquire acquire;
204
208
auto *allocator = static_cast <PyGpuAllocator *>(self);
205
209
std::uintptr_t ptr = allocator->allocate (size);
206
210
return reinterpret_cast <void *>(ptr);
@@ -209,6 +213,7 @@ class PyGpuAllocator {
209
213
// Trampoline function: Routes C-style deallocation calls to C++ virtual
210
214
// method.
211
215
static bool pyGpuAllocatorDeallocate (void *self, void *memory) {
216
+ py::gil_scoped_acquire acquire;
212
217
auto *allocator = static_cast <PyGpuAllocator *>(self);
213
218
return allocator->deallocate (reinterpret_cast <std::uintptr_t >(memory));
214
219
}
@@ -969,7 +974,8 @@ PYBIND11_MODULE(_api, m) {
969
974
py::arg (" nccl_uuid" ) = py::str (" " ));
970
975
971
976
py::class_<PyGpuAllocator, PyGpuAllocatorTrampoline>(m, " GpuAllocator" )
972
- .def (py::init<>())
977
+ .def (py::init<>(
978
+ [](py::object self) { return new PyGpuAllocatorTrampoline (self); }))
973
979
.def (" allocate" , &PyGpuAllocator::allocate)
974
980
.def (" deallocate" , &PyGpuAllocator::deallocate)
975
981
.def (" get_capi_object" , &PyGpuAllocator::getCApiObject);
@@ -983,7 +989,8 @@ PYBIND11_MODULE(_api, m) {
983
989
if (gpu_allocator.is_none ()) {
984
990
// Create session without custom allocator
985
991
s = mtrtRuntimeSessionCreate (
986
- options, exe, MTRT_GpuAllocator{nullptr }, &session);
992
+ options, exe, MTRT_GpuAllocator{nullptr , nullptr , nullptr },
993
+ &session);
987
994
} else {
988
995
try {
989
996
PyGpuAllocator &allocator =
0 commit comments