diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 0e3d2d7857..2355d64717 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -957,7 +957,7 @@ GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) namespace { struct GraphNodeBox { - CUgraphNode resource; + mutable CUgraphNode resource; GraphHandle h_graph; }; } // namespace @@ -978,6 +978,12 @@ GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept { return h ? get_box(h)->h_graph : GraphHandle{}; } +void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept { + if (h) { + get_box(h)->resource = nullptr; + } +} + // ============================================================================ // Graphics Resource Handles // ============================================================================ diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index 92d3cd4669..064f1406f6 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -415,6 +415,9 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_ // Extract the owning graph handle from a node handle. GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept; +// Zero the CUgraphNode resource inside the handle, marking it invalid. +void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept; + // ============================================================================ // Graphics resource handle functions // ============================================================================ diff --git a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx new file mode 100644 index 0000000000..5c5dae1ddd --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Mutable-set proxy for graph node predecessors and successors.""" + +from libc.stddef cimport size_t +from libcpp.vector cimport vector +from cuda.bindings cimport cydriver +from cuda.core._graph._graph_def._graph_node cimport GraphNode +from cuda.core._resource_handles cimport ( + GraphHandle, + GraphNodeHandle, + as_cu, + graph_node_get_graph, +) +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN +from collections.abc import MutableSet + + +# ---- Python MutableSet wrapper ---------------------------------------------- + +class AdjacencySetProxy(MutableSet): + """Mutable set proxy for a node's predecessors or successors. Mutations + write through to the underlying CUDA graph.""" + + __slots__ = ("_core",) + + def __init__(self, node, bint is_fwd): + self._core = _AdjacencySetCore(node, is_fwd) + + # Used by operators such as &|^ to create non-proxy views when needed. + @classmethod + def _from_iterable(cls, it): + return set(it) + + # --- abstract methods required by MutableSet --- + + def __contains__(self, x): + if not isinstance(x, GraphNode): + return False + return x in (<_AdjacencySetCore>self._core).query() + + def __iter__(self): + return iter((<_AdjacencySetCore>self._core).query()) + + def __len__(self): + return (<_AdjacencySetCore>self._core).count() + + def add(self, value): + if not isinstance(value, GraphNode): + raise TypeError( + f"expected GraphNode, got {type(value).__name__}") + if value in self: + return + (<_AdjacencySetCore>self._core).add_edge(value) + + def discard(self, value): + if not isinstance(value, GraphNode): + return + if value not in self: + return + (<_AdjacencySetCore>self._core).remove_edge(value) + + # --- override for bulk efficiency --- + + def clear(self): + """Remove all edges in a single driver call.""" + members = (<_AdjacencySetCore>self._core).query() + if members: + (<_AdjacencySetCore>self._core).remove_edges(members) + + def __isub__(self, it): + """Remove edges to all nodes in *it* in a single driver call.""" + if it is self: + self.clear() + else: + to_remove = [v for v in it if isinstance(v, GraphNode) and v in self] + if to_remove: + (<_AdjacencySetCore>self._core).remove_edges(to_remove) + return self + + def update(self, *others): + """Add edges to multiple nodes at once.""" + nodes = [] + for other in others: + if isinstance(other, GraphNode): + nodes.append(other) + else: + nodes.extend(other) + if not nodes: + return + for n in nodes: + if not isinstance(n, GraphNode): + raise TypeError( + f"expected GraphNode, got {type(n).__name__}") + new = [n for n in nodes if n not in self] + if new: + (<_AdjacencySetCore>self._core).add_edges(new) + + def __ior__(self, it): + """Add edges to all nodes in *it* in a single driver call.""" + self.update(it) + return self + + def __repr__(self): + return "{" + ", ".join(repr(n) for n in self) + "}" + + +# ---- cdef core holding a function pointer ------------------------------------ + +# Signature shared by driver_get_preds and driver_get_succs. +ctypedef cydriver.CUresult (*_adj_fn_t)( + cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil + + +cdef class _AdjacencySetCore: + """Cythonized core implementing AdjacencySetProxy""" + cdef: + GraphNodeHandle _h_node + GraphHandle _h_graph + _adj_fn_t _query_fn + bint _is_fwd + + def __init__(self, GraphNode node, bint is_fwd): + self._h_node = node._h_node + self._h_graph = graph_node_get_graph(node._h_node) + self._is_fwd = is_fwd + self._query_fn = driver_get_succs if is_fwd else driver_get_preds + + cdef inline void _resolve_edge( + self, GraphNode other, + cydriver.CUgraphNode* c_from, + cydriver.CUgraphNode* c_to) noexcept: + if self._is_fwd: + c_from[0] = as_cu(self._h_node) + c_to[0] = as_cu(other._h_node) + else: + c_from[0] = as_cu(other._h_node) + c_to[0] = as_cu(self._h_node) + + cdef list query(self): + cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) + if c_node == NULL: + return [] + cdef size_t count = 0 + with nogil: + HANDLE_RETURN(self._query_fn(c_node, NULL, &count)) + if count == 0: + return [] + cdef vector[cydriver.CUgraphNode] nodes_vec + nodes_vec.resize(count) + with nogil: + HANDLE_RETURN(self._query_fn( + c_node, nodes_vec.data(), &count)) + return [GraphNode._create(self._h_graph, nodes_vec[i]) + for i in range(count)] + + cdef Py_ssize_t count(self): + cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) + if c_node == NULL: + return 0 + cdef size_t n = 0 + with nogil: + HANDLE_RETURN(self._query_fn(c_node, NULL, &n)) + return n + + cdef void add_edge(self, GraphNode other): + cdef cydriver.CUgraphNode c_from, c_to + self._resolve_edge(other, &c_from, &c_to) + with nogil: + HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) + + cdef void add_edges(self, list nodes): + cdef size_t n = len(nodes) + cdef vector[cydriver.CUgraphNode] from_vec + cdef vector[cydriver.CUgraphNode] to_vec + from_vec.resize(n) + to_vec.resize(n) + cdef size_t i + for i in range(n): + self._resolve_edge(nodes[i], &from_vec[i], &to_vec[i]) + with nogil: + HANDLE_RETURN(driver_add_edges( + as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) + + cdef void remove_edge(self, GraphNode other): + cdef cydriver.CUgraphNode c_from, c_to + self._resolve_edge(other, &c_from, &c_to) + with nogil: + HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) + + cdef void remove_edges(self, list nodes): + cdef size_t n = len(nodes) + cdef vector[cydriver.CUgraphNode] from_vec + cdef vector[cydriver.CUgraphNode] to_vec + from_vec.resize(n) + to_vec.resize(n) + cdef size_t i + for i in range(n): + self._resolve_edge(nodes[i], &from_vec[i], &to_vec[i]) + with nogil: + HANDLE_RETURN(driver_remove_edges( + as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) + + +# ---- driver wrappers: absorb CUDA version differences ---- + +cdef inline cydriver.CUresult driver_get_preds( + cydriver.CUgraphNode node, cydriver.CUgraphNode* out, + size_t* count) noexcept nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return cydriver.cuGraphNodeGetDependencies(node, out, NULL, count) + ELSE: + return cydriver.cuGraphNodeGetDependencies(node, out, count) + + +cdef inline cydriver.CUresult driver_get_succs( + cydriver.CUgraphNode node, cydriver.CUgraphNode* out, + size_t* count) noexcept nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return cydriver.cuGraphNodeGetDependentNodes(node, out, NULL, count) + ELSE: + return cydriver.cuGraphNodeGetDependentNodes(node, out, count) + + +cdef inline cydriver.CUresult driver_add_edges( + cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, + cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return cydriver.cuGraphAddDependencies( + graph, from_arr, to_arr, NULL, count) + ELSE: + return cydriver.cuGraphAddDependencies( + graph, from_arr, to_arr, count) + + +cdef inline cydriver.CUresult driver_remove_edges( + cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, + cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return cydriver.cuGraphRemoveDependencies( + graph, from_arr, to_arr, NULL, count) + ELSE: + return cydriver.cuGraphRemoveDependencies( + graph, from_arr, to_arr, count) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx index d45c72ba2a..03673844d5 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx @@ -314,12 +314,12 @@ cdef class GraphDef: with nogil: HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(as_cu(self._h_graph), c_path, flags)) - def nodes(self) -> tuple: + def nodes(self) -> set: """Return all nodes in the graph. Returns ------- - tuple of GraphNode + set of GraphNode All nodes in the graph. """ cdef size_t num_nodes = 0 @@ -328,21 +328,21 @@ cdef class GraphDef: HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes)) if num_nodes == 0: - return () + return set() cdef vector[cydriver.CUgraphNode] nodes_vec nodes_vec.resize(num_nodes) with nogil: HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) - return tuple(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes)) + return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes)) - def edges(self) -> tuple: + def edges(self) -> set: """Return all edges in the graph as (from_node, to_node) pairs. Returns ------- - tuple of tuple + set of tuple Each element is a (from_node, to_node) pair representing a dependency edge in the graph. """ @@ -355,7 +355,7 @@ cdef class GraphDef: HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges)) if num_edges == 0: - return () + return set() cdef vector[cydriver.CUgraphNode] from_nodes cdef vector[cydriver.CUgraphNode] to_nodes @@ -369,7 +369,7 @@ cdef class GraphDef: HANDLE_RETURN(cydriver.cuGraphGetEdges( as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges)) - return tuple( + return set( (GraphNode._create(self._h_graph, from_nodes[i]), GraphNode._create(self._h_graph, to_nodes[i])) for i in range(num_edges) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd index 7a9f82f33f..0a87b70ad6 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd @@ -9,8 +9,6 @@ from cuda.core._resource_handles cimport GraphHandle, GraphNodeHandle cdef class GraphNode: cdef: GraphNodeHandle _h_node - tuple _pred_cache - tuple _succ_cache object __weakref__ @staticmethod diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index 17c2c072f7..4048c9ee06 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -48,6 +48,7 @@ from cuda.core._resource_handles cimport ( create_graph_handle_ref, create_graph_node_handle, graph_node_get_graph, + invalidate_graph_node_handle, ) from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value @@ -57,6 +58,7 @@ from cuda.core._graph._utils cimport ( ) from cuda.core import Device +from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy from cuda.core._utils.cuda_utils import driver, handle_return @@ -123,32 +125,48 @@ cdef class GraphNode: return as_py(self._h_node) @property - def pred(self) -> tuple: - """Return the predecessor nodes (dependencies) of this node. + def is_valid(self): + """Whether this node is valid (not destroyed). - Results are cached since a node's dependencies are immutable - once created. + Returns ``False`` after :meth:`destroy` has been called. + """ + return as_intptr(self._h_node) != 0 - Returns - ------- - tuple of GraphNode - The nodes that this node depends on. + def destroy(self): + """Destroy this node and remove all its edges from the parent graph. + + After this call, :attr:`is_valid` returns ``False`` and the node + cannot be re-added to any graph. Safe to call on an + already-destroyed node (no-op). """ - return GN_pred(self) + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + return + with nogil: + HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) + invalidate_graph_node_handle(self._h_node) @property - def succ(self) -> tuple: - """Return the successor nodes (dependents) of this node. + def pred(self): + """A mutable set-like view of this node's predecessors.""" + return AdjacencySetProxy(self, False) - Results are cached and automatically invalidated when new - dependent nodes are added via builder methods. + @pred.setter + def pred(self, value): + p = AdjacencySetProxy(self, False) + p.clear() + p.update(value) - Returns - ------- - tuple of GraphNode - The nodes that depend on this node. - """ - return GN_succ(self) + @property + def succ(self): + """A mutable set-like view of this node's successors.""" + return AdjacencySetProxy(self, True) + + @succ.setter + def succ(self, value): + s = AdjacencySetProxy(self, True) + s.clear() + s.update(value) def launch(self, config: LaunchConfig, kernel: Kernel, *args) -> KernelNode: """Add a kernel launch node depending on this node. @@ -504,7 +522,6 @@ cdef inline ConditionalNode _make_conditional_node( n._cond_type = cond_type n._branches = branches - pred._succ_cache = None return n cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): @@ -546,72 +563,6 @@ cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): return n -cdef inline tuple GN_pred(GraphNode self): - if self._pred_cache is not None: - return self._pred_cache - - cdef cydriver.CUgraphNode node = as_cu(self._h_node) - if node == NULL: - self._pred_cache = () - return self._pred_cache - - cdef size_t num_deps = 0 - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, NULL, NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, NULL, &num_deps)) - - if num_deps == 0: - self._pred_cache = () - return self._pred_cache - - cdef vector[cydriver.CUgraphNode] deps - deps.resize(num_deps) - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, deps.data(), NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, deps.data(), &num_deps)) - - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - self._pred_cache = tuple(GraphNode._create(h_graph, deps[i]) for i in range(num_deps)) - return self._pred_cache - - -cdef inline tuple GN_succ(GraphNode self): - if self._succ_cache is not None: - return self._succ_cache - - cdef cydriver.CUgraphNode node = as_cu(self._h_node) - if node == NULL: - self._succ_cache = () - return self._succ_cache - - cdef size_t num_deps = 0 - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, NULL, NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, NULL, &num_deps)) - - if num_deps == 0: - self._succ_cache = () - return self._succ_cache - - cdef vector[cydriver.CUgraphNode] deps - deps.resize(num_deps) - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, deps.data(), NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, deps.data(), &num_deps)) - - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - self._succ_cache = tuple(GraphNode._create(h_graph, deps[i]) for i in range(num_deps)) - return self._succ_cache - - cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, ParamHolder ker_args): cdef cydriver.CUDA_KERNEL_NODE_PARAMS node_params cdef cydriver.CUgraphNode new_node = NULL @@ -644,7 +595,6 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, _attach_user_object(as_cu(h_graph), new KernelHandle(ker._h_kernel), _destroy_kernel_handle_copy) - self._succ_cache = None return KernelNode._create_with_params( create_graph_node_handle(new_node, h_graph), conf.grid, conf.block, conf.shmem_size, @@ -674,9 +624,6 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): HANDLE_RETURN(cydriver.cuGraphAddEmptyNode( &new_node, as_cu(h_graph), deps_ptr, num_deps)) - self._succ_cache = None - for other in nodes: - (other)._succ_cache = None return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)) @@ -753,7 +700,6 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode( &new_node, as_cu(h_graph), deps, num_deps, &alloc_params)) - self._succ_cache = None return AllocNode._create_with_params( create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size, device_id, memory_type, tuple(peer_ids)) @@ -774,7 +720,6 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode( &new_node, as_cu(h_graph), deps, num_deps, c_dptr)) - self._succ_cache = None return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr) @@ -810,7 +755,6 @@ cdef inline MemsetNode GN_memset( &new_node, as_cu(h_graph), deps, num_deps, &memset_params, ctx)) - self._succ_cache = None return MemsetNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, val, elem_size, width, height, pitch) @@ -872,7 +816,6 @@ cdef inline MemcpyNode GN_memcpy( HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode( &new_node, as_cu(h_graph), deps, num_deps, ¶ms, ctx)) - self._succ_cache = None return MemcpyNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, c_src, size, c_dst_type, c_src_type) @@ -900,7 +843,6 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph) - self._succ_cache = None return ChildGraphNode._create_with_params( create_graph_node_handle(new_node, h_graph), h_embedded) @@ -923,7 +865,6 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - self._succ_cache = None return EventRecordNode._create_with_params( create_graph_node_handle(new_node, h_graph), ev._h_event) @@ -946,7 +887,6 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - self._succ_cache = None return EventWaitNode._create_with_params( create_graph_node_handle(new_node, h_graph), ev._h_event) @@ -974,7 +914,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_ &new_node, as_cu(h_graph), deps, num_deps, &node_params)) cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None - self._succ_cache = None return HostCallbackNode._create_with_params( create_graph_node_handle(new_node, h_graph), callable_obj, node_params.fn, node_params.userData) diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index 419106f04a..f847e60223 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -186,6 +186,7 @@ cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHand # Graph node handles cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil +cdef void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept nogil # Graphics resource handles cdef GraphicsResourceHandle create_graphics_resource_handle( diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index 39b425b9ed..001f9b4a0c 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -159,6 +159,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil GraphHandle graph_node_get_graph "cuda_core::graph_node_get_graph" ( const GraphNodeHandle& h) noexcept nogil + void invalidate_graph_node_handle "cuda_core::invalidate_graph_node_handle" ( + const GraphNodeHandle& h) noexcept nogil # Graphics resource handles GraphicsResourceHandle create_graphics_resource_handle "cuda_core::create_graphics_resource_handle" ( diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index 3412d71847..be6da9515a 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -712,8 +712,8 @@ def test_graphdef_entry_is_virtual(sample_graphdef): """Internal entry node is virtual (no pred/succ, type is None).""" entry = sample_graphdef._entry assert isinstance(entry, GraphNode) - assert entry.pred == () - assert entry.succ == () + assert entry.pred == set() + assert entry.succ == set() assert entry.type is None diff --git a/cuda_core/tests/graph/test_graphdef_lifetime.py b/cuda_core/tests/graph/test_graphdef_lifetime.py index 133f2c7ca1..3b254d423f 100644 --- a/cuda_core/tests/graph/test_graphdef_lifetime.py +++ b/cuda_core/tests/graph/test_graphdef_lifetime.py @@ -68,7 +68,7 @@ def test_branches_survive_parent_deletion(init_cuda, builder, expected_count): gc.collect() for branch in branches: - assert branch.nodes() == () + assert branch.nodes() == set() @pytest.mark.parametrize("builder, expected_count", _COND_BUILDERS) @@ -108,7 +108,7 @@ def test_reconstructed_body_survives_parent_deletion(init_cuda): del g, condition, all_nodes, cond_nodes, branches gc.collect() - assert body.nodes() == () + assert body.nodes() == set() # ============================================================================= @@ -477,7 +477,7 @@ def test_kernel_node_reconstruction_preserves_validity(init_cuda): # Reconstruct the kernel node through DAG traversal # successor.pred -> GraphNode._create -> KernelNode._create_from_driver # -> create_kernel_handle_ref -> handle recovery - reconstructed = successor.pred[0] + reconstructed = next(iter(successor.pred)) assert isinstance(reconstructed, KernelNode) assert reconstructed.kernel.attributes.max_threads_per_block() > 0 diff --git a/cuda_core/tests/graph/test_graphdef_mutation.py b/cuda_core/tests/graph/test_graphdef_mutation.py new file mode 100644 index 0000000000..dcfd4aab89 --- /dev/null +++ b/cuda_core/tests/graph/test_graphdef_mutation.py @@ -0,0 +1,394 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +"""Tests for mutating a graph definition (edge changes, node removal).""" + +import numpy as np +import pytest +from helpers.collection_interface_testers import assert_mutable_set_interface +from helpers.graph_kernels import compile_parallel_kernels +from helpers.marks import requires_module + +from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource +from cuda.core._graph._graph_def import GraphDef, KernelNode, MemsetNode +from cuda.core._utils.cuda_utils import CUDAError + + +class YRig: + """Test rigging for graph mutation tests. Constructs a Y-shaped graph with + two parallel arms joined by a combine node. Modifying the sequence of + operations along either arm changes the output. + + Topology:: + + a0 -- a1 -- a2 + \ + j -- r + / + b0 -- b1 + + Each a/b node applies ``affine(ptr, m, b)`` to its arm's int accumulator. + Node r computes result ``combine(R, A, B) = (A << 16) | (B & 0xFFFF)``, + encoding both arms' results into a single int. j is a joining (empty) node + preceeding r. + + The affine operation a * m + b is noncommutative, so we can be sure the + graph has exactly the topology we expect by checking the final value. + """ + + def __init__(self): + self.A_OPS = [(2, 1), (3, 2), (5, 3)] + self.B_OPS = [(2, 7), (3, 1)] + + mod = compile_parallel_kernels() + self.affine = mod.get_kernel("affine") + self.combine = mod.get_kernel("combine") + self.config = LaunchConfig(grid=1, block=1) + + self._mr = LegacyPinnedMemoryResource() + self._buf = self._mr.allocate(3 * 4) + self._arr = np.from_dlpack(self._buf).view(np.int32) + + self.ptr_a = self._arr[0:].ctypes.data + self.ptr_b = self._arr[1:].ctypes.data + self.ptr_r = self._arr[2:].ctypes.data + + self.graph_def = GraphDef() + self.stream = None + + # Arm A + self.a = [] + prev = self.graph_def + for m, b in self.A_OPS: + prev = prev.launch(self.config, self.affine, self.ptr_a, m, b) + self.a.append(prev) + + # Arm B + self.b = [] + prev = self.graph_def + for m, b in self.B_OPS: + prev = prev.launch(self.config, self.affine, self.ptr_b, m, b) + self.b.append(prev) + + # Join and combine + self.j = self.graph_def.join(self.a[-1], self.b[-1]) + self.r = self.j.launch(self.config, self.combine, self.ptr_r, self.ptr_a, self.ptr_b) + + def run(self): + if self.stream is None: + self.stream = Device().create_stream() + graph = self.graph_def.instantiate() + self.reset() + graph.launch(self.stream) + self.stream.sync() + + def reset(self): + self._arr[:] = 0 + + @property + def A_out(self): + return int(self._arr[0]) + + @property + def B_out(self): + return int(self._arr[1]) + + @property + def R_out(self): + return int(self._arr[2]) + + @property + def output(self): + return self.A_out, self.B_out, self.R_out + + @property + def expected_output(self): + """Expected (A, B, R) after one run from zero.""" + + def apply_affine(val, ops): + for m, b in ops: + val = val * m + b + return val + + a = apply_affine(0, self.A_OPS) + b = apply_affine(0, self.B_OPS) + r = (a << 16) | (b & 0xFFFF) + return (a, b, r) + + @property + def edges(self): + return self.graph_def.edges() + + @property + def initial_edges(self): + return ( + set(zip(self.a, self.a[1:])) + | set(zip(self.b, self.b[1:])) + | {(self.a[-1], self.j), (self.b[-1], self.j), (self.j, self.r)} + ) + + @property + def nodes(self): + return self.graph_def.nodes() + + @property + def initial_nodes(self): + return set(self.a + self.b + [self.j, self.r]) + + def close(self): + self._buf.close() + + +@requires_module(np, "2.1") +class TestMutateYRig: + """Tests that mutate the Y-shaped graph built by YRig.""" + + def test_baseline(self, init_cuda): + """Unmodified graph produces the expected results.""" + rig = YRig() + rig.run() + assert rig.output == rig.expected_output + assert rig.edges == rig.initial_edges + assert rig.nodes == rig.initial_nodes + rig.close() + + def test_destroy_a1(self, init_cuda): + """Destroy a1 (creates a race on arm a). Arm b yields the expected + value, and the final step is correctly ordered after b completes.""" + rig = YRig() + rig.a[1].destroy() + rig.run() + _, b_exp, _ = rig.expected_output + assert rig.B_out == b_exp + assert (rig.R_out & 0xFFFF) == b_exp + a0, a1, a2 = rig.a + assert rig.edges == rig.initial_edges - {(a0, a1), (a1, a2)} + assert rig.nodes == rig.initial_nodes - {a1} + rig.close() + + def test_destroy_a2(self, init_cuda): + """Destroy a2, connect a1--r""" + rig = YRig() + rig.a[2].destroy() + rig.a[1].succ.add(rig.r) + rig.A_OPS.pop() + rig.run() + assert rig.output == rig.expected_output + a0, a1, a2, j, r = rig.a + [rig.j, rig.r] + assert rig.edges == (rig.initial_edges - {(a1, a2), (a2, j)}) | {(a1, r)} + assert rig.nodes == rig.initial_nodes - {a2} + rig.close() + + def test_destroy_joint(self, init_cuda): + """Remove the joining node and instead add edges directly to r.""" + rig = YRig() + _, _, a2, _, b1, j, r = rig.a + rig.b + [rig.j, rig.r] + j.destroy() + r.pred = {a2, b1} + rig.run() + assert rig.output == rig.expected_output + assert rig.edges == (rig.initial_edges - {(a2, j), (b1, j), (j, r)}) | {(a2, r), (b1, r)} + assert rig.nodes == rig.initial_nodes - {j} + rig.close() + + def test_insert_b(self, init_cuda): + """Insert a node into arm b.""" + rig = YRig() + coeffs = 5, 3 + b_new = rig.graph_def.launch(rig.config, rig.affine, rig.ptr_b, *coeffs) + b0, b1 = rig.b + b0.succ.discard(b1) + b0.succ.add(b_new) + b_new.succ.add(b1) + rig.B_OPS.insert(1, coeffs) + rig.run() + assert rig.output == rig.expected_output + assert rig.edges == (rig.initial_edges - {(b0, b1)}) | {(b0, b_new), (b_new, b1)} + assert rig.nodes == rig.initial_nodes | {b_new} + rig.close() + + +def test_adjacency_set_interface(init_cuda): + """Exercise every MutableSet method on AdjacencySetProxy.""" + g = GraphDef() + hub = g.join() + items = [g.join() for _ in range(5)] + assert_mutable_set_interface(hub.succ, items) + + +def test_adjacency_set_pred_direction(init_cuda): + """Verify that pred works symmetrically with succ.""" + g = GraphDef() + target = g.join() + x, y, z = (g.join() for _ in range(3)) + + pred = target.pred + assert pred == set() + + pred.add(x) + pred.add(y) + assert pred == {x, y} + + # Verify the edge is visible from the other direction + assert target in x.succ + assert target in y.succ + assert target not in z.succ + + pred.discard(x) + assert pred == {y} + assert target not in x.succ + + +def test_adjacency_set_property_setter(init_cuda): + """Verify that assigning to node.pred or node.succ replaces all edges.""" + g = GraphDef() + hub = g.join() + a, b, c = (g.join() for _ in range(3)) + + hub.succ = {a, b} + assert hub.succ == {a, b} + + hub.succ = {c} + assert hub.succ == {c} + assert a not in hub.succ + + hub.succ = set() + assert hub.succ == set() + + hub.pred = {a, b} + assert hub.pred == {a, b} + + hub.pred = set() + assert hub.pred == set() + + hub.pred = set() + assert hub.pred == set() + + +def test_destroyed_node(init_cuda): + """Test that destroy() invalidates a node.""" + mr = LegacyPinnedMemoryResource() + buf = mr.allocate(4) + arr = np.from_dlpack(buf).view(np.int32) + arr[:] = 0 + ptr = arr[0:].ctypes.data + + g = GraphDef() + a = g.memset(ptr, 0, 4) + b = a.memset(ptr, 42, 4) + + assert a.is_valid + assert b.is_valid + assert b in g.nodes() + assert (a, b) in g.edges() + + b.destroy() + + assert not b.is_valid + assert b not in g.nodes() + assert (a, b) not in g.edges() + + # Python object is invalid but using it does not crash. + assert isinstance(b, MemsetNode) + assert b.type is None + assert b.pred == set() + assert b.succ == set() + assert b.handle is None + assert b.dptr == ptr # tolerable + assert b.value == 42 # tolerable + assert b.width == 4 # tolerable + + # Adding an edge to a destroyed node fails. + with pytest.raises(CUDAError): + a.succ.add(b) + + # Repeated destroy succeeds quietly. + b.destroy() + assert not b.is_valid + + +def test_add_wrong_type(init_cuda): + """Adding a non-GraphNode raises TypeError.""" + g = GraphDef() + node = g.join() + with pytest.raises(TypeError, match="expected GraphNode"): + node.succ.add("not a node") + with pytest.raises(TypeError, match="expected GraphNode"): + node.succ.add(42) + + +def test_cross_graph_edge(init_cuda): + """Adding an edge to a node from a different graph raises CUDAError.""" + g1 = GraphDef() + g2 = GraphDef() + a = g1.join() + b = g2.join() + with pytest.raises(CUDAError): + a.succ.add(b) + + +def test_self_edge(init_cuda): + """Adding a self-edge raises CUDAError.""" + g = GraphDef() + node = g.join() + with pytest.raises(CUDAError): + node.succ.add(node) + + +@requires_module(np, "2.1") +def test_convert_linear_to_fan_in(init_cuda): + """Chain four computations sequentially, then rewire so all pairs run in + parallel feeding into a reduce node. + + Initial topology (sequential):: + + memset0 -- launch0 -- memset1 -- launch1 -- memset2 -- launch2 -- memset3 -- launch3 + + After rewiring (parallel):: + + memset0 -- launch0 --\ + memset1 -- launch1 ---+-- reduce + memset2 -- launch2 --/ + memset3 -- launch3 -/ + """ + mod = compile_parallel_kernels() + affine = mod.get_kernel("affine") + reduce_kern = mod.get_kernel("reduce") + config = LaunchConfig(grid=1, block=1) + + mr = LegacyPinnedMemoryResource() + buf = mr.allocate(5 * 4) + arr = np.from_dlpack(buf).view(np.int32) + arr[:] = 0 + + values = np.array([10, 20, 30, 40], dtype=np.int32) + ptrs = [arr[i:].ctypes.data for i in range(5)] + + # Create the initial graph. + g = GraphDef() + prev = g + for i, val in enumerate(values): + prev = prev.memset(ptrs[i], val, 1) + prev = prev.launch(config, affine, ptrs[i], 2, 1) + reduce_node = g.launch(config, reduce_kern, ptrs[4], ptrs[0], 4) + + # Rewire: + # - drop preds from memsets + # - connect kernel launches to the reduction + assert len(g.edges()) == 7 + + for node in g.nodes(): + if isinstance(node, MemsetNode): + node.pred.clear() + elif isinstance(node, KernelNode) and node != reduce_node: + node.succ.add(reduce_node) + + assert len(g.edges()) == 8 + + stream = Device().create_stream() + graph = g.instantiate() + graph.launch(stream) + stream.sync() + assert arr[4] == sum(2 * values + 1) + + buf.close() diff --git a/cuda_core/tests/helpers/collection_interface_testers.py b/cuda_core/tests/helpers/collection_interface_testers.py new file mode 100644 index 0000000000..d9b5ee2cd0 --- /dev/null +++ b/cuda_core/tests/helpers/collection_interface_testers.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +"""Reusable helpers to verify collections.abc protocol conformance.""" + +from collections.abc import MutableSet, Set + +import pytest + + +def assert_mutable_set_interface(subject, items): + """Exercise every MutableSet method on *subject* against a reference set. + + Parameters + ---------- + subject : MutableSet + An **empty** mutable-set-like object to test. + items : sequence + At least five distinct, hashable objects valid for insertion into + *subject*. + """ + assert len(items) >= 5 + a, b, c, d, e = items[:5] + ref = set() + + # -- ABC conformance -- + assert isinstance(subject, Set) + assert isinstance(subject, MutableSet) + + # -- empty state -- + assert len(subject) == 0 + assert subject == ref + assert subject == set() + assert list(subject) == [] + + # -- add -- + subject.add(a) + ref.add(a) + assert subject == ref + assert a in subject + assert b not in subject + assert len(subject) == 1 + + subject.add(b) + subject.add(c) + ref.update({b, c}) + assert subject == ref + assert len(subject) == 3 + + # add duplicate is a no-op + subject.add(a) + assert subject == ref + + # -- discard -- + subject.discard(b) + ref.discard(b) + assert subject == ref + + # discard non-member is a no-op + subject.discard(d) + assert subject == ref + + # -- remove -- + subject.add(b) + ref.add(b) + subject.remove(b) + ref.remove(b) + assert subject == ref + + with pytest.raises(KeyError): + subject.remove(d) + + # -- comparison with plain set -- + assert subject == {a, c} + assert subject != {a, b} + + # -- isdisjoint -- + assert subject.isdisjoint({d, e}) + assert not subject.isdisjoint({a, d}) + + # -- subset / superset -- + assert subject <= {a, c} + assert subject <= {a, b, c} + assert not (subject <= {a}) + assert subject < {a, b, c} + assert not (subject < {a, c}) + assert {a, c} >= subject + assert {a, b, c} > subject + + # -- binary operators -- + assert subject & {a, d} == {a} + assert subject | {d} == {a, c, d} + assert subject - {c} == {a} + assert subject ^ {c, d} == {a, d} + + # -- in-place union (|=) -- + subject |= {d, e} + ref |= {d, e} + assert subject == ref + + # -- in-place intersection (&=) -- + subject &= {a, d, e} + ref &= {a, d, e} + assert subject == ref + + # -- in-place difference (-=) -- + subject -= {e} + ref -= {e} + assert subject == ref + + # -- in-place symmetric difference (^=) -- + subject ^= {a, b} + ref ^= {a, b} + assert subject == ref + + # -- pop -- + popped = subject.pop() + ref.discard(popped) + assert popped not in subject + assert subject == ref + + # -- clear -- + subject.clear() + ref.clear() + assert subject == ref + assert len(subject) == 0 + + with pytest.raises(KeyError): + subject.pop() + + # -- bulk add via |= -- + subject |= {a, b, c} + ref.update({a, b, c}) + assert subject == ref + + # -- __iter__ -- + assert set(subject) == ref + + # -- __repr__ -- + r = repr(subject) + assert isinstance(r, str) + assert len(r) > 0 diff --git a/cuda_core/tests/helpers/graph_kernels.py b/cuda_core/tests/helpers/graph_kernels.py index c38f0bafde..657d7509b2 100644 --- a/cuda_core/tests/helpers/graph_kernels.py +++ b/cuda_core/tests/helpers/graph_kernels.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE """Shared kernel compilation helpers for graph tests.""" @@ -79,3 +79,23 @@ def compile_conditional_kernels(cond_type): nvrtcVersion = handle_return(nvrtc.nvrtcVersion()) pytest.skip(f"NVRTC version {nvrtcVersion} does not support conditionals") return mod + + +def compile_parallel_kernels(): + """Compile kernels for parallel graph tests. + + Returns a module with: + - affine: computes *a = *a * m + b + - combine: computes *s = (*a << 16) | (*b & 0xFFFF) + - reduce: computes a sum. + """ + code = """ + __global__ void affine(int *a, int m, int b) { *a = *a * m + b; } + __global__ void combine(int *s, int *a, int *b) { *s = (*a << 16) | (*b & 0xFFFF); } + __global__ void reduce(int *out, int *in, size_t n) { for(size_t i=0; i