diff --git a/glue/sample/src/sinter/_decoding/_decoding.py b/glue/sample/src/sinter/_decoding/_decoding.py index 1e54f87ef..e45aef72b 100644 --- a/glue/sample/src/sinter/_decoding/_decoding.py +++ b/glue/sample/src/sinter/_decoding/_decoding.py @@ -177,13 +177,7 @@ def sample_decode(*, were executed. The detection fraction is the ratio of these two numbers. num_shots: The number of sample shots to take from the circuit. - decoder: The name of the decoder to use. Allowed values are: - "pymatching": - Use pymatching min-weight-perfect-match decoder. - "internal": - Use internal decoder with uncorrelated decoding. - "internal_correlated": - Use internal decoder with correlated decoding. + decoder: The name of the decoder to use. For example, 'pymatching'. tmp_dir: An existing directory that is currently empty where temporary files can be written as part of performing decoding. If set to None, one is created using the tempfile package. diff --git a/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py index 92d8d49dd..93ffa584f 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py +++ b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py @@ -12,6 +12,7 @@ BUILT_IN_DECODERS: Dict[str, Decoder] = { 'vacuous': VacuousDecoder(), 'pymatching': PyMatchingDecoder(), + 'pymatching-correlated': PyMatchingDecoder(use_correlated_decoding=True), 'fusion_blossom': FusionBlossomDecoder(), # an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049) 'hypergraph_union_find': HyperUFDecoder(), diff --git a/glue/sample/src/sinter/_decoding/_decoding_pymatching.py b/glue/sample/src/sinter/_decoding/_decoding_pymatching.py index b57bb32bc..ce933a523 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_pymatching.py +++ b/glue/sample/src/sinter/_decoding/_decoding_pymatching.py @@ -1,26 +1,53 @@ from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder +def check_pymatching_version_for_correlated_decoding(pymatching): + v = pymatching.__version__.split('.') + try: + a = int(v[0]) + b = int(v[1] + c = int(''.join(e for e in v[2] if e in '0123456789')) # In case dev version + except ValueError, IndexError: + return # Probably it's the future. + + if (a, b, c) < (2, 3, 1): + if not (int(v[0]) >= 2 or int(v[1]) >= 3 or int(v[2]) >= 1): + raise ValueError( + "Pymatching version must be at least 2.3.1 for correlated decoding.\n" + f"Installed version: {pymatching.__version__}\n" + "To fix this, install a newer version of pymatching into your environment.\n" + "For example, if you are using pip, run `pip install pymatching --upgrade`.\n" + ) + + class PyMatchingCompiledDecoder(CompiledDecoder): - def __init__(self, matcher: 'pymatching.Matching'): + def __init__(self, matcher: 'pymatching.Matching', use_correlated_decoding: bool): self.matcher = matcher + self.use_correlated_decoding = use_correlated_decoding def decode_shots_bit_packed( self, *, bit_packed_detection_event_data: 'np.ndarray', ) -> 'np.ndarray': + kwargs = {} + if self.use_correlated_decoding: + kwargs['enable_correlations'] = True return self.matcher.decode_batch( shots=bit_packed_detection_event_data, bit_packed_shots=True, bit_packed_predictions=True, return_weights=False, + **kwargs, ) class PyMatchingDecoder(Decoder): """Use pymatching to predict observables from detection events.""" + def __init__(self, use_correlated_decoding: bool = False): + self.use_correlated_decoding = use_correlated_decoding + def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder: try: import pymatching @@ -31,7 +58,14 @@ def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> Compiled "For example, if you are using pip, run `pip install pymatching`.\n" ) from ex - return PyMatchingCompiledDecoder(pymatching.Matching.from_detector_error_model(dem)) + kwargs = {} + if self.use_correlated_decoding: + check_pymatching_version_for_correlated_decoding(pymatching) + kwargs['enable_correlations'] = True + return PyMatchingCompiledDecoder( + pymatching.Matching.from_detector_error_model(dem, **kwargs), + use_correlated_decoding=self.use_correlated_decoding, + ) def decode_via_files(self, *, @@ -60,7 +94,9 @@ def decode_via_files(self, if not hasattr(pymatching, 'cli'): raise ValueError(""" The installed version of pymatching has no `pymatching.cli` method. + sinter requires pymatching 2.1.0 or later. + If you're using pip to install packages, this can be fixed by running ``` @@ -69,13 +105,18 @@ def decode_via_files(self, """) - result = pymatching.cli(command_line_args=[ + args = [ "predict", "--dem", str(dem_path), "--in", str(dets_b8_in_path), "--in_format", "b8", "--out", str(obs_predictions_b8_out_path), "--out_format", "b8", - ]) + ] + if self.use_correlated_decoding: + check_pymatching_version_for_correlated_decoding(pymatching) + args.append("--enable_correlations") + + result = pymatching.cli(command_line_args=args) if result: raise ValueError("pymatching.cli returned a non-zero exit code") diff --git a/glue/sample/src/sinter/_decoding/_decoding_test.py b/glue/sample/src/sinter/_decoding/_decoding_test.py index cd4e28d0d..7dd08f379 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_test.py +++ b/glue/sample/src/sinter/_decoding/_decoding_test.py @@ -233,6 +233,8 @@ def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[boo @pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) def test_post_selection(decoder: str, force_streaming: Optional[bool]): + if decoder == 'pymatching-correlated': + pytest.skip("Correlated matching does not support error probabilities > 0.5 in from_detector_error_model") circuit = stim.Circuit(""" X_ERROR(0.6) 0 M 0 @@ -243,7 +245,7 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]): M 1 DETECTOR(1, 0, 0) rec[-1] OBSERVABLE_INCLUDE(0) rec[-1] - + X_ERROR(0.1) 2 M 2 OBSERVABLE_INCLUDE(0) rec[-1]