Skip to content

Commit 153eb76

Browse files
committed
Regards #690: We've extracted the free() calls out of the async sub-namespace + some comment tweaks and redundancy removals
1 parent 74cbfe1 commit 153eb76

File tree

3 files changed

+56
-28
lines changed

3 files changed

+56
-28
lines changed

src/cuda/api/memory.hpp

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -229,41 +229,58 @@ inline region_t allocate(
229229
return allocate_in_current_context(size_in_bytes, stream_handle);
230230
}
231231

232-
} // namespace detail_
233-
234-
/// Free a region of device-side memory (regardless of how it was allocated)
235-
inline void free(void* ptr)
232+
#if CUDA_VERSION >= 11020
233+
inline void free(
234+
context::handle_t context_handle,
235+
void* allocated_region_start,
236+
optional<stream::handle_t> stream_handle = {})
237+
#else
238+
inline void free(
239+
context::handle_t context_handle,
240+
void* allocated_region_start)
241+
#endif
236242
{
237-
auto result = cuMemFree(address(ptr));
243+
#if CUDA_VERSION >= 11020
244+
if (stream_handle) {
245+
auto status = cuMemFreeAsync(device::address(allocated_region_start), *stream_handle);
246+
throw_if_error_lazy(status,
247+
"Failed scheduling an asynchronous freeing of the global memory region starting at "
248+
+ cuda::detail_::ptr_as_hex(allocated_region_start) + " on "
249+
+ stream::detail_::identify(*stream_handle, context_handle));
250+
return;
251+
}
252+
#endif
253+
auto result = cuMemFree(address(allocated_region_start));
238254
#ifdef CAW_THROW_ON_FREE_IN_DESTROYED_CONTEXT
239255
if (result == status::success) { return; }
240256
#else
241257
if (result == status::success or result == status::context_is_destroyed) { return; }
242258
#endif
243-
throw runtime_error(result, "Freeing device memory at " + cuda::detail_::ptr_as_hex(ptr));
259+
throw runtime_error(result, "Freeing device memory at " + cuda::detail_::ptr_as_hex(allocated_region_start));
244260
}
245261

246-
/// @copydoc free(void*)
247-
inline void free(region_t region) { free(region.start()); }
262+
} // namespace detail_
248263

264+
/// Free a region of device-side memory (regardless of how it was allocated)
249265
#if CUDA_VERSION >= 11020
250-
namespace async {
251-
252-
namespace detail_ {
266+
inline void free(void* region_start, optional_ref<const stream_t> stream = {});
267+
#else
268+
inline void free(void* ptr);
269+
#endif
253270

254-
inline void free(
255-
context::handle_t context_handle,
256-
stream::handle_t stream_handle,
257-
void* allocated_region_start)
271+
/// @copydoc free(void*, optional_ref<const stream_t>)
272+
#if CUDA_VERSION >= 11020
273+
inline void free(region_t region, optional_ref<const stream_t> stream = {})
274+
#else
275+
inline void free(region_t region)
276+
#endif
258277
{
259-
auto status = cuMemFreeAsync(device::address(allocated_region_start), stream_handle);
260-
throw_if_error_lazy(status,
261-
"Failed scheduling an asynchronous freeing of the global memory region starting at "
262-
+ cuda::detail_::ptr_as_hex(allocated_region_start) + " on "
263-
+ stream::detail_::identify(stream_handle, context_handle) );
278+
free(region.start(), stream);
264279
}
265280

266-
} // namespace detail_
281+
#if CUDA_VERSION >= 11020
282+
283+
namespace async {
267284

268285
/**
269286
* Schedule a de-allocation of device-side memory on a CUDA stream.

src/cuda/api/multi_wrapper_impls/memory.hpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,24 @@ inline region_t allocate(size_t size_in_bytes, optional_ref<const stream_t> stre
121121
detail_::allocate_in_current_context(size_in_bytes);
122122
}
123123

124-
namespace async {
124+
#endif // CUDA_VERSION >= 11020
125125

126-
inline void free(const stream_t& stream, void* region_start)
126+
#if CUDA_VERSION >= 11020
127+
inline void free(void* region_start, optional_ref<const stream_t> stream)
128+
#else
129+
inline void free(void* ptr)
130+
#endif // CUDA_VERSION >= 11020
127131
{
128-
return detail_::free(stream.context().handle(), stream.handle(), region_start);
132+
auto cch = context::current::detail_::get_handle();
133+
#if CUDA_VERSION >= 11020
134+
if (stream) {
135+
detail_::free(cch, region_start, stream->handle());
136+
}
137+
#endif
138+
detail_::free(cch,region_start);
129139
}
130-
#endif // CUDA_VERSION >= 11020
140+
141+
namespace async {
131142

132143
template <typename T>
133144
inline void typed_set(T* start, const T& value, size_t num_elements, const stream_t& stream)

src/cuda/api/stream.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -601,14 +601,14 @@ class stream_t {
601601
///@{
602602
void free(void* region_start) const
603603
{
604-
memory::device::async::free(associated_stream, region_start);
604+
memory::device::free(region_start, associated_stream);
605605
}
606606

607607
void free(memory::region_t region) const
608608
{
609-
memory::device::async::free(associated_stream, region);
609+
memory::device::free(region, associated_stream);
610610
}
611-
#endif
611+
#endif // CUDA_VERSION >= 11020
612612

613613
/**
614614
* Sets the attachment of a region of managed memory (i.e. in the address space visible

0 commit comments

Comments
 (0)