17
17
from scipy .sparse import csc_matrix , csr_matrix
18
18
import pytest
19
19
import networkx as nx
20
+ import retworkx as rx
20
21
import matplotlib .pyplot as plt
21
22
22
23
from pymatching ._cpp_mwpm import MatchingGraph
@@ -60,6 +61,21 @@ def test_boundary_from_networkx():
60
61
assert np .array_equal (m .decode (np .array ([0 ,1 ,1 ,0 ])), np .array ([0 ,0 ,1 ,0 ,0 ]))
61
62
assert np .array_equal (m .decode (np .array ([0 ,0 ,1 ,0 ])), np .array ([0 ,0 ,0 ,1 ,1 ]))
62
63
64
+ def test_boundary_from_retworkx ():
65
+ g = rx .PyGraph ()
66
+ g .add_nodes_from ([{} for _ in range (5 )])
67
+ g .add_edge (4 ,0 , dict (fault_ids = 0 ))
68
+ g .add_edge (0 ,1 , dict (fault_ids = 1 ))
69
+ g .add_edge (1 ,2 , dict (fault_ids = 2 ))
70
+ g .add_edge (2 ,3 , dict (fault_ids = 3 ))
71
+ g .add_edge (3 ,4 , dict (fault_ids = 4 ))
72
+ g [4 ]['is_boundary' ] = True
73
+ m = Matching (g )
74
+ assert m .boundary == {4 }
75
+ assert np .array_equal (m .decode (np .array ([1 ,0 ,0 ,0 ])), np .array ([1 ,0 ,0 ,0 ,0 ]))
76
+ assert np .array_equal (m .decode (np .array ([0 ,1 ,0 ,0 ])), np .array ([1 ,1 ,0 ,0 ,0 ]))
77
+ assert np .array_equal (m .decode (np .array ([0 ,1 ,1 ,0 ])), np .array ([0 ,0 ,1 ,0 ,0 ]))
78
+ assert np .array_equal (m .decode (np .array ([0 ,0 ,1 ,0 ])), np .array ([0 ,0 ,0 ,1 ,1 ]))
63
79
64
80
def test_boundaries_from_networkx ():
65
81
g = nx .Graph ()
@@ -78,6 +94,23 @@ def test_boundaries_from_networkx():
78
94
assert np .array_equal (m .decode (np .array ([0 ,0 ,1 ,1 ,0 ])), np .array ([0 ,0 ,1 ,0 ,0 ]))
79
95
assert np .array_equal (m .decode (np .array ([0 ,0 ,0 ,1 ,0 ])), np .array ([0 ,0 ,0 ,1 ,1 ]))
80
96
97
+ def test_boundaries_from_retworkx ():
98
+ g = rx .PyGraph ()
99
+ g .add_nodes_from ([{} for _ in range (6 )])
100
+ g .add_edge (0 ,1 , dict (fault_ids = 0 ))
101
+ g .add_edge (1 ,2 , dict (fault_ids = 1 ))
102
+ g .add_edge (2 ,3 , dict (fault_ids = 2 ))
103
+ g .add_edge (3 ,4 , dict (fault_ids = 3 ))
104
+ g .add_edge (4 ,5 , dict (fault_ids = 4 ))
105
+ g .add_edge (0 ,5 , dict (fault_ids = - 1 , weight = 0.0 ))
106
+ g .nodes ()[0 ]['is_boundary' ] = True
107
+ g .nodes ()[5 ]['is_boundary' ] = True
108
+ m = Matching (g )
109
+ assert m .boundary == {0 , 5 }
110
+ assert np .array_equal (m .decode (np .array ([0 ,1 ,0 ,0 ,0 ,0 ])), np .array ([1 ,0 ,0 ,0 ,0 ]))
111
+ assert np .array_equal (m .decode (np .array ([0 ,0 ,1 ,0 ,0 ])), np .array ([1 ,1 ,0 ,0 ,0 ]))
112
+ assert np .array_equal (m .decode (np .array ([0 ,0 ,1 ,1 ,0 ])), np .array ([0 ,0 ,1 ,0 ,0 ]))
113
+ assert np .array_equal (m .decode (np .array ([0 ,0 ,0 ,1 ,0 ])), np .array ([0 ,0 ,0 ,1 ,1 ]))
81
114
82
115
def test_nonzero_matrix_elements_not_one_raises_value_error ():
83
116
H = csr_matrix (np .array ([[0 ,1.01 ,1.01 ],[1.01 ,1.01 ,0 ]]))
@@ -200,6 +233,79 @@ def test_mwpm_from_networkx():
200
233
assert (m .matching_graph .shortest_path (0 ,2 ) == [0 ,2 ])
201
234
202
235
236
+ def test_unweighted_stabiliser_graph_from_retworkx ():
237
+ w = rx .PyGraph ()
238
+ w .add_nodes_from ([{} for _ in range (6 )])
239
+ w .add_edge (0 , 1 , dict (fault_ids = 0 , weight = 7.0 ))
240
+ w .add_edge (0 , 5 , dict (fault_ids = 1 , weight = 14.0 ))
241
+ w .add_edge (0 , 2 , dict (fault_ids = 2 , weight = 9.0 ))
242
+ w .add_edge (1 , 2 , dict (fault_ids = - 1 , weight = 10.0 ))
243
+ w .add_edge (1 , 3 , dict (fault_ids = 3 , weight = 15.0 ))
244
+ w .add_edge (2 , 5 , dict (fault_ids = 4 , weight = 2.0 ))
245
+ w .add_edge (2 , 3 , dict (fault_ids = - 1 , weight = 11.0 ))
246
+ w .add_edge (3 , 4 , dict (fault_ids = 5 , weight = 6.0 ))
247
+ w .add_edge (4 , 5 , dict (fault_ids = 6 , weight = 9.0 ))
248
+ m = Matching (w )
249
+ assert (m .num_fault_ids == 7 )
250
+ assert (m .num_detectors == 6 )
251
+ assert (m .matching_graph .shortest_path (3 , 5 ) == [3 , 2 , 5 ])
252
+ assert (m .matching_graph .distance (5 , 0 ) == pytest .approx (11.0 ))
253
+ assert (np .array_equal (
254
+ m .decode (np .array ([1 ,0 ,1 ,0 ,0 ,0 ])),
255
+ np .array ([0 ,0 ,1 ,0 ,0 ,0 ,0 ]))
256
+ )
257
+ with pytest .raises (ValueError ):
258
+ m .decode (np .array ([1 ,1 ,0 ]))
259
+ with pytest .raises (ValueError ):
260
+ m .decode (np .array ([1 ,1 ,1 ,0 ,0 ,0 ]))
261
+ assert (np .array_equal (
262
+ m .decode (np .array ([1 ,0 ,0 ,0 ,0 ,1 ])),
263
+ np .array ([0 ,0 ,1 ,0 ,1 ,0 ,0 ]))
264
+ )
265
+ assert (np .array_equal (
266
+ m .decode (np .array ([0 ,1 ,0 ,0 ,0 ,1 ])),
267
+ np .array ([0 ,0 ,0 ,0 ,1 ,0 ,0 ]))
268
+ )
269
+
270
+
271
+ def test_mwpm_from_retworkx ():
272
+ g = rx .PyGraph ()
273
+ g .add_nodes_from ([{} for _ in range (3 )])
274
+ g .add_edge (0 , 1 , dict (fault_ids = 0 ))
275
+ g .add_edge (0 , 2 , dict (fault_ids = 1 ))
276
+ g .add_edge (1 , 2 , dict (fault_ids = 2 ))
277
+ m = Matching (g )
278
+ assert (isinstance (m .matching_graph , MatchingGraph ))
279
+ assert (m .num_detectors == 3 )
280
+ assert (m .num_fault_ids == 3 )
281
+ assert (m .matching_graph .distance (0 ,2 ) == 1 )
282
+ assert (m .matching_graph .shortest_path (0 ,2 ) == [0 ,2 ])
283
+
284
+ g = rx .PyGraph ()
285
+ g .add_nodes_from ([{} for _ in range (3 )])
286
+ g .add_edge (0 , 1 , {})
287
+ g .add_edge (0 , 2 , {})
288
+ g .add_edge (1 , 2 , {})
289
+ m = Matching (g )
290
+ assert (isinstance (m .matching_graph , MatchingGraph ))
291
+ assert (m .num_detectors == 3 )
292
+ assert (m .num_fault_ids == 0 )
293
+ assert (m .matching_graph .distance (0 ,2 ) == 1 )
294
+ assert (m .matching_graph .shortest_path (0 ,2 ) == [0 ,2 ])
295
+
296
+ g = rx .PyGraph ()
297
+ g .add_nodes_from ([{} for _ in range (3 )])
298
+ g .add_edge (0 , 1 , dict (weight = 1.5 ))
299
+ g .add_edge (0 , 2 , dict (weight = 1.7 ))
300
+ g .add_edge (1 , 2 , dict (weight = 1.2 ))
301
+ m = Matching (g )
302
+ assert (isinstance (m .matching_graph , MatchingGraph ))
303
+ assert (m .num_detectors == 3 )
304
+ assert (m .num_fault_ids == 0 )
305
+ assert (m .matching_graph .distance (0 ,2 ) == pytest .approx (1.7 ))
306
+ assert (m .matching_graph .shortest_path (0 ,2 ) == [0 ,2 ])
307
+
308
+
203
309
def test_repr ():
204
310
g = nx .Graph ()
205
311
g .add_edge (0 , 1 , fault_ids = 0 )
@@ -253,6 +359,47 @@ def test_qubit_id_accepted_via_networkx():
253
359
assert es == expected_edges
254
360
255
361
362
+ def test_matching_edges_from_retworkx ():
363
+ g = rx .PyGraph ()
364
+ g .add_nodes_from ([{} for _ in range (4 )])
365
+ g .add_edge (0 , 1 , dict (fault_ids = 0 , weight = 1.1 , error_probability = 0.1 ))
366
+ g .add_edge (1 , 2 , dict (fault_ids = 1 , weight = 2.1 , error_probability = 0.2 ))
367
+ g .add_edge (2 , 3 , dict (fault_ids = {2 ,3 }, weight = 0.9 , error_probability = 0.3 ))
368
+ g [0 ]['is_boundary' ] = True
369
+ g [3 ]['is_boundary' ] = True
370
+ g .add_edge (0 , 3 , dict (weight = 0.0 ))
371
+ m = Matching (g )
372
+ es = list (m .edges ())
373
+ expected_edges = [
374
+ (0 ,1 ,{'fault_ids' : {0 }, 'weight' : 1.1 , 'error_probability' : 0.1 }),
375
+ (1 ,2 ,{'fault_ids' : {1 }, 'weight' : 2.1 , 'error_probability' : 0.2 }),
376
+ (2 ,3 ,{'fault_ids' : {2 ,3 }, 'weight' : 0.9 , 'error_probability' : 0.3 }),
377
+ (0 ,3 ,{'fault_ids' : set (), 'weight' : 0.0 , 'error_probability' : - 1.0 }),
378
+ ]
379
+ print (es )
380
+ assert es == expected_edges
381
+
382
+
383
+ def test_qubit_id_accepted_via_retworkx ():
384
+ g = rx .PyGraph ()
385
+ g .add_nodes_from ([{} for _ in range (4 )])
386
+ g .add_edge (0 , 1 , dict (qubit_id = 0 , weight = 1.1 , error_probability = 0.1 ))
387
+ g .add_edge (1 , 2 , dict (qubit_id = 1 , weight = 2.1 , error_probability = 0.2 ))
388
+ g .add_edge (2 , 3 , dict (qubit_id = {2 , 3 }, weight = 0.9 , error_probability = 0.3 ))
389
+ g [0 ]['is_boundary' ] = True
390
+ g [3 ]['is_boundary' ] = True
391
+ g .add_edge (0 , 3 , dict (weight = 0.0 ))
392
+ m = Matching (g )
393
+ es = list (m .edges ())
394
+ expected_edges = [
395
+ (0 , 1 , {'fault_ids' : {0 }, 'weight' : 1.1 , 'error_probability' : 0.1 }),
396
+ (1 , 2 , {'fault_ids' : {1 }, 'weight' : 2.1 , 'error_probability' : 0.2 }),
397
+ (2 , 3 , {'fault_ids' : {2 , 3 }, 'weight' : 0.9 , 'error_probability' : 0.3 }),
398
+ (0 , 3 , {'fault_ids' : set (), 'weight' : 0.0 , 'error_probability' : - 1.0 }),
399
+ ]
400
+ assert es == expected_edges
401
+
402
+
256
403
def test_qubit_id_accepted_using_add_edge ():
257
404
m = Matching ()
258
405
m .add_edge (0 , 1 , qubit_id = 0 )
@@ -304,6 +451,31 @@ def test_matching_to_networkx():
304
451
assert sorted (gedges ) == sorted (g2edges )
305
452
306
453
454
+ def test_matching_to_retworkx ():
455
+ g = rx .PyGraph ()
456
+ g .add_nodes_from ([{} for _ in range (4 )])
457
+ g .add_edge (0 , 1 , dict (fault_ids = {0 }, weight = 1.1 , error_probability = 0.1 ))
458
+ g .add_edge (1 , 2 , dict (fault_ids = {1 }, weight = 2.1 , error_probability = 0.2 ))
459
+ g .add_edge (2 , 3 , dict (fault_ids = {2 ,3 }, weight = 0.9 , error_probability = 0.3 ))
460
+ g [0 ]['is_boundary' ] = True
461
+ g [3 ]['is_boundary' ] = True
462
+ g .add_edge (0 , 3 , dict (weight = 0.0 ))
463
+ m = Matching (g )
464
+
465
+ edge_0_3 = g .get_edge_data (0 , 3 )
466
+ edge_0_3 ['fault_ids' ] = set ()
467
+ edge_0_3 ['error_probability' ] = - 1.0
468
+ g [1 ]['is_boundary' ] = False
469
+ g [2 ]['is_boundary' ] = False
470
+
471
+ g2 = m .to_retworkx ()
472
+
473
+ assert g .node_indices () == g2 .node_indices ()
474
+ gedges = [({s ,t },d ) for (s , t , d ) in g .weighted_edge_list ()]
475
+ g2edges = [({s ,t },d ) for (s , t , d ) in g .weighted_edge_list ()]
476
+ assert sorted (gedges ) == sorted (g2edges )
477
+
478
+
307
479
def test_draw_matching ():
308
480
g = nx .Graph ()
309
481
g .add_edge (0 , 1 , fault_ids = {0 }, weight = 1.1 , error_probability = 0.1 )
0 commit comments