Skip to content

Commit 2b67ffe

Browse files
authored
Merge pull request #31 from mtreinish/add-retworkx-support
Add support for using retworkx as input and output from Matching
2 parents b64cca6 + 73bfb6f commit 2b67ffe

File tree

3 files changed

+270
-5
lines changed

3 files changed

+270
-5
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,12 @@ def build_extension(self, ext):
127127
packages=find_packages("src"),
128128
package_dir={'':'src'},
129129
cmdclass=dict(build_ext=CMakeBuild),
130-
install_requires=['scipy', 'numpy', 'networkx','matplotlib'],
130+
install_requires=['scipy', 'numpy', 'networkx','retworkx>=0.11.0','matplotlib'],
131131
classifiers=[
132132
"License :: OSI Approved :: Apache Software License"
133133
],
134134
long_description=long_description,
135135
long_description_content_type='text/markdown',
136136
python_requires='>=3',
137137
zip_safe=False,
138-
)
138+
)

src/pymatching/matching.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import matplotlib.cbook
1919
import numpy as np
2020
import networkx as nx
21+
import retworkx as rx
2122
import scipy
2223
from scipy.sparse import csc_matrix
2324

@@ -60,7 +61,7 @@ class Matching:
6061
fault ids, boundaries and error probabilities.
6162
"""
6263
def __init__(self,
63-
H: Union[scipy.sparse.spmatrix, np.ndarray, nx.Graph, List[List[int]]] = None,
64+
H: Union[scipy.sparse.spmatrix, np.ndarray, rx.PyGraph, nx.Graph, List[List[int]]] = None,
6465
spacelike_weights: Union[float, np.ndarray, List[float]] = None,
6566
error_probabilities: Union[float, np.ndarray, List[float]] = None,
6667
repetitions: int = None,
@@ -156,7 +157,7 @@ def __init__(self,
156157
self.matching_graph = MatchingGraph()
157158
if H is None:
158159
return
159-
if not isinstance(H, nx.Graph):
160+
if not isinstance(H, (nx.Graph, rx.PyGraph)):
160161
try:
161162
H = csc_matrix(H)
162163
except TypeError:
@@ -165,8 +166,10 @@ def __init__(self,
165166
self.load_from_check_matrix(H, spacelike_weights, error_probabilities,
166167
repetitions, timelike_weights, measurement_error_probabilities,
167168
**kwargs)
168-
else:
169+
elif isinstance(H, nx.Graph):
169170
self.load_from_networkx(H)
171+
else:
172+
self.load_from_retworkx(H)
170173
if precompute_shortest_paths:
171174
self.matching_graph.compute_all_pairs_shortest_paths()
172175

@@ -310,6 +313,74 @@ def load_from_networkx(self, graph: nx.Graph) -> None:
310313
g.add_edge(u, v, fault_ids, weight, e_prob, 0 <= e_prob <= 1)
311314
self.matching_graph = g
312315

316+
def load_from_retworkx(self, graph: rx.PyGraph) -> None:
317+
r"""
318+
Load a matching graph from a retworkX graph
319+
320+
Parameters
321+
----------
322+
graph : retworkx.PyGraph
323+
Each edge in the retworkx graph can have dictionary payload with keys
324+
``fault_ids``, ``weight`` and ``error_probability``. ``fault_ids`` should be
325+
an int or a set of ints. Each fault id corresponds to a self-inverse fault
326+
that is flipped when the corresponding edge is flipped. These self-inverse
327+
faults could correspond to physical Pauli errors (physical frame changes)
328+
or to the logical observables that are flipped by the fault
329+
(a logical frame change, equivalent to an obersvable ID in an error instruction in a Stim
330+
detector error model). The `fault_ids` attribute was previously named `qubit_id` in an
331+
earlier version of PyMatching, and `qubit_id` is still accepted instead of `fault_ids` in order
332+
to maintain backward compatibility.
333+
Each ``weight`` attribute should be a non-negative float. If
334+
every edge is assigned an error_probability between zero and one,
335+
then the ``add_noise`` method can be used to simulate noise and
336+
flip edges independently in the graph.
337+
338+
Examples
339+
--------
340+
>>> import pymatching
341+
>>> import retworkx as rx
342+
>>> import math
343+
>>> g = rx.PyGraph()
344+
>>> matching = g.add_nodes_from([{} for _ in range(3)])
345+
>>> edge_a =g.add_edge(0, 1, dict(fault_ids=0, weight=math.log((1-0.1)/0.1), error_probability=0.1))
346+
>>> edge_b = g.add_edge(1, 2, dict(fault_ids=1, weight=math.log((1-0.15)/0.15), error_probability=0.15))
347+
>>> g[0]['is_boundary'] = True
348+
>>> g[2]['is_boundary'] = True
349+
>>> m = pymatching.Matching(g)
350+
>>> m
351+
<pymatching.Matching object with 1 detector, 2 boundary nodes, and 2 edges>
352+
"""
353+
if not isinstance(graph, rx.PyGraph):
354+
raise TypeError("G must be a retworkx graph")
355+
boundary = {i for i in graph.node_indices() if graph[i].get("is_boundary", False)}
356+
num_nodes = len(graph)
357+
g = MatchingGraph(self.num_detectors, boundary)
358+
for (u, v, attr) in graph.weighted_edge_list():
359+
u, v = int(u), int(v)
360+
if "fault_ids" in attr and "qubit_id" in attr:
361+
raise ValueError("Both `fault_ids` and `qubit_id` were provided as edge attributes, however use "
362+
"of `qubit_id` has been deprecated in favour of `fault_ids`. Please only supply "
363+
"`fault_ids` as an edge attribute.")
364+
if "fault_ids" not in attr and "qubit_id" in attr:
365+
fault_ids = attr["qubit_id"] # Still accept qubit_id as well for now
366+
else:
367+
fault_ids = attr.get("fault_ids", set())
368+
if isinstance(fault_ids, (int, np.integer)):
369+
fault_ids = {int(fault_ids)} if fault_ids != -1 else set()
370+
else:
371+
try:
372+
fault_ids = set(fault_ids)
373+
if not all(isinstance(q, (int, np.integer)) for q in fault_ids):
374+
raise ValueError("fault_ids must be a set of ints, not {}".format(fault_ids))
375+
except:
376+
raise ValueError(
377+
"fault_ids property must be an int or a set of int"\
378+
" (or convertible to a set), not {}".format(fault_ids))
379+
weight = attr.get("weight", 1) # Default weight is 1 if not provided
380+
e_prob = attr.get("error_probability", -1)
381+
g.add_edge(u, v, fault_ids, weight, e_prob, 0 <= e_prob <= 1)
382+
self.matching_graph = g
383+
313384
def load_from_check_matrix(self,
314385
H: Union[scipy.sparse.spmatrix, np.ndarray, List[List[int]]],
315386
spacelike_weights: Union[float, np.ndarray, List[float]] = None,
@@ -748,6 +819,28 @@ def to_networkx(self) -> nx.Graph:
748819
is_boundary = i in boundary
749820
G.nodes[i]['is_boundary'] = is_boundary
750821
return G
822+
823+
def to_retworkx(self) -> rx.PyGraph:
824+
"""Convert to retworkx graph
825+
826+
Returns a retworkx graph object corresponding to the matching graph. Each edge
827+
payload is a ``dict`` with keys `fault_ids`, `weight` and `error_probability` and
828+
each node has a ``dict`` payload with the key ``is_boundary`` and the value is
829+
a boolean.
830+
831+
Returns
832+
-------
833+
retworkx.PyGraph
834+
retworkx graph corresponding to the matching graph
835+
"""
836+
G = rx.PyGraph(multigraph=False)
837+
G.add_nodes_from([{} for _ in range(self.num_nodes)])
838+
G.extend_from_weighted_edge_list(self.edges())
839+
boundary = self.boundary
840+
for i in G.node_indices():
841+
is_boundary = i in boundary
842+
G[i]['is_boundary'] = is_boundary
843+
return G
751844

752845
def draw(self) -> None:
753846
"""Draw the matching graph using matplotlib

tests/test_matching.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from scipy.sparse import csc_matrix, csr_matrix
1818
import pytest
1919
import networkx as nx
20+
import retworkx as rx
2021
import matplotlib.pyplot as plt
2122

2223
from pymatching._cpp_mwpm import MatchingGraph
@@ -60,6 +61,21 @@ def test_boundary_from_networkx():
6061
assert np.array_equal(m.decode(np.array([0,1,1,0])), np.array([0,0,1,0,0]))
6162
assert np.array_equal(m.decode(np.array([0,0,1,0])), np.array([0,0,0,1,1]))
6263

64+
def test_boundary_from_retworkx():
65+
g = rx.PyGraph()
66+
g.add_nodes_from([{} for _ in range(5)])
67+
g.add_edge(4,0, dict(fault_ids=0))
68+
g.add_edge(0,1, dict(fault_ids=1))
69+
g.add_edge(1,2, dict(fault_ids=2))
70+
g.add_edge(2,3, dict(fault_ids=3))
71+
g.add_edge(3,4, dict(fault_ids=4))
72+
g[4]['is_boundary'] = True
73+
m = Matching(g)
74+
assert m.boundary == {4}
75+
assert np.array_equal(m.decode(np.array([1,0,0,0])), np.array([1,0,0,0,0]))
76+
assert np.array_equal(m.decode(np.array([0,1,0,0])), np.array([1,1,0,0,0]))
77+
assert np.array_equal(m.decode(np.array([0,1,1,0])), np.array([0,0,1,0,0]))
78+
assert np.array_equal(m.decode(np.array([0,0,1,0])), np.array([0,0,0,1,1]))
6379

6480
def test_boundaries_from_networkx():
6581
g = nx.Graph()
@@ -78,6 +94,23 @@ def test_boundaries_from_networkx():
7894
assert np.array_equal(m.decode(np.array([0,0,1,1,0])), np.array([0,0,1,0,0]))
7995
assert np.array_equal(m.decode(np.array([0,0,0,1,0])), np.array([0,0,0,1,1]))
8096

97+
def test_boundaries_from_retworkx():
98+
g = rx.PyGraph()
99+
g.add_nodes_from([{} for _ in range(6)])
100+
g.add_edge(0,1, dict(fault_ids=0))
101+
g.add_edge(1,2, dict(fault_ids=1))
102+
g.add_edge(2,3, dict(fault_ids=2))
103+
g.add_edge(3,4, dict(fault_ids=3))
104+
g.add_edge(4,5, dict(fault_ids=4))
105+
g.add_edge(0,5, dict(fault_ids=-1, weight=0.0))
106+
g.nodes()[0]['is_boundary'] = True
107+
g.nodes()[5]['is_boundary'] = True
108+
m = Matching(g)
109+
assert m.boundary == {0, 5}
110+
assert np.array_equal(m.decode(np.array([0,1,0,0,0,0])), np.array([1,0,0,0,0]))
111+
assert np.array_equal(m.decode(np.array([0,0,1,0,0])), np.array([1,1,0,0,0]))
112+
assert np.array_equal(m.decode(np.array([0,0,1,1,0])), np.array([0,0,1,0,0]))
113+
assert np.array_equal(m.decode(np.array([0,0,0,1,0])), np.array([0,0,0,1,1]))
81114

82115
def test_nonzero_matrix_elements_not_one_raises_value_error():
83116
H = csr_matrix(np.array([[0,1.01,1.01],[1.01,1.01,0]]))
@@ -200,6 +233,79 @@ def test_mwpm_from_networkx():
200233
assert(m.matching_graph.shortest_path(0,2) == [0,2])
201234

202235

236+
def test_unweighted_stabiliser_graph_from_retworkx():
237+
w = rx.PyGraph()
238+
w.add_nodes_from([{} for _ in range(6)])
239+
w.add_edge(0, 1, dict(fault_ids=0, weight=7.0))
240+
w.add_edge(0, 5, dict(fault_ids=1, weight=14.0))
241+
w.add_edge(0, 2, dict(fault_ids=2, weight=9.0))
242+
w.add_edge(1, 2, dict(fault_ids=-1, weight=10.0))
243+
w.add_edge(1, 3, dict(fault_ids=3, weight=15.0))
244+
w.add_edge(2, 5, dict(fault_ids=4, weight=2.0))
245+
w.add_edge(2, 3, dict(fault_ids=-1, weight=11.0))
246+
w.add_edge(3, 4, dict(fault_ids=5, weight=6.0))
247+
w.add_edge(4, 5, dict(fault_ids=6, weight=9.0))
248+
m = Matching(w)
249+
assert(m.num_fault_ids == 7)
250+
assert(m.num_detectors == 6)
251+
assert(m.matching_graph.shortest_path(3, 5) == [3, 2, 5])
252+
assert(m.matching_graph.distance(5, 0) == pytest.approx(11.0))
253+
assert(np.array_equal(
254+
m.decode(np.array([1,0,1,0,0,0])),
255+
np.array([0,0,1,0,0,0,0]))
256+
)
257+
with pytest.raises(ValueError):
258+
m.decode(np.array([1,1,0]))
259+
with pytest.raises(ValueError):
260+
m.decode(np.array([1,1,1,0,0,0]))
261+
assert(np.array_equal(
262+
m.decode(np.array([1,0,0,0,0,1])),
263+
np.array([0,0,1,0,1,0,0]))
264+
)
265+
assert(np.array_equal(
266+
m.decode(np.array([0,1,0,0,0,1])),
267+
np.array([0,0,0,0,1,0,0]))
268+
)
269+
270+
271+
def test_mwpm_from_retworkx():
272+
g = rx.PyGraph()
273+
g.add_nodes_from([{} for _ in range(3)])
274+
g.add_edge(0, 1, dict(fault_ids=0))
275+
g.add_edge(0, 2, dict(fault_ids=1))
276+
g.add_edge(1, 2, dict(fault_ids=2))
277+
m = Matching(g)
278+
assert(isinstance(m.matching_graph, MatchingGraph))
279+
assert(m.num_detectors == 3)
280+
assert(m.num_fault_ids == 3)
281+
assert(m.matching_graph.distance(0,2) == 1)
282+
assert(m.matching_graph.shortest_path(0,2) == [0,2])
283+
284+
g = rx.PyGraph()
285+
g.add_nodes_from([{} for _ in range(3)])
286+
g.add_edge(0, 1, {})
287+
g.add_edge(0, 2, {})
288+
g.add_edge(1, 2, {})
289+
m = Matching(g)
290+
assert(isinstance(m.matching_graph, MatchingGraph))
291+
assert(m.num_detectors == 3)
292+
assert(m.num_fault_ids == 0)
293+
assert(m.matching_graph.distance(0,2) == 1)
294+
assert(m.matching_graph.shortest_path(0,2) == [0,2])
295+
296+
g = rx.PyGraph()
297+
g.add_nodes_from([{} for _ in range(3)])
298+
g.add_edge(0, 1, dict(weight=1.5))
299+
g.add_edge(0, 2, dict(weight=1.7))
300+
g.add_edge(1, 2, dict(weight=1.2))
301+
m = Matching(g)
302+
assert(isinstance(m.matching_graph, MatchingGraph))
303+
assert(m.num_detectors == 3)
304+
assert(m.num_fault_ids == 0)
305+
assert(m.matching_graph.distance(0,2) == pytest.approx(1.7))
306+
assert(m.matching_graph.shortest_path(0,2) == [0,2])
307+
308+
203309
def test_repr():
204310
g = nx.Graph()
205311
g.add_edge(0, 1, fault_ids=0)
@@ -253,6 +359,47 @@ def test_qubit_id_accepted_via_networkx():
253359
assert es == expected_edges
254360

255361

362+
def test_matching_edges_from_retworkx():
363+
g = rx.PyGraph()
364+
g.add_nodes_from([{} for _ in range(4)])
365+
g.add_edge(0, 1, dict(fault_ids=0, weight=1.1, error_probability=0.1))
366+
g.add_edge(1, 2, dict(fault_ids=1, weight=2.1, error_probability=0.2))
367+
g.add_edge(2, 3, dict(fault_ids={2,3}, weight=0.9, error_probability=0.3))
368+
g[0]['is_boundary'] = True
369+
g[3]['is_boundary'] = True
370+
g.add_edge(0, 3, dict(weight=0.0))
371+
m = Matching(g)
372+
es = list(m.edges())
373+
expected_edges = [
374+
(0,1,{'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}),
375+
(1,2,{'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}),
376+
(2,3,{'fault_ids': {2,3}, 'weight': 0.9, 'error_probability': 0.3}),
377+
(0,3,{'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}),
378+
]
379+
print(es)
380+
assert es == expected_edges
381+
382+
383+
def test_qubit_id_accepted_via_retworkx():
384+
g = rx.PyGraph()
385+
g.add_nodes_from([{} for _ in range(4)])
386+
g.add_edge(0, 1, dict(qubit_id=0, weight=1.1, error_probability=0.1))
387+
g.add_edge(1, 2, dict(qubit_id=1, weight=2.1, error_probability=0.2))
388+
g.add_edge(2, 3, dict(qubit_id={2, 3}, weight=0.9, error_probability=0.3))
389+
g[0]['is_boundary'] = True
390+
g[3]['is_boundary'] = True
391+
g.add_edge(0, 3, dict(weight=0.0))
392+
m = Matching(g)
393+
es = list(m.edges())
394+
expected_edges = [
395+
(0, 1, {'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}),
396+
(1, 2, {'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}),
397+
(2, 3, {'fault_ids': {2, 3}, 'weight': 0.9, 'error_probability': 0.3}),
398+
(0, 3, {'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}),
399+
]
400+
assert es == expected_edges
401+
402+
256403
def test_qubit_id_accepted_using_add_edge():
257404
m = Matching()
258405
m.add_edge(0, 1, qubit_id=0)
@@ -304,6 +451,31 @@ def test_matching_to_networkx():
304451
assert sorted(gedges) == sorted(g2edges)
305452

306453

454+
def test_matching_to_retworkx():
455+
g = rx.PyGraph()
456+
g.add_nodes_from([{} for _ in range(4)])
457+
g.add_edge(0, 1, dict(fault_ids={0}, weight=1.1, error_probability=0.1))
458+
g.add_edge(1, 2, dict(fault_ids={1}, weight=2.1, error_probability=0.2))
459+
g.add_edge(2, 3, dict(fault_ids={2,3}, weight=0.9, error_probability=0.3))
460+
g[0]['is_boundary'] = True
461+
g[3]['is_boundary'] = True
462+
g.add_edge(0, 3, dict(weight=0.0))
463+
m = Matching(g)
464+
465+
edge_0_3 = g.get_edge_data(0, 3)
466+
edge_0_3['fault_ids'] = set()
467+
edge_0_3['error_probability'] = -1.0
468+
g[1]['is_boundary'] = False
469+
g[2]['is_boundary'] = False
470+
471+
g2 = m.to_retworkx()
472+
473+
assert g.node_indices() == g2.node_indices()
474+
gedges = [({s,t},d) for (s, t, d) in g.weighted_edge_list()]
475+
g2edges = [({s,t},d) for (s, t, d) in g.weighted_edge_list()]
476+
assert sorted(gedges) == sorted(g2edges)
477+
478+
307479
def test_draw_matching():
308480
g = nx.Graph()
309481
g.add_edge(0, 1, fault_ids={0}, weight=1.1, error_probability=0.1)

0 commit comments

Comments
 (0)