From 9ef2ec29671f62fb80b834d6c3f942242c16a6c6 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Fri, 6 May 2022 13:43:55 -0400 Subject: [PATCH 1/4] Add support for using retworkx as input and output from Matching This commit adds support for using retworkx [1] as a graph library input and output in addition to NetworkX. This enables users to go from a retworkx PyGraph object to a PyMatching Matching object. The retworkx library is a higher performance general purpose graph library written in Rust that was inspired by NetworkX. While similar to NetworkX it is not a drop in replacement (see [2] for details of the differences). retworkx was originally written to accelerate the graph data structures and algorithms used internally by Qiskit and has since grown to be a high performance general graph library that is suitable for any application. Since it's adoption by Qiskit the library has gained some adoption in several QEC libraries such as flamingpy [3] and qtcodes [4], along with Qiskit's topological codes module in qiskit-ignis (which was recently spun off into a separate repository). [5] Being able to leverage retworkx graphs directly will simplify potentially using PyMatching in any of these packages. For this commit I opted to add retworkx to the requirements list, but it could easily be made an optional dependency if that would be preferrable so that two graph libraries aren't required to install PyMatching. [1] https://github.com/Qiskit/retworkx [2] https://qiskit.org/documentation/retworkx/networkx.html [3] https://github.com/XanaduAI/flamingpy [4] https://github.com/yaleqc/qtcodes [5] https://github.com/NCCR-SPIN/topological_codes --- setup.py | 4 +- src/pymatching/matching.py | 101 +++++++++++++++++++++- tests/test_matching.py | 172 +++++++++++++++++++++++++++++++++++++ 3 files changed, 272 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index cc900e1b..2e0acf48 100644 --- a/setup.py +++ b/setup.py @@ -127,7 +127,7 @@ def build_extension(self, ext): packages=find_packages("src"), package_dir={'':'src'}, cmdclass=dict(build_ext=CMakeBuild), - install_requires=['scipy', 'numpy', 'networkx','matplotlib'], + install_requires=['scipy', 'numpy', 'networkx','retworkx','matplotlib'], classifiers=[ "License :: OSI Approved :: Apache Software License" ], @@ -135,4 +135,4 @@ def build_extension(self, ext): long_description_content_type='text/markdown', python_requires='>=3', zip_safe=False, -) \ No newline at end of file +) diff --git a/src/pymatching/matching.py b/src/pymatching/matching.py index b03def22..ddf4108e 100644 --- a/src/pymatching/matching.py +++ b/src/pymatching/matching.py @@ -18,6 +18,7 @@ import matplotlib.cbook import numpy as np import networkx as nx +import retworkx as rx import scipy from scipy.sparse import csc_matrix @@ -60,7 +61,7 @@ class Matching: fault ids, boundaries and error probabilities. """ def __init__(self, - H: Union[scipy.sparse.spmatrix, np.ndarray, nx.Graph, List[List[int]]] = None, + H: Union[scipy.sparse.spmatrix, np.ndarray, rx.PyGraph, nx.Graph, List[List[int]]] = None, spacelike_weights: Union[float, np.ndarray, List[float]] = None, error_probabilities: Union[float, np.ndarray, List[float]] = None, repetitions: int = None, @@ -156,7 +157,7 @@ def __init__(self, self.matching_graph = MatchingGraph() if H is None: return - if not isinstance(H, nx.Graph): + if not isinstance(H, (nx.Graph, rx.PyGraph)): try: H = csc_matrix(H) except TypeError: @@ -165,8 +166,10 @@ def __init__(self, self.load_from_check_matrix(H, spacelike_weights, error_probabilities, repetitions, timelike_weights, measurement_error_probabilities, **kwargs) - else: + elif isinstance(H, nx.Graph): self.load_from_networkx(H) + else: + self.load_from_retworkx(H) if precompute_shortest_paths: self.matching_graph.compute_all_pairs_shortest_paths() @@ -310,6 +313,76 @@ def load_from_networkx(self, graph: nx.Graph) -> None: g.add_edge(u, v, fault_ids, weight, e_prob, 0 <= e_prob <= 1) self.matching_graph = g + def load_from_retworkx(self, graph: rx.PyGraph) -> None: + r""" + Load a matching graph from a retworkX graph + + Parameters + ---------- + graph : retworkx.PyGraph + Each edge in the retworkx graph can have dictionary payload with keys + ``fault_ids``, ``weight`` and ``error_probability``. ``fault_ids`` should be + an int or a set of ints. Each fault id corresponds to a self-inverse fault + that is flipped when the corresponding edge is flipped. These self-inverse + faults could correspond to physical Pauli errors (physical frame changes) + or to the logical observables that are flipped by the fault + (a logical frame change, equivalent to an obersvable ID in an error instruction in a Stim + detector error model). The `fault_ids` attribute was previously named `qubit_id` in an + earlier version of PyMatching, and `qubit_id` is still accepted instead of `fault_ids` in order + to maintain backward compatibility. + Each ``weight`` attribute should be a non-negative float. If + every edge is assigned an error_probability between zero and one, + then the ``add_noise`` method can be used to simulate noise and + flip edges independently in the graph. + + Examples + -------- + >>> import pymatching + >>> import retworkx as rx + >>> import math + >>> g = rx.PyGraph() + >>> g.add_nodes_from([{} for _ in range(3)]) + >>> g.add_edge(0, 1, dict(fault_ids=0, weight=math.log((1-0.1)/0.1), error_probability=0.1)) + >>> g.add_edge(1, 2, dict(fault_ids=1, weight=math.log((1-0.15)/0.15), error_probability=0.15)) + >>> g[0]['is_boundary'] = True + >>> g[2]['is_boundary'] = True + >>> m = pymatching.Matching(g) + >>> m + + """ + if not isinstance(graph, rx.PyGraph): + raise TypeError("G must be a retworkx graph") + boundary = {i for i in graph.node_indices() if graph[i].get("is_boundary", False)} + num_nodes = len(graph) + all_fault_ids = set() + g = MatchingGraph(self.num_detectors, boundary) + for (u, v, attr) in graph.weighted_edge_list(): + u, v = int(u), int(v) + if "fault_ids" in attr and "qubit_id" in attr: + raise ValueError("Both `fault_ids` and `qubit_id` were provided as edge attributes, however use " + "of `qubit_id` has been deprecated in favour of `fault_ids`. Please only supply " + "`fault_ids` as an edge attribute.") + if "fault_ids" not in attr and "qubit_id" in attr: + fault_ids = attr["qubit_id"] # Still accept qubit_id as well for now + else: + fault_ids = attr.get("fault_ids", set()) + if isinstance(fault_ids, (int, np.integer)): + fault_ids = {int(fault_ids)} if fault_ids != -1 else set() + else: + try: + fault_ids = set(fault_ids) + if not all(isinstance(q, (int, np.integer)) for q in fault_ids): + raise ValueError("fault_ids must be a set of ints, not {}".format(fault_ids)) + except: + raise ValueError( + "fault_ids property must be an int or a set of int"\ + " (or convertible to a set), not {}".format(fault_ids)) + all_fault_ids = all_fault_ids | fault_ids + weight = attr.get("weight", 1) # Default weight is 1 if not provided + e_prob = attr.get("error_probability", -1) + g.add_edge(u, v, fault_ids, weight, e_prob, 0 <= e_prob <= 1) + self.matching_graph = g + def load_from_check_matrix(self, H: Union[scipy.sparse.spmatrix, np.ndarray, List[List[int]]], spacelike_weights: Union[float, np.ndarray, List[float]] = None, @@ -748,6 +821,28 @@ def to_networkx(self) -> nx.Graph: is_boundary = i in boundary G.nodes[i]['is_boundary'] = is_boundary return G + + def to_retworkx(self) -> rx.PyGraph: + """Convert to retworkx graph + + Returns a retworkx graph object corresponding to the matching graph. Each edge + payload is a ``dict`` with keys `fault_ids`, `weight` and `error_probability` and + each node has a ``dict`` payload with the key ``is_boundary`` and the value is + a boolean. + + Returns + ------- + retworkx.PyGraph + retworkx graph corresponding to the matching graph + """ + G = rx.PyGraph(multigraph=False) + G.add_nodes_from([{} for _ in range(self.num_nodes)]) + G.extend_from_weighted_edge_list(self.edges()) + boundary = self.boundary + for i in G.node_indices(): + is_boundary = i in boundary + G[i]['is_boundary'] = is_boundary + return G def draw(self) -> None: """Draw the matching graph using matplotlib diff --git a/tests/test_matching.py b/tests/test_matching.py index 2cf224bc..9350240b 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -17,6 +17,7 @@ from scipy.sparse import csc_matrix, csr_matrix import pytest import networkx as nx +import retworkx as rx import matplotlib.pyplot as plt from pymatching._cpp_mwpm import MatchingGraph @@ -60,6 +61,21 @@ def test_boundary_from_networkx(): assert np.array_equal(m.decode(np.array([0,1,1,0])), np.array([0,0,1,0,0])) assert np.array_equal(m.decode(np.array([0,0,1,0])), np.array([0,0,0,1,1])) +def test_boundary_from_retworkx(): + g = rx.PyGraph() + g.add_nodes_from([{} for _ in range(5)]) + g.add_edge(4,0, dict(fault_ids=0)) + g.add_edge(0,1, dict(fault_ids=1)) + g.add_edge(1,2, dict(fault_ids=2)) + g.add_edge(2,3, dict(fault_ids=3)) + g.add_edge(3,4, dict(fault_ids=4)) + g[4]['is_boundary'] = True + m = Matching(g) + assert m.boundary == {4} + assert np.array_equal(m.decode(np.array([1,0,0,0])), np.array([1,0,0,0,0])) + assert np.array_equal(m.decode(np.array([0,1,0,0])), np.array([1,1,0,0,0])) + assert np.array_equal(m.decode(np.array([0,1,1,0])), np.array([0,0,1,0,0])) + assert np.array_equal(m.decode(np.array([0,0,1,0])), np.array([0,0,0,1,1])) def test_boundaries_from_networkx(): g = nx.Graph() @@ -78,6 +94,23 @@ def test_boundaries_from_networkx(): assert np.array_equal(m.decode(np.array([0,0,1,1,0])), np.array([0,0,1,0,0])) assert np.array_equal(m.decode(np.array([0,0,0,1,0])), np.array([0,0,0,1,1])) +def test_boundaries_from_retworkx(): + g = rx.PyGraph() + g.add_nodes_from([{} for _ in range(6)]) + g.add_edge(0,1, dict(fault_ids=0)) + g.add_edge(1,2, dict(fault_ids=1)) + g.add_edge(2,3, dict(fault_ids=2)) + g.add_edge(3,4, dict(fault_ids=3)) + g.add_edge(4,5, dict(fault_ids=4)) + g.add_edge(0,5, dict(fault_ids=-1, weight=0.0)) + g.nodes()[0]['is_boundary'] = True + g.nodes()[5]['is_boundary'] = True + m = Matching(g) + assert m.boundary == {0, 5} + assert np.array_equal(m.decode(np.array([0,1,0,0,0,0])), np.array([1,0,0,0,0])) + assert np.array_equal(m.decode(np.array([0,0,1,0,0])), np.array([1,1,0,0,0])) + assert np.array_equal(m.decode(np.array([0,0,1,1,0])), np.array([0,0,1,0,0])) + assert np.array_equal(m.decode(np.array([0,0,0,1,0])), np.array([0,0,0,1,1])) def test_nonzero_matrix_elements_not_one_raises_value_error(): H = csr_matrix(np.array([[0,1.01,1.01],[1.01,1.01,0]])) @@ -200,6 +233,79 @@ def test_mwpm_from_networkx(): assert(m.matching_graph.shortest_path(0,2) == [0,2]) +def test_unweighted_stabiliser_graph_from_retworkx(): + w = rx.PyGraph() + w.add_nodes_from([{} for _ in range(6)]) + w.add_edge(0, 1, dict(fault_ids=0, weight=7.0)) + w.add_edge(0, 5, dict(fault_ids=1, weight=14.0)) + w.add_edge(0, 2, dict(fault_ids=2, weight=9.0)) + w.add_edge(1, 2, dict(fault_ids=-1, weight=10.0)) + w.add_edge(1, 3, dict(fault_ids=3, weight=15.0)) + w.add_edge(2, 5, dict(fault_ids=4, weight=2.0)) + w.add_edge(2, 3, dict(fault_ids=-1, weight=11.0)) + w.add_edge(3, 4, dict(fault_ids=5, weight=6.0)) + w.add_edge(4, 5, dict(fault_ids=6, weight=9.0)) + m = Matching(w) + assert(m.num_fault_ids == 7) + assert(m.num_detectors == 6) + assert(m.matching_graph.shortest_path(3, 5) == [3, 2, 5]) + assert(m.matching_graph.distance(5, 0) == pytest.approx(11.0)) + assert(np.array_equal( + m.decode(np.array([1,0,1,0,0,0])), + np.array([0,0,1,0,0,0,0])) + ) + with pytest.raises(ValueError): + m.decode(np.array([1,1,0])) + with pytest.raises(ValueError): + m.decode(np.array([1,1,1,0,0,0])) + assert(np.array_equal( + m.decode(np.array([1,0,0,0,0,1])), + np.array([0,0,1,0,1,0,0])) + ) + assert(np.array_equal( + m.decode(np.array([0,1,0,0,0,1])), + np.array([0,0,0,0,1,0,0])) + ) + + +def test_mwpm_from_retworkx(): + g = rx.PyGraph() + g.add_nodes_from([{} for _ in range(3)]) + g.add_edge(0, 1, dict(fault_ids=0)) + g.add_edge(0, 2, dict(fault_ids=1)) + g.add_edge(1, 2, dict(fault_ids=2)) + m = Matching(g) + assert(isinstance(m.matching_graph, MatchingGraph)) + assert(m.num_detectors == 3) + assert(m.num_fault_ids == 3) + assert(m.matching_graph.distance(0,2) == 1) + assert(m.matching_graph.shortest_path(0,2) == [0,2]) + + g = rx.PyGraph() + g.add_nodes_from([{} for _ in range(3)]) + g.add_edge(0, 1, {}) + g.add_edge(0, 2, {}) + g.add_edge(1, 2, {}) + m = Matching(g) + assert(isinstance(m.matching_graph, MatchingGraph)) + assert(m.num_detectors == 3) + assert(m.num_fault_ids == 0) + assert(m.matching_graph.distance(0,2) == 1) + assert(m.matching_graph.shortest_path(0,2) == [0,2]) + + g = rx.PyGraph() + g.add_nodes_from([{} for _ in range(3)]) + g.add_edge(0, 1, dict(weight=1.5)) + g.add_edge(0, 2, dict(weight=1.7)) + g.add_edge(1, 2, dict(weight=1.2)) + m = Matching(g) + assert(isinstance(m.matching_graph, MatchingGraph)) + assert(m.num_detectors == 3) + assert(m.num_fault_ids == 0) + assert(m.matching_graph.distance(0,2) == pytest.approx(1.7)) + assert(m.matching_graph.shortest_path(0,2) == [0,2]) + + def test_repr(): g = nx.Graph() g.add_edge(0, 1, fault_ids=0) @@ -253,6 +359,47 @@ def test_qubit_id_accepted_via_networkx(): assert es == expected_edges +def test_matching_edges_from_retworkx(): + g = rx.PyGraph() + g.add_nodes_from([{} for _ in range(4)]) + g.add_edge(0, 1, dict(fault_ids=0, weight=1.1, error_probability=0.1)) + g.add_edge(1, 2, dict(fault_ids=1, weight=2.1, error_probability=0.2)) + g.add_edge(2, 3, dict(fault_ids={2,3}, weight=0.9, error_probability=0.3)) + g[0]['is_boundary'] = True + g[3]['is_boundary'] = True + g.add_edge(0, 3, dict(weight=0.0)) + m = Matching(g) + es = list(m.edges()) + expected_edges = [ + (0,1,{'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}), + (1,2,{'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}), + (2,3,{'fault_ids': {2,3}, 'weight': 0.9, 'error_probability': 0.3}), + (0,3,{'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}), + ] + print(es) + assert es == expected_edges + + +def test_qubit_id_accepted_via_retworkx(): + g = rx.PyGraph() + g.add_nodes_from([{} for _ in range(4)]) + g.add_edge(0, 1, dict(qubit_id=0, weight=1.1, error_probability=0.1)) + g.add_edge(1, 2, dict(qubit_id=1, weight=2.1, error_probability=0.2)) + g.add_edge(2, 3, dict(qubit_id={2, 3}, weight=0.9, error_probability=0.3)) + g[0]['is_boundary'] = True + g[3]['is_boundary'] = True + g.add_edge(0, 3, dict(weight=0.0)) + m = Matching(g) + es = list(m.edges()) + expected_edges = [ + (0, 1, {'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}), + (1, 2, {'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}), + (2, 3, {'fault_ids': {2, 3}, 'weight': 0.9, 'error_probability': 0.3}), + (0, 3, {'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}), + ] + assert es == expected_edges + + def test_qubit_id_accepted_using_add_edge(): m = Matching() m.add_edge(0, 1, qubit_id=0) @@ -304,6 +451,31 @@ def test_matching_to_networkx(): assert sorted(gedges) == sorted(g2edges) +def test_matching_to_retworkx(): + g = rx.PyGraph() + g.add_nodes_from([{} for _ in range(4)]) + g.add_edge(0, 1, dict(fault_ids={0}, weight=1.1, error_probability=0.1)) + g.add_edge(1, 2, dict(fault_ids={1}, weight=2.1, error_probability=0.2)) + g.add_edge(2, 3, dict(fault_ids={2,3}, weight=0.9, error_probability=0.3)) + g[0]['is_boundary'] = True + g[3]['is_boundary'] = True + g.add_edge(0, 3, dict(weight=0.0)) + m = Matching(g) + + edge_0_3 = g.get_edge_data(0, 3) + edge_0_3['fault_ids'] = set() + edge_0_3['error_probability'] = -1.0 + g[1]['is_boundary'] = False + g[2]['is_boundary'] = False + + g2 = m.to_retworkx() + + assert g.node_indices() == g2.node_indices() + gedges = [({s,t},d) for (s, t, d) in g.weighted_edge_list()] + g2edges = [({s,t},d) for (s, t, d) in g.weighted_edge_list()] + assert sorted(gedges) == sorted(g2edges) + + def test_draw_matching(): g = nx.Graph() g.add_edge(0, 1, fault_ids={0}, weight=1.1, error_probability=0.1) From 98bcc3c14007a45d2872bc1bf8a42b7f666ee402 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Sat, 7 May 2022 08:38:25 -0400 Subject: [PATCH 2/4] Fix doc test for load_from_retworkx --- src/pymatching/matching.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pymatching/matching.py b/src/pymatching/matching.py index ddf4108e..8ec141ce 100644 --- a/src/pymatching/matching.py +++ b/src/pymatching/matching.py @@ -341,9 +341,9 @@ def load_from_retworkx(self, graph: rx.PyGraph) -> None: >>> import retworkx as rx >>> import math >>> g = rx.PyGraph() - >>> g.add_nodes_from([{} for _ in range(3)]) - >>> g.add_edge(0, 1, dict(fault_ids=0, weight=math.log((1-0.1)/0.1), error_probability=0.1)) - >>> g.add_edge(1, 2, dict(fault_ids=1, weight=math.log((1-0.15)/0.15), error_probability=0.15)) + >>> matching = g.add_nodes_from([{} for _ in range(3)]) + >>> edge_a =g.add_edge(0, 1, dict(fault_ids=0, weight=math.log((1-0.1)/0.1), error_probability=0.1)) + >>> edge_b = g.add_edge(1, 2, dict(fault_ids=1, weight=math.log((1-0.15)/0.15), error_probability=0.15)) >>> g[0]['is_boundary'] = True >>> g[2]['is_boundary'] = True >>> m = pymatching.Matching(g) From ab027486f1330adc4e3ed4ac5bd0352bd090b602 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Sat, 7 May 2022 08:38:58 -0400 Subject: [PATCH 3/4] Set minimum version for retworkx requirement --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2e0acf48..99d4f52f 100644 --- a/setup.py +++ b/setup.py @@ -127,7 +127,7 @@ def build_extension(self, ext): packages=find_packages("src"), package_dir={'':'src'}, cmdclass=dict(build_ext=CMakeBuild), - install_requires=['scipy', 'numpy', 'networkx','retworkx','matplotlib'], + install_requires=['scipy', 'numpy', 'networkx','retworkx>=0.11.0','matplotlib'], classifiers=[ "License :: OSI Approved :: Apache Software License" ], From 73bfb6f4dc77551a4e52665ce924a79b7741a0fe Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Mon, 9 May 2022 10:43:59 -0400 Subject: [PATCH 4/4] Remove unnecessary all_fault_ids variable --- src/pymatching/matching.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/pymatching/matching.py b/src/pymatching/matching.py index 8ec141ce..b3a5c5e4 100644 --- a/src/pymatching/matching.py +++ b/src/pymatching/matching.py @@ -354,7 +354,6 @@ def load_from_retworkx(self, graph: rx.PyGraph) -> None: raise TypeError("G must be a retworkx graph") boundary = {i for i in graph.node_indices() if graph[i].get("is_boundary", False)} num_nodes = len(graph) - all_fault_ids = set() g = MatchingGraph(self.num_detectors, boundary) for (u, v, attr) in graph.weighted_edge_list(): u, v = int(u), int(v) @@ -377,7 +376,6 @@ def load_from_retworkx(self, graph: rx.PyGraph) -> None: raise ValueError( "fault_ids property must be an int or a set of int"\ " (or convertible to a set), not {}".format(fault_ids)) - all_fault_ids = all_fault_ids | fault_ids weight = attr.get("weight", 1) # Default weight is 1 if not provided e_prob = attr.get("error_probability", -1) g.add_edge(u, v, fault_ids, weight, e_prob, 0 <= e_prob <= 1)