Skip to content

Commit 8554d30

Browse files
committed
Rename AdjacencySet to AdjacencySetProxy, add bulk ops and safety guards
Rename class and file to AdjacencySetProxy to clarify write-through semantics. Add bulk-efficient clear(), __isub__(), __ior__() overrides and remove_edges() on the Cython core. Guard GraphNode.discard() against double-destroy via membership check. Filter duplicates in update(). Add error-path tests for wrong types, cross-graph edges, and self-edges. Made-with: Cursor
1 parent b27dd93 commit 8554d30

File tree

3 files changed

+144
-29
lines changed

3 files changed

+144
-29
lines changed

cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx renamed to cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ from collections.abc import MutableSet
2020

2121
# ---- Python MutableSet wrapper ----------------------------------------------
2222

23-
class AdjacencySet(MutableSet):
24-
"""Mutable set-like view of a node's predecessors or successors."""
23+
class AdjacencySetProxy(MutableSet):
24+
"""Mutable set proxy for a node's predecessors or successors. Mutations
25+
write through to the underlying CUDA graph."""
2526

2627
__slots__ = ("_core",)
2728

2829
def __init__(self, node, bint is_fwd):
2930
self._core = _AdjacencySetCore(node, is_fwd)
3031

32+
# Used by operators such as &|^ to create non-proxy views when needed.
3133
@classmethod
3234
def _from_iterable(cls, it):
3335
return set(it)
@@ -62,6 +64,22 @@ class AdjacencySet(MutableSet):
6264

6365
# --- override for bulk efficiency ---
6466

67+
def clear(self):
68+
"""Remove all edges in a single driver call."""
69+
members = (<_AdjacencySetCore>self._core).query()
70+
if members:
71+
(<_AdjacencySetCore>self._core).remove_edges(members)
72+
73+
def __isub__(self, it):
74+
"""Remove edges to all nodes in *it* in a single driver call."""
75+
if it is self:
76+
self.clear()
77+
else:
78+
to_remove = [v for v in it if isinstance(v, GraphNode) and v in self]
79+
if to_remove:
80+
(<_AdjacencySetCore>self._core).remove_edges(to_remove)
81+
return self
82+
6583
def update(self, *others):
6684
"""Add edges to multiple nodes at once."""
6785
nodes = []
@@ -76,21 +94,28 @@ class AdjacencySet(MutableSet):
7694
if not isinstance(n, GraphNode):
7795
raise TypeError(
7896
f"expected GraphNode, got {type(n).__name__}")
79-
(<_AdjacencySetCore>self._core).add_edges(nodes)
97+
new = [n for n in nodes if n not in self]
98+
if new:
99+
(<_AdjacencySetCore>self._core).add_edges(new)
100+
101+
def __ior__(self, it):
102+
"""Add edges to all nodes in *it* in a single driver call."""
103+
self.update(it)
104+
return self
80105

81106
def __repr__(self):
82107
return "{" + ", ".join(repr(n) for n in self) + "}"
83108

84109

85-
# ---- cdef core holding function pointer ------------------------------------
110+
# ---- cdef core holding a function pointer ------------------------------------
86111

87-
# Signature shared by _get_preds and _get_succs.
112+
# Signature shared by driver_get_preds and driver_get_succs.
88113
ctypedef cydriver.CUresult (*_adj_fn_t)(
89114
cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil
90115

91116

92117
cdef class _AdjacencySetCore:
93-
"""Cythonized core implementing AdjacencySet"""
118+
"""Cythonized core implementing AdjacencySetProxy"""
94119
cdef:
95120
GraphNodeHandle _h_node
96121
GraphHandle _h_graph
@@ -101,7 +126,7 @@ cdef class _AdjacencySetCore:
101126
self._h_node = node._h_node
102127
self._h_graph = graph_node_get_graph(node._h_node)
103128
self._is_fwd = is_fwd
104-
self._query_fn = _get_succs if is_fwd else _get_preds
129+
self._query_fn = driver_get_succs if is_fwd else driver_get_preds
105130

106131
cdef inline void _resolve_edge(
107132
self, GraphNode other,
@@ -144,15 +169,28 @@ cdef class _AdjacencySetCore:
144169
cdef cydriver.CUgraphNode c_from, c_to
145170
self._resolve_edge(other, &c_from, &c_to)
146171
with nogil:
147-
HANDLE_RETURN(_add_edge(as_cu(self._h_graph), &c_from, &c_to, 1))
172+
HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1))
173+
174+
cdef void add_edges(self, list nodes):
175+
cdef size_t n = len(nodes)
176+
cdef vector[cydriver.CUgraphNode] from_vec
177+
cdef vector[cydriver.CUgraphNode] to_vec
178+
from_vec.resize(n)
179+
to_vec.resize(n)
180+
cdef size_t i
181+
for i in range(n):
182+
self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i])
183+
with nogil:
184+
HANDLE_RETURN(driver_add_edges(
185+
as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))
148186

149187
cdef void remove_edge(self, GraphNode other):
150188
cdef cydriver.CUgraphNode c_from, c_to
151189
self._resolve_edge(other, &c_from, &c_to)
152190
with nogil:
153-
HANDLE_RETURN(_remove_edge(as_cu(self._h_graph), &c_from, &c_to, 1))
191+
HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1))
154192

155-
cdef void add_edges(self, list nodes):
193+
cdef void remove_edges(self, list nodes):
156194
cdef size_t n = len(nodes)
157195
cdef vector[cydriver.CUgraphNode] from_vec
158196
cdef vector[cydriver.CUgraphNode] to_vec
@@ -162,13 +200,13 @@ cdef class _AdjacencySetCore:
162200
for i in range(n):
163201
self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i])
164202
with nogil:
165-
HANDLE_RETURN(_add_edge(
203+
HANDLE_RETURN(driver_remove_edges(
166204
as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))
167205

168206

169207
# ---- driver wrappers: absorb CUDA version differences ----
170208

171-
cdef cydriver.CUresult _get_preds(
209+
cdef inline cydriver.CUresult driver_get_preds(
172210
cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
173211
size_t* count) noexcept nogil:
174212
IF CUDA_CORE_BUILD_MAJOR >= 13:
@@ -177,7 +215,7 @@ cdef cydriver.CUresult _get_preds(
177215
return cydriver.cuGraphNodeGetDependencies(node, out, count)
178216

179217

180-
cdef cydriver.CUresult _get_succs(
218+
cdef inline cydriver.CUresult driver_get_succs(
181219
cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
182220
size_t* count) noexcept nogil:
183221
IF CUDA_CORE_BUILD_MAJOR >= 13:
@@ -186,7 +224,7 @@ cdef cydriver.CUresult _get_succs(
186224
return cydriver.cuGraphNodeGetDependentNodes(node, out, count)
187225

188226

189-
cdef cydriver.CUresult _add_edge(
227+
cdef inline cydriver.CUresult driver_add_edges(
190228
cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
191229
cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
192230
IF CUDA_CORE_BUILD_MAJOR >= 13:
@@ -197,7 +235,7 @@ cdef cydriver.CUresult _add_edge(
197235
graph, from_arr, to_arr, count)
198236

199237

200-
cdef cydriver.CUresult _remove_edge(
238+
cdef inline cydriver.CUresult driver_remove_edges(
201239
cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
202240
cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
203241
IF CUDA_CORE_BUILD_MAJOR >= 13:

cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ from cuda.core._graph._utils cimport (
5757
)
5858

5959
from cuda.core import Device
60-
from cuda.core._graph._graph_def._adjacency_set import AdjacencySet
60+
from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy
6161
from cuda.core._utils.cuda_utils import driver, handle_return
6262

6363

@@ -124,30 +124,35 @@ cdef class GraphNode:
124124
return as_py(self._h_node)
125125

126126
def discard(self):
127-
"""Discard this node and remove all its edges from the parent graph."""
127+
"""Discard this node and remove all its edges from the parent graph.
128+
129+
Safe to call on an already-discarded node (no-op).
130+
"""
131+
if self not in self.graph.nodes():
132+
return
128133
cdef cydriver.CUgraphNode node = as_cu(self._h_node)
129134
with nogil:
130135
HANDLE_RETURN(cydriver.cuGraphDestroyNode(node))
131136

132137
@property
133138
def pred(self):
134139
"""A mutable set-like view of this node's predecessors."""
135-
return AdjacencySet(self, False)
140+
return AdjacencySetProxy(self, False)
136141

137142
@pred.setter
138143
def pred(self, value):
139-
p = AdjacencySet(self, False)
144+
p = AdjacencySetProxy(self, False)
140145
p.clear()
141146
p.update(value)
142147

143148
@property
144149
def succ(self):
145150
"""A mutable set-like view of this node's successors."""
146-
return AdjacencySet(self, True)
151+
return AdjacencySetProxy(self, True)
147152

148153
@succ.setter
149154
def succ(self, value):
150-
s = AdjacencySet(self, True)
155+
s = AdjacencySetProxy(self, True)
151156
s.clear()
152157
s.update(value)
153158

cuda_core/tests/graph/test_graphdef_mutation.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
"""Tests for mutating a graph definition (edge changes, node removal)."""
55

66
import numpy as np
7+
import pytest
78
from helpers.collection_interface_testers import assert_mutable_set_interface
89
from helpers.graph_kernels import compile_parallel_kernels
910
from helpers.marks import requires_module
1011

12+
from cuda.bindings import driver
1113
from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource
1214
from cuda.core._graph._graph_def import GraphDef, KernelNode, MemsetNode
15+
from cuda.core._utils.cuda_utils import CUDAError
1316

1417

1518
class YRig:
@@ -29,6 +32,9 @@ class YRig:
2932
Node r computes result ``combine(R, A, B) = (A << 16) | (B & 0xFFFF)``,
3033
encoding both arms' results into a single int. j is a joining (empty) node
3134
preceeding r.
35+
36+
The affine operation a * m + b is noncommutative, so we can be sure the
37+
graph has exactly the topology we expect by checking the final value.
3238
"""
3339

3440
def __init__(self):
@@ -148,8 +154,8 @@ def test_baseline(self, init_cuda):
148154
rig.close()
149155

150156
def test_discard_a1(self, init_cuda):
151-
"""Discard a1 (creates a race). Arm b yields the expected value, and the
152-
final result is correctly ordered after b."""
157+
"""Discard a1 (creates a race on arm a). Arm b yields the expected
158+
value, and the final step is correctly ordered after b completes."""
153159
rig = YRig()
154160
rig.a[1].discard()
155161
rig.run()
@@ -204,7 +210,7 @@ def test_insert_b(self, init_cuda):
204210

205211

206212
def test_adjacency_set_interface(init_cuda):
207-
"""Exercise every MutableSet method on AdjacencySet."""
213+
"""Exercise every MutableSet method on AdjacencySetProxy."""
208214
g = GraphDef()
209215
hub = g.join()
210216
items = [g.join() for _ in range(5)]
@@ -260,6 +266,67 @@ def test_adjacency_set_property_setter(init_cuda):
260266
assert hub.pred == set()
261267

262268

269+
def test_discarded_node(init_cuda):
270+
"""Test uses of discarded nodes."""
271+
mr = LegacyPinnedMemoryResource()
272+
buf = mr.allocate(4)
273+
arr = np.from_dlpack(buf).view(np.int32)
274+
arr[:] = 0
275+
ptr = arr[0:].ctypes.data
276+
277+
g = GraphDef()
278+
a = g.memset(ptr, 0, 4)
279+
b = a.memset(ptr, 42, 4)
280+
281+
assert b in g.nodes()
282+
assert (a, b) in g.edges()
283+
284+
b.discard()
285+
286+
# b is removed from the graph but still usable
287+
assert b not in g.nodes()
288+
assert (a, b) not in g.edges()
289+
assert isinstance(b, MemsetNode)
290+
assert b.type == driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_KERNEL
291+
assert b.pred == set()
292+
assert b.succ == set()
293+
assert b.handle != 0
294+
assert b.dptr == ptr
295+
assert b.value == 42
296+
assert b.width == 4
297+
298+
# Repeated discard succeeds quietly.
299+
b.discard()
300+
301+
302+
def test_add_wrong_type(init_cuda):
303+
"""Adding a non-GraphNode raises TypeError."""
304+
g = GraphDef()
305+
node = g.join()
306+
with pytest.raises(TypeError, match="expected GraphNode"):
307+
node.succ.add("not a node")
308+
with pytest.raises(TypeError, match="expected GraphNode"):
309+
node.succ.add(42)
310+
311+
312+
def test_cross_graph_edge(init_cuda):
313+
"""Adding an edge to a node from a different graph raises CUDAError."""
314+
g1 = GraphDef()
315+
g2 = GraphDef()
316+
a = g1.join()
317+
b = g2.join()
318+
with pytest.raises(CUDAError):
319+
a.succ.add(b)
320+
321+
322+
def test_self_edge(init_cuda):
323+
"""Adding a self-edge raises CUDAError."""
324+
g = GraphDef()
325+
node = g.join()
326+
with pytest.raises(CUDAError):
327+
node.succ.add(node)
328+
329+
263330
@requires_module(np, "2.1")
264331
def test_convert_linear_to_fan_in(init_cuda):
265332
"""Chain four computations sequentially, then rewire so all pairs run in
@@ -293,17 +360,22 @@ def test_convert_linear_to_fan_in(init_cuda):
293360
g = GraphDef()
294361
prev = g
295362
for i, val in enumerate(values):
296-
prev = prev.memset(ptrs[i], val, 1).launch(config, affine, ptrs[i], 2, 1)
363+
prev = prev.memset(ptrs[i], val, 1)
364+
prev = prev.launch(config, affine, ptrs[i], 2, 1)
297365
reduce_node = g.launch(config, reduce_kern, ptrs[4], ptrs[0], 4)
298366

299367
# Rewire:
300368
# - drop preds from memsets
301-
# - connect results to reduction
369+
# - connect kernel launches to the reduction
370+
assert len(g.edges()) == 7
371+
302372
for node in g.nodes():
303-
if isinstance(node, MemsetNode) and len(node.pred):
304-
node.pred = set()
373+
if isinstance(node, MemsetNode):
374+
node.pred.clear()
305375
elif isinstance(node, KernelNode) and node != reduce_node:
306-
node.succ = {reduce_node}
376+
node.succ.add(reduce_node)
377+
378+
assert len(g.edges()) == 8
307379

308380
stream = Device().create_stream()
309381
graph = g.instantiate()

0 commit comments

Comments
 (0)