Skip to content
Merged
8 changes: 7 additions & 1 deletion cuda_core/cuda/core/_cpp/resource_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent)

namespace {
struct GraphNodeBox {
CUgraphNode resource;
mutable CUgraphNode resource;
GraphHandle h_graph;
};
} // namespace
Expand All @@ -978,6 +978,12 @@ GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {
return h ? get_box(h)->h_graph : GraphHandle{};
}

void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept {
if (h) {
get_box(h)->resource = nullptr;
}
}

// ============================================================================
// Graphics Resource Handles
// ============================================================================
Expand Down
3 changes: 3 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_
// Extract the owning graph handle from a node handle.
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept;

// Zero the CUgraphNode resource inside the handle, marking it invalid.
void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make the invalidate_graph_node_handle take a mutable reference? Rather than internally mutating through a const& handle?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it will improve anything. GraphNodeHandle is a std::shared_ptr, and this "invalidate" function relies on T& std::shared_ptr::operator*() const noexcept to get a non-const reference to the target object (note the function is marked const but the return type is a non-const T&; doc). Aconst qualifier on GraphNodeHandle only refers to whether the pointer can be reset or rebound to a new target through something like assignment/move, not whether the target itself is mutable. The "invalidate" function marks the target as having been invalidated by a call to cuGraphNodeDestroy but does not change the target location, so a const handle is correct.

Removing const from this function would not have the effect you might expect. Consider these three handles to two targets:

GraphNodeHandle h1 = ...;      // a graph node handle
const GraphNodeHandle h2 = h1; // a second handle to the same graph node
GraphNodeHandle h3 = ...;      // a handle to an unrelated graph node

Marking h2 with const protects the handle, not the GraphNode it points to. In this case, h2 = h3 would be disallowed. If we changed the declaration as you suggest, invalidate_graph_node_handle(h2) would be disallowed; however, invalidate_graph_node_handle(h1) would still be allowed and would invalidate the shared graph node.

In general, handle targets are always marked const because resource handles by design only permit attach, detach, and dereference operations. The GraphNodeHandle specifically and oddly needs invalidation because of the strange (I'd say broken) semantics of cuGraphNodeDestroy.

One way to support invalidation would be removing const from this declaration:

using GraphNodeHandle = std::shared_ptr<const CUgraphNode>

Then one could write *h2 = NULL to invalidate. That's not a good solution because it allows mutation everywhere.

The better approach (what's in this PR) is to keep the current definition, consistent with other handle types, while marking CUgraphNode as mutable in GraphNodeBox and providing an "invalidate" function to invalidate the handle when needed. This keeps the GraphNodeHandle similar to all other handles, while supporting this odd destroy behavior that is needed because of cuGraphNodeDestroy.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should probably be called invalidate_graph_node


// ============================================================================
// Graphics resource handle functions
// ============================================================================
Expand Down
246 changes: 246 additions & 0 deletions cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

"""Mutable-set proxy for graph node predecessors and successors."""

from libc.stddef cimport size_t
from libcpp.vector cimport vector
from cuda.bindings cimport cydriver
from cuda.core._graph._graph_def._graph_node cimport GraphNode
from cuda.core._resource_handles cimport (
GraphHandle,
GraphNodeHandle,
as_cu,
graph_node_get_graph,
)
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
from collections.abc import MutableSet


# ---- Python MutableSet wrapper ----------------------------------------------

class AdjacencySetProxy(MutableSet):
"""Mutable set proxy for a node's predecessors or successors. Mutations
write through to the underlying CUDA graph."""

__slots__ = ("_core",)

def __init__(self, node, bint is_fwd):
self._core = _AdjacencySetCore(node, is_fwd)

# Used by operators such as &|^ to create non-proxy views when needed.
@classmethod
def _from_iterable(cls, it):
return set(it)

# --- abstract methods required by MutableSet ---

def __contains__(self, x):
if not isinstance(x, GraphNode):
return False
return x in (<_AdjacencySetCore>self._core).query()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be premature optimization, but is it required to instantiate an entire list that is a copy of the elements on the C++ side just to test for in? We'd have to check that this works in Cython, but having query return an iterator comprehension rather than a list comprehension may be enough.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this?

--- a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx
+++ b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx
@@ -68,7 +68,7 @@ class AdjacencySetProxy(MutableSet):
         """Remove all edges in a single driver call."""
         members = (<_AdjacencySetCore>self._core).query()
         if members:
-            (<_AdjacencySetCore>self._core).remove_edges(members)
+            (<_AdjacencySetCore>self._core).remove_edges(list(members))

     def __isub__(self, it):
         """Remove edges to all nodes in *it* in a single driver call."""
@@ -139,7 +139,7 @@ cdef class _AdjacencySetCore:
             c_from[0] = as_cu(other._h_node)
             c_to[0] = as_cu(self._h_node)

-    cdef list query(self):
+    cdef query(self):
         cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
         if c_node == NULL:
             return []
@@ -153,8 +153,8 @@ cdef class _AdjacencySetCore:
         with nogil:
             HANDLE_RETURN(self._query_fn(
                 c_node, nodes_vec.data(), &count))
-        return [GraphNode._create(self._h_graph, nodes_vec[i])
-                for i in range(count)]
+        return (GraphNode._create(self._h_graph, nodes_vec[i])
+                for i in range(count))

     cdef Py_ssize_t count(self):
         cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this, too. Definitely a good idea to see whether query can return a generator. I was thinking of seeing whether I can implement contains() directly on AdjacencySetCore. We have to query the driver in batch, but we don't need to reconstruct the Python objects at all just to check containment.


def __iter__(self):
return iter((<_AdjacencySetCore>self._core).query())

def __len__(self):
return (<_AdjacencySetCore>self._core).count()

def add(self, value):
if not isinstance(value, GraphNode):
raise TypeError(
f"expected GraphNode, got {type(value).__name__}")
if value in self:
return
(<_AdjacencySetCore>self._core).add_edge(<GraphNode>value)

def discard(self, value):
if not isinstance(value, GraphNode):
return
if value not in self:
return
(<_AdjacencySetCore>self._core).remove_edge(<GraphNode>value)

# --- override for bulk efficiency ---

def clear(self):
"""Remove all edges in a single driver call."""
members = (<_AdjacencySetCore>self._core).query()
if members:
(<_AdjacencySetCore>self._core).remove_edges(members)

def __isub__(self, it):
"""Remove edges to all nodes in *it* in a single driver call."""
if it is self:
self.clear()
else:
to_remove = [v for v in it if isinstance(v, GraphNode) and v in self]
if to_remove:
(<_AdjacencySetCore>self._core).remove_edges(to_remove)
return self

def update(self, *others):
"""Add edges to multiple nodes at once."""
nodes = []
for other in others:
if isinstance(other, GraphNode):
nodes.append(other)
else:
nodes.extend(other)
if not nodes:
return
for n in nodes:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could move this check next to the nodes.extend(other) above, since we already know it to be true if isinstance(other, GraphNode) above.

if not isinstance(n, GraphNode):
raise TypeError(
f"expected GraphNode, got {type(n).__name__}")
new = [n for n in nodes if n not in self]
if new:
(<_AdjacencySetCore>self._core).add_edges(new)

def __ior__(self, it):
"""Add edges to all nodes in *it* in a single driver call."""
self.update(it)
return self

def __repr__(self):
return "{" + ", ".join(repr(n) for n in self) + "}"


# ---- cdef core holding a function pointer ------------------------------------

# Signature shared by driver_get_preds and driver_get_succs.
ctypedef cydriver.CUresult (*_adj_fn_t)(
cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil


cdef class _AdjacencySetCore:
"""Cythonized core implementing AdjacencySetProxy"""
cdef:
GraphNodeHandle _h_node
GraphHandle _h_graph
_adj_fn_t _query_fn
bint _is_fwd

def __init__(self, GraphNode node, bint is_fwd):
self._h_node = node._h_node
self._h_graph = graph_node_get_graph(node._h_node)
self._is_fwd = is_fwd
self._query_fn = driver_get_succs if is_fwd else driver_get_preds

cdef inline void _resolve_edge(
self, GraphNode other,
cydriver.CUgraphNode* c_from,
cydriver.CUgraphNode* c_to) noexcept:
if self._is_fwd:
c_from[0] = as_cu(self._h_node)
c_to[0] = as_cu(other._h_node)
else:
c_from[0] = as_cu(other._h_node)
c_to[0] = as_cu(self._h_node)

cdef list query(self):
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
if c_node == NULL:
return []
cdef size_t count = 0
with nogil:
HANDLE_RETURN(self._query_fn(c_node, NULL, &count))
if count == 0:
return []
cdef vector[cydriver.CUgraphNode] nodes_vec
nodes_vec.resize(count)
with nogil:
HANDLE_RETURN(self._query_fn(
c_node, nodes_vec.data(), &count))
return [GraphNode._create(self._h_graph, nodes_vec[i])
for i in range(count)]

cdef Py_ssize_t count(self):
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
if c_node == NULL:
return 0
cdef size_t n = 0
with nogil:
HANDLE_RETURN(self._query_fn(c_node, NULL, &n))
return <Py_ssize_t>n

cdef void add_edge(self, GraphNode other):
cdef cydriver.CUgraphNode c_from, c_to
self._resolve_edge(other, &c_from, &c_to)
with nogil:
HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1))

cdef void add_edges(self, list nodes):
cdef size_t n = len(nodes)
cdef vector[cydriver.CUgraphNode] from_vec
cdef vector[cydriver.CUgraphNode] to_vec
from_vec.resize(n)
to_vec.resize(n)
cdef size_t i
for i in range(n):
self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i])
with nogil:
HANDLE_RETURN(driver_add_edges(
as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))

cdef void remove_edge(self, GraphNode other):
cdef cydriver.CUgraphNode c_from, c_to
self._resolve_edge(other, &c_from, &c_to)
with nogil:
HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1))

cdef void remove_edges(self, list nodes):
cdef size_t n = len(nodes)
cdef vector[cydriver.CUgraphNode] from_vec
cdef vector[cydriver.CUgraphNode] to_vec
from_vec.resize(n)
to_vec.resize(n)
cdef size_t i
for i in range(n):
self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i])
with nogil:
HANDLE_RETURN(driver_remove_edges(
as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))


# ---- driver wrappers: absorb CUDA version differences ----

cdef inline cydriver.CUresult driver_get_preds(
cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
size_t* count) noexcept nogil:
IF CUDA_CORE_BUILD_MAJOR >= 13:
return cydriver.cuGraphNodeGetDependencies(node, out, NULL, count)
ELSE:
return cydriver.cuGraphNodeGetDependencies(node, out, count)


cdef inline cydriver.CUresult driver_get_succs(
cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
size_t* count) noexcept nogil:
IF CUDA_CORE_BUILD_MAJOR >= 13:
return cydriver.cuGraphNodeGetDependentNodes(node, out, NULL, count)
ELSE:
return cydriver.cuGraphNodeGetDependentNodes(node, out, count)


cdef inline cydriver.CUresult driver_add_edges(
cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
IF CUDA_CORE_BUILD_MAJOR >= 13:
return cydriver.cuGraphAddDependencies(
graph, from_arr, to_arr, NULL, count)
ELSE:
return cydriver.cuGraphAddDependencies(
graph, from_arr, to_arr, count)


cdef inline cydriver.CUresult driver_remove_edges(
cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
IF CUDA_CORE_BUILD_MAJOR >= 13:
return cydriver.cuGraphRemoveDependencies(
graph, from_arr, to_arr, NULL, count)
ELSE:
return cydriver.cuGraphRemoveDependencies(
graph, from_arr, to_arr, count)
16 changes: 8 additions & 8 deletions cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,12 @@ cdef class GraphDef:
with nogil:
HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(as_cu(self._h_graph), c_path, flags))
def nodes(self) -> tuple:
def nodes(self) -> set:
"""Return all nodes in the graph.

Returns
-------
tuple of GraphNode
set of GraphNode
All nodes in the graph.
"""
cdef size_t num_nodes = 0
Expand All @@ -328,21 +328,21 @@ cdef class GraphDef:
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes))
if num_nodes == 0:
return ()
return set()
cdef vector[cydriver.CUgraphNode] nodes_vec
nodes_vec.resize(num_nodes)
with nogil:
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
return tuple(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))
return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))
def edges(self) -> tuple:
def edges(self) -> set:
"""Return all edges in the graph as (from_node, to_node) pairs.

Returns
-------
tuple of tuple
set of tuple
Each element is a (from_node, to_node) pair representing
a dependency edge in the graph.
"""
Expand All @@ -355,7 +355,7 @@ cdef class GraphDef:
HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges))
if num_edges == 0:
return ()
return set()
cdef vector[cydriver.CUgraphNode] from_nodes
cdef vector[cydriver.CUgraphNode] to_nodes
Expand All @@ -369,7 +369,7 @@ cdef class GraphDef:
HANDLE_RETURN(cydriver.cuGraphGetEdges(
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))
return tuple(
return set(
(GraphNode._create(self._h_graph, from_nodes[i]),
GraphNode._create(self._h_graph, to_nodes[i]))
for i in range(num_edges)
Expand Down
2 changes: 0 additions & 2 deletions cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ from cuda.core._resource_handles cimport GraphHandle, GraphNodeHandle
cdef class GraphNode:
cdef:
GraphNodeHandle _h_node
tuple _pred_cache
tuple _succ_cache
object __weakref__

@staticmethod
Expand Down
Loading
Loading