diff --git a/sycl/source/accessor.cpp b/sycl/source/accessor.cpp index b8c4d82dbc6b3..7321f7fa7c821 100644 --- a/sycl/source/accessor.cpp +++ b/sycl/source/accessor.cpp @@ -14,15 +14,6 @@ namespace sycl { inline namespace _V1 { namespace detail { -device getDeviceFromHandler(handler &cgh) { - assert((cgh.MQueue || getSyclObjImpl(cgh)->MGraph) && - "One of MQueue or MGraph should be nonnull!"); - if (cgh.MQueue) - return cgh.MQueue->get_device(); - - return getSyclObjImpl(cgh)->MGraph->getDevice(); -} - // property::no_init is supported now for // accessor // host_accessor diff --git a/sycl/source/detail/context_impl.hpp b/sycl/source/detail/context_impl.hpp index 5a79591ba87eb..84787ad50c5e2 100644 --- a/sycl/source/detail/context_impl.hpp +++ b/sycl/source/detail/context_impl.hpp @@ -29,7 +29,7 @@ inline namespace _V1 { // Forward declaration class device; namespace detail { -class context_impl : std::enable_shared_from_this { +class context_impl : public std::enable_shared_from_this { struct private_tag { explicit private_tag() = default; }; diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index abb2eaea4bd68..ed406f19fa411 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -504,7 +504,7 @@ graph_impl::add(std::function CGF, std::vector> &Deps) { (void)Args; #ifdef __INTEL_PREVIEW_BREAKING_CHANGES - detail::handler_impl HandlerImpl{shared_from_this()}; + detail::handler_impl HandlerImpl{*this}; sycl::handler Handler{&HandlerImpl, std::shared_ptr{}}; #else sycl::handler Handler{shared_from_this()}; @@ -2284,7 +2284,7 @@ void dynamic_command_group_impl::finalizeCGFList( // Handler defined inside the loop so it doesn't appear to the runtime // as a single command-group with multiple commands inside. #ifdef __INTEL_PREVIEW_BREAKING_CHANGES - detail::handler_impl HandlerImpl{MGraph}; + detail::handler_impl HandlerImpl{*MGraph}; sycl::handler Handler{&HandlerImpl, std::shared_ptr{}}; #else sycl::handler Handler{MGraph}; diff --git a/sycl/source/detail/handler_impl.hpp b/sycl/source/detail/handler_impl.hpp index e54aaa662a335..72fe112375a64 100644 --- a/sycl/source/detail/handler_impl.hpp +++ b/sycl/source/detail/handler_impl.hpp @@ -10,6 +10,7 @@ #include "sycl/handler.hpp" #include +#include #include #include #include @@ -31,15 +32,13 @@ enum class HandlerSubmissionState : std::uint8_t { class handler_impl { public: - handler_impl(queue_impl *SubmissionSecondaryQueue, bool EventNeeded) + handler_impl(queue_impl &Queue, queue_impl *SubmissionSecondaryQueue, + bool EventNeeded) : MSubmissionSecondaryQueue(SubmissionSecondaryQueue), - MEventNeeded(EventNeeded) {}; + MEventNeeded(EventNeeded), MQueueOrGraph{Queue} {}; - handler_impl( - std::shared_ptr Graph) - : MGraph{Graph} {} - - handler_impl() = default; + handler_impl(ext::oneapi::experimental::detail::graph_impl &Graph) + : MQueueOrGraph{Graph} {} void setStateExplicitKernelBundle() { if (MSubmissionState == HandlerSubmissionState::SPEC_CONST_SET_STATE) @@ -165,8 +164,46 @@ class handler_impl { /// manipulations with version are required detail::CGType MCGType = detail::CGType::None; - /// The graph that is associated with this handler. - std::shared_ptr MGraph; + // This handler is associated with either a queue or a graph. + using graph_impl = ext::oneapi::experimental::detail::graph_impl; + const std::variant, + std::reference_wrapper> + MQueueOrGraph; + + queue_impl *get_queue_or_null() { + auto *Queue = + std::get_if>(&MQueueOrGraph); + return Queue ? &Queue->get() : nullptr; + } + queue_impl &get_queue() { + return std::get>(MQueueOrGraph).get(); + } + graph_impl *get_graph_or_null() { + auto *Graph = + std::get_if>(&MQueueOrGraph); + return Graph ? &Graph->get() : nullptr; + } + graph_impl &get_graph() { + return std::get>(MQueueOrGraph).get(); + } + + // Make the following methods templates to avoid circular dependencies for the + // includes. + template detail::device_impl &get_device() { + Self *self = this; + if (auto *Queue = self->get_queue_or_null()) + return Queue->getDeviceImpl(); + else + return self->get_graph().getDeviceImpl(); + } + template context_impl &get_context() { + Self *self = this; + if (auto *Queue = self->get_queue_or_null()) + return *Queue->getContextImplPtr(); + else + return *self->get_graph().getContextImplPtr(); + } + /// If we are submitting a graph using ext_oneapi_graph this will be the graph /// to be executed. std::shared_ptr diff --git a/sycl/source/detail/queue_impl.cpp b/sycl/source/detail/queue_impl.cpp index 4f6b303e677a7..387ad8195b5cb 100644 --- a/sycl/source/detail/queue_impl.cpp +++ b/sycl/source/detail/queue_impl.cpp @@ -310,7 +310,7 @@ queue_impl::submit_impl(const detail::type_erased_cgfo_ty &CGF, const detail::code_location &Loc, bool IsTopCodeLoc, const v1::SubmissionInfo &SubmitInfo) { #ifdef __INTEL_PREVIEW_BREAKING_CHANGES - detail::handler_impl HandlerImplVal(SecondaryQueue, CallerNeedsEvent); + detail::handler_impl HandlerImplVal(*this, SecondaryQueue, CallerNeedsEvent); detail::handler_impl *HandlerImpl = &HandlerImplVal; // Inlining `Self` results in a crash when SYCL RT is built using MSVC with // optimizations enabled. No crash if built using OneAPI. diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 59cbcb5384b8e..4dd9cd65791f3 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -43,13 +43,15 @@ inline namespace _V1 { namespace detail { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES +// TODO: Check if two ABI exports below are still necessary. +#endif device_impl &getDeviceImplFromHandler(handler &CGH) { - assert((CGH.MQueue || getSyclObjImpl(CGH)->MGraph) && - "One of MQueue or MGraph should be nonnull!"); - if (CGH.MQueue) - return CGH.MQueue->getDeviceImpl(); + return getSyclObjImpl(CGH)->get_device(); +} - return getSyclObjImpl(CGH)->MGraph->getDeviceImpl(); +device getDeviceFromHandler(handler &CGH) { + return createSyclObjFromImpl(getSyclObjImpl(CGH)->get_device()); } bool isDeviceGlobalUsedInKernel(const void *DeviceGlobalPtr) { @@ -316,8 +318,8 @@ fill_copy_args(detail::handler_impl *impl, handler::handler(const std::shared_ptr &Queue, bool CallerNeedsEvent) - : MImplOwner( - std::make_shared(nullptr, CallerNeedsEvent)), + : MImplOwner(std::make_shared(*Queue, nullptr, + CallerNeedsEvent)), impl(MImplOwner.get()), MQueue(Queue) {} handler::handler(detail::handler_impl *HandlerImpl, @@ -328,7 +330,8 @@ handler::handler(detail::handler_impl *HandlerImpl, handler::handler(std::shared_ptr Queue, bool CallerNeedsEvent) - : impl(std::make_shared(nullptr, CallerNeedsEvent)), + : impl(std::make_shared(*Queue, nullptr, + CallerNeedsEvent)), MQueue(std::move(Queue)) {} #ifndef __INTEL_PREVIEW_BREAKING_CHANGES @@ -338,20 +341,20 @@ handler::handler( std::shared_ptr Queue, [[maybe_unused]] std::shared_ptr PrimaryQueue, std::shared_ptr SecondaryQueue, bool CallerNeedsEvent) - : impl(std::make_shared(SecondaryQueue.get(), + : impl(std::make_shared(*Queue, SecondaryQueue.get(), CallerNeedsEvent)), MQueue(Queue) {} #endif handler::handler(std::shared_ptr Queue, detail::queue_impl *SecondaryQueue, bool CallerNeedsEvent) - : impl(std::make_shared(SecondaryQueue, + : impl(std::make_shared(*Queue, SecondaryQueue, CallerNeedsEvent)), MQueue(std::move(Queue)) {} handler::handler( std::shared_ptr Graph) - : impl(std::make_shared(Graph)) {} + : impl(std::make_shared(*Graph)) {} #endif @@ -380,11 +383,11 @@ bool handler::isStateExplicitKernelBundle() const { std::shared_ptr handler::getOrInsertHandlerKernelBundle(bool Insert) const { if (!impl->MKernelBundle && Insert) { - auto Ctx = - impl->MGraph ? impl->MGraph->getContext() : MQueue->get_context(); - auto Dev = impl->MGraph ? impl->MGraph->getDevice() : MQueue->get_device(); - impl->MKernelBundle = detail::getSyclObjImpl( - get_kernel_bundle(Ctx, {Dev}, {})); + context Ctx = detail::createSyclObjFromImpl(impl->get_context()); + impl->MKernelBundle = + detail::getSyclObjImpl(get_kernel_bundle( + Ctx, {detail::createSyclObjFromImpl(impl->get_device())}, + {})); } return impl->MKernelBundle; } @@ -416,12 +419,14 @@ event handler::finalize() { MIsFinalized = true; const auto &type = getType(); + detail::queue_impl *Queue = impl->get_queue_or_null(); + ext::oneapi::experimental::detail::graph_impl *Graph = + impl->get_graph_or_null(); const bool KernelFastPath = - (MQueue && !impl->MGraph && !impl->MSubgraphNode && - !MQueue->hasCommandGraph() && !impl->CGData.MRequirements.size() && - !MStreamStorage.size() && + (Queue && !Graph && !impl->MSubgraphNode && !Queue->hasCommandGraph() && + !impl->CGData.MRequirements.size() && !MStreamStorage.size() && detail::Scheduler::areEventsSafeForSchedulerBypass( - impl->CGData.MEvents, MQueue->getContextImplPtr())); + impl->CGData.MEvents, Queue->getContextImplPtr())); // Extract arguments from the kernel lambda, if required. // Skipping this is currently limited to simple kernels on the fast path. @@ -482,12 +487,12 @@ event handler::finalize() { KernelBundleImpPtr->hasSYCLOfflineImages()) && !KernelBundleImpPtr->tryGetKernel(toKernelNameStrT(MKernelName), KernelBundleImpPtr)) { - auto Dev = - impl->MGraph ? impl->MGraph->getDevice() : MQueue->get_device(); + detail::device_impl &Dev = impl->get_device(); kernel_id KernelID = detail::ProgramManager::getInstance().getSYCLKernelID( toKernelNameStrT(MKernelName)); - bool KernelInserted = KernelBundleImpPtr->add_kernel(KernelID, Dev); + bool KernelInserted = KernelBundleImpPtr->add_kernel( + KernelID, detail::createSyclObjFromImpl(Dev)); // If kernel was not inserted and the bundle is in input mode we try // building it and trying to find the kernel in executable mode if (!KernelInserted && @@ -499,7 +504,8 @@ event handler::finalize() { build(KernelBundle); KernelBundleImpPtr = detail::getSyclObjImpl(ExecKernelBundle); setHandlerKernelBundle(KernelBundleImpPtr); - KernelInserted = KernelBundleImpPtr->add_kernel(KernelID, Dev); + KernelInserted = KernelBundleImpPtr->add_kernel( + KernelID, detail::createSyclObjFromImpl(Dev)); } // If the kernel was not found in executable mode we throw an exception if (!KernelInserted) @@ -544,7 +550,7 @@ event handler::finalize() { #endif bool DiscardEvent = - !impl->MEventNeeded && MQueue->supportsDiscardingPiEvents(); + !impl->MEventNeeded && impl->get_queue().supportsDiscardingPiEvents(); if (DiscardEvent) { // Kernel only uses assert if it's non interop one bool KernelUsesAssert = @@ -705,11 +711,11 @@ event handler::finalize() { break; case detail::CGType::EnqueueNativeCommand: case detail::CGType::CodeplayHostTask: { - auto context = impl->MGraph - ? detail::getSyclObjImpl(impl->MGraph->getContext()) - : MQueue->getContextImplPtr(); + detail::context_impl &Context = impl->get_context(); + detail::queue_impl *Queue = impl->get_queue_or_null(); CommandGroup.reset(new detail::CGHostTask( - std::move(impl->MHostTask), MQueue, context, std::move(impl->MArgs), + std::move(impl->MHostTask), Queue ? Queue->shared_from_this() : nullptr, + Context.shared_from_this(), std::move(impl->MArgs), std::move(impl->CGData), getType(), MCodeLoc)); break; } @@ -756,14 +762,15 @@ event handler::finalize() { break; } case detail::CGType::ExecCommandBuffer: { + detail::queue_impl *Queue = impl->get_queue_or_null(); std::shared_ptr ParentGraph = - MQueue ? MQueue->getCommandGraph() : impl->MGraph; + Queue ? Queue->getCommandGraph() : impl->get_graph().shared_from_this(); // If a parent graph is set that means we are adding or recording a subgraph // and we don't want to actually execute this command graph submission. if (ParentGraph) { ext::oneapi::experimental::detail::graph_impl::WriteLock ParentLock; - if (MQueue) { + if (Queue) { ParentLock = ext::oneapi::experimental::detail::graph_impl::WriteLock( ParentGraph->MMutex); } @@ -828,7 +835,7 @@ event handler::finalize() { // If there is a graph associated with the handler we are in the explicit // graph mode, so we store the CG instead of submitting it to the scheduler, // so it can be retrieved by the graph later. - if (impl->MGraph) { + if (impl->get_graph_or_null()) { impl->MGraphNodeCG = std::move(CommandGroup); auto EventImpl = std::make_shared(); #ifdef __INTEL_PREVIEW_BREAKING_CHANGES @@ -1356,9 +1363,8 @@ void handler::verifyUsedKernelBundleInternal(detail::string_view KernelName) { return; kernel_id KernelID = detail::get_kernel_id_impl(KernelName); - device Dev = impl->MGraph ? impl->MGraph->getDevice() - : detail::getDeviceFromHandler(*this); - if (!UsedKernelBundleImplPtr->has_kernel(KernelID, Dev)) + if (!UsedKernelBundleImplPtr->has_kernel( + KernelID, detail::createSyclObjFromImpl(impl->get_device()))) throw sycl::exception( make_error_code(errc::kernel_not_supported), "The kernel bundle in use does not contain the kernel"); @@ -1555,8 +1561,10 @@ void handler::ext_oneapi_copy( MDstPtr = Dest; ur_exp_image_copy_flags_t ImageCopyFlags = detail::getUrImageCopyFlags( - get_pointer_type(Src, MQueue->get_context()), - get_pointer_type(Dest, MQueue->get_context())); + get_pointer_type(Src, + createSyclObjFromImpl(impl->get_context())), + get_pointer_type(Dest, + createSyclObjFromImpl(impl->get_context()))); if (ImageCopyFlags == UR_EXP_IMAGE_COPY_FLAG_HOST_TO_DEVICE || ImageCopyFlags == UR_EXP_IMAGE_COPY_FLAG_DEVICE_TO_HOST) { @@ -1585,8 +1593,10 @@ void handler::ext_oneapi_copy( MDstPtr = Dest; ur_exp_image_copy_flags_t ImageCopyFlags = detail::getUrImageCopyFlags( - get_pointer_type(Src, MQueue->get_context()), - get_pointer_type(Dest, MQueue->get_context())); + get_pointer_type(Src, + createSyclObjFromImpl(impl->get_context())), + get_pointer_type(Dest, + createSyclObjFromImpl(impl->get_context()))); // Fill the host extent based on the type of copy. if (ImageCopyFlags == UR_EXP_IMAGE_COPY_FLAG_HOST_TO_DEVICE) { @@ -1741,8 +1751,10 @@ void handler::ext_oneapi_copy( MDstPtr = Dest; ur_exp_image_copy_flags_t ImageCopyFlags = detail::getUrImageCopyFlags( - get_pointer_type(Src, MQueue->get_context()), - get_pointer_type(Dest, MQueue->get_context())); + get_pointer_type(Src, + createSyclObjFromImpl(impl->get_context())), + get_pointer_type(Dest, + createSyclObjFromImpl(impl->get_context()))); if (ImageCopyFlags == UR_EXP_IMAGE_COPY_FLAG_DEVICE_TO_DEVICE || ImageCopyFlags == UR_EXP_IMAGE_COPY_FLAG_HOST_TO_HOST) { @@ -1768,8 +1780,10 @@ void handler::ext_oneapi_copy( MDstPtr = Dest; ur_exp_image_copy_flags_t ImageCopyFlags = detail::getUrImageCopyFlags( - get_pointer_type(Src, MQueue->get_context()), - get_pointer_type(Dest, MQueue->get_context())); + get_pointer_type(Src, + createSyclObjFromImpl(impl->get_context())), + get_pointer_type(Dest, + createSyclObjFromImpl(impl->get_context()))); if (ImageCopyFlags == UR_EXP_IMAGE_COPY_FLAG_DEVICE_TO_DEVICE || ImageCopyFlags == UR_EXP_IMAGE_COPY_FLAG_HOST_TO_HOST) { @@ -1897,9 +1911,9 @@ void handler::ext_oneapi_signal_external_semaphore( void handler::use_kernel_bundle( const kernel_bundle &ExecBundle) { - if ((!impl->MGraph && (MQueue->get_context() != ExecBundle.get_context())) || - (impl->MGraph && - (impl->MGraph->getContext() != ExecBundle.get_context()))) + + if (&impl->get_context() != + detail::getSyclObjImpl(ExecBundle.get_context()).get()) throw sycl::exception( make_error_code(errc::invalid), "Context associated with the primary queue is different from the " @@ -1949,14 +1963,14 @@ void handler::depends_on(const detail::EventImplPtr &EventImpl) { if (MQueue && EventGraph) { auto QueueGraph = MQueue->getCommandGraph(); - if (EventGraph->getContext() != MQueue->get_context()) { + if (EventGraph->getContextImplPtr().get() != &impl->get_context()) { throw sycl::exception( make_error_code(errc::invalid), "Cannot submit to a queue with a dependency from a graph that is " "associated with a different context."); } - if (EventGraph->getDevice() != MQueue->get_device()) { + if (&EventGraph->getDeviceImpl() != &impl->get_device()) { throw sycl::exception( make_error_code(errc::invalid), "Cannot submit to a queue with a dependency from a graph that is " @@ -2081,17 +2095,15 @@ bool handler::supportsUSMMemset2D() { } id<2> handler::computeFallbackKernelBounds(size_t Width, size_t Height) { - device Dev = MQueue->get_device(); + device_impl &Dev = impl->get_device(); range<2> ItemLimit = Dev.get_info>() * Dev.get_info(); return id<2>{std::min(ItemLimit[0], Height), std::min(ItemLimit[1], Width)}; } +// TODO: do we need this still? backend handler::getDeviceBackend() const { - if (impl->MGraph) - return impl->MGraph->getDevice().get_backend(); - else - return MQueue->getDeviceImpl().getBackend(); + return impl->get_device().getBackend(); } void handler::ext_intel_read_host_pipe(detail::string_view Name, void *Ptr, @@ -2180,10 +2192,10 @@ void handler::memcpyFromHostOnlyDeviceGlobal(void *Dest, const std::shared_ptr & handler::getContextImplPtr() const { - if (impl->MGraph) { - return impl->MGraph->getContextImplPtr(); + if (auto *Graph = impl->get_graph_or_null()) { + return Graph->getContextImplPtr(); } - return MQueue->getContextImplPtr(); + return impl->get_queue().getContextImplPtr(); } void handler::setKernelCacheConfig(handler::StableKernelCacheConfig Config) { @@ -2228,15 +2240,11 @@ void handler::ext_oneapi_graph( std::shared_ptr handler::getCommandGraph() const { - if (impl->MGraph) { - return impl->MGraph; + if (auto *Graph = impl->get_graph_or_null()) { + return Graph->shared_from_this(); } - if (this->MQueue) - return MQueue->getCommandGraph(); - // We should never reach here. MGraph and MQueue can not be null - // simultaneously. - return nullptr; + return impl->get_queue().getCommandGraph(); } void handler::setUserFacingNodeType(ext::oneapi::experimental::node_type Type) { @@ -2244,7 +2252,7 @@ void handler::setUserFacingNodeType(ext::oneapi::experimental::node_type Type) { } std::optional> handler::getMaxWorkGroups() { - device_impl &DeviceImpl = detail::getDeviceImplFromHandler(*this); + device_impl &DeviceImpl = impl->get_device(); std::array UrResult = {}; auto Ret = DeviceImpl.getAdapter()->call_nocheck( DeviceImpl.getHandleRef(), @@ -2272,12 +2280,13 @@ void handler::registerDynamicParameter( ext::oneapi::experimental::detail::dynamic_parameter_impl *DynamicParamImpl, int ArgIndex) { - if (MQueue && MQueue->hasCommandGraph()) { + if (queue_impl *Queue = impl->get_queue_or_null(); + Queue && Queue->hasCommandGraph()) { throw sycl::exception( make_error_code(errc::invalid), "Dynamic Parameters cannot be used with Graph Queue recording."); } - if (!impl->MGraph) { + if (!impl->get_graph_or_null()) { throw sycl::exception( make_error_code(errc::invalid), "Dynamic Parameters cannot be used with normal SYCL submissions");