Skip to content

Commit

Permalink
Fixes #614: Reduces code duplication regarding flags between `context…
Browse files Browse the repository at this point in the history
…_t` and `primary_context_t`
  • Loading branch information
eyalroz committed Mar 16, 2024
1 parent 193da75 commit 148e00d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 54 deletions.
1 change: 1 addition & 0 deletions src/cuda/api/current_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ inline context::flags_t get_flags()
context::flags_t result;
auto status = cuCtxGetFlags(&result);
throw_if_error_lazy(status, "Failed obtaining the current context's flags");
// Note: Not sanitizing the flags from having CU_CTX_MAP_HOST set
return result;
}

Expand Down
6 changes: 6 additions & 0 deletions src/cuda/api/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ class device_t {

void set_flags(flags_type new_flags) const
{
new_flags &= ~CU_CTX_MAP_HOST;
// CU_CTX_MAP_HOST is (mostly) ignored since CUDA 3.2, and has been officially
// deprecated in CUDA 11. Moreover, in CUDA 11 (and possibly other versions),
// the flags you get with cuDevicePrimaryCtxGetState() and cuCtxGetFlag()
// differ on this particular flag - and cuDevicePrimaryCtxSetFlags() doesn't
// like seeing it.
auto status = cuDevicePrimaryCtxSetFlags(id(), new_flags);
throw_if_error_lazy(status, "Failed setting (primary context) flags for device " + device::detail_::identify(id_));
}
Expand Down
57 changes: 3 additions & 54 deletions src/cuda/api/primary_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,15 @@ inline state_t raw_state(device::id_t device_id)
{
state_t result;
auto status = cuDevicePrimaryCtxGetState(device_id, &result.flags, &result.is_active);
throw_if_error(status, "Failed obtaining the state of the primary context for " + device::detail_::identify(device_id));
throw_if_error(status, "Failed obtaining the state of the primary context for "
+ device::detail_::identify(device_id));
// Note: Not sanitizing the flags from having CU_CTX_MAP_HOST set
return result;
}

inline context::flags_t flags(device::id_t device_id)
{
return raw_state(device_id).flags & ~CU_CTX_MAP_HOST;
// CU_CTX_MAP_HOST is ignored since CUDA 3.2, and has been officially
// deprecated in CUDA 11. Moreover, in CUDA 11 (and possibly other versions),
// the flags you get with cuDevicePrimaryCtxGetState() and cuCtxGetFlag()
// differ on this particular flag - and cuDevicePrimaryCtxSetFlags() doesn't
// like seeing it.
}

inline bool is_active(device::id_t device_id)
Expand Down Expand Up @@ -141,28 +138,6 @@ class primary_context_t : public context_t {
// proxy, we are not making this context current on construction
// nor expecting it to be current throughout its lifetime.

protected:

// Note: Hides the base class' non-virtual flags() method
context::flags_t flags() const override
{
return primary_context::detail_::flags(device_id_);
}

void set_flags(flags_type new_flags) const
{
auto status = cuDevicePrimaryCtxSetFlags(device_id_, new_flags);
throw_if_error(status, "Failed setting primary context flags for " + device::detail_::identify(device_id_));
}

void set_flags(
device::host_thread_sync_scheduling_policy_t
sync_scheduling_policy = device::host_thread_sync_scheduling_policy_t::heuristic,
bool keep_larger_local_mem_after_resize = true)
{
set_flags(context::detail_::make_flags(sync_scheduling_policy, keep_larger_local_mem_after_resize));
}

public:

stream_t default_stream() const noexcept;
Expand All @@ -182,7 +157,6 @@ class primary_context_t : public context_t {
}
}


primary_context_t(primary_context_t&& other) noexcept = default;

~primary_context_t() NOEXCEPT_IF_NDEBUG
Expand All @@ -201,31 +175,6 @@ class primary_context_t : public context_t {

primary_context_t& operator=(const primary_context_t& other) = delete;
primary_context_t& operator=(primary_context_t&& other) = default;

public: // mutators of the proxied primary context, but not of the proxy

void set_sync_scheduling_policy(context::host_thread_sync_scheduling_policy_t new_policy) const
{
auto other_flags = flags() & ~CU_CTX_SCHED_MASK;
set_flags(other_flags | static_cast<flags_type>(new_policy));
}

bool keeping_larger_local_mem_after_resize() const
{
return flags() & CU_CTX_LMEM_RESIZE_TO_MAX;
}

void keep_larger_local_mem_after_resize(bool keep = true) const
{
auto other_flags = flags() & ~CU_CTX_LMEM_RESIZE_TO_MAX;
flags_type new_flags = other_flags | (keep ? CU_CTX_LMEM_RESIZE_TO_MAX : 0);
set_flags(new_flags);
}

void dont_keep_larger_local_mem_after_resize() const
{
keep_larger_local_mem_after_resize(false);
}
};

namespace primary_context {
Expand Down

0 comments on commit 148e00d

Please sign in to comment.