Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions glue/sample/src/sinter/_decoding/_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,7 @@ def sample_decode(*,
were executed. The detection fraction is the ratio of these two
numbers.
num_shots: The number of sample shots to take from the circuit.
decoder: The name of the decoder to use. Allowed values are:
"pymatching":
Use pymatching min-weight-perfect-match decoder.
"internal":
Use internal decoder with uncorrelated decoding.
"internal_correlated":
Use internal decoder with correlated decoding.
decoder: The name of the decoder to use. For example, 'pymatching'.
tmp_dir: An existing directory that is currently empty where temporary
files can be written as part of performing decoding. If set to
None, one is created using the tempfile package.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BUILT_IN_DECODERS: Dict[str, Decoder] = {
'vacuous': VacuousDecoder(),
'pymatching': PyMatchingDecoder(),
'pymatching-correlated': PyMatchingDecoder(use_correlated_decoding=True),
'fusion_blossom': FusionBlossomDecoder(),
# an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049)
'hypergraph_union_find': HyperUFDecoder(),
Expand Down
49 changes: 45 additions & 4 deletions glue/sample/src/sinter/_decoding/_decoding_pymatching.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,53 @@
from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder


def check_pymatching_version_for_correlated_decoding(pymatching):
v = pymatching.__version__.split('.')
try:
a = int(v[0])
b = int(v[1]
c = int(''.join(e for e in v[2] if e in '0123456789')) # In case dev version
except ValueError, IndexError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... I made a syntax error

return # Probably it's the future.

if (a, b, c) < (2, 3, 1):
if not (int(v[0]) >= 2 or int(v[1]) >= 3 or int(v[2]) >= 1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop redundant if here

raise ValueError(
"Pymatching version must be at least 2.3.1 for correlated decoding.\n"
f"Installed version: {pymatching.__version__}\n"
"To fix this, install a newer version of pymatching into your environment.\n"
"For example, if you are using pip, run `pip install pymatching --upgrade`.\n"
)


class PyMatchingCompiledDecoder(CompiledDecoder):
def __init__(self, matcher: 'pymatching.Matching'):
def __init__(self, matcher: 'pymatching.Matching', use_correlated_decoding: bool):
self.matcher = matcher
self.use_correlated_decoding = use_correlated_decoding

def decode_shots_bit_packed(
self,
*,
bit_packed_detection_event_data: 'np.ndarray',
) -> 'np.ndarray':
kwargs = {}
if self.use_correlated_decoding:
kwargs['enable_correlations'] = True
return self.matcher.decode_batch(
shots=bit_packed_detection_event_data,
bit_packed_shots=True,
bit_packed_predictions=True,
return_weights=False,
**kwargs,
)


class PyMatchingDecoder(Decoder):
"""Use pymatching to predict observables from detection events."""

def __init__(self, use_correlated_decoding: bool = False):
self.use_correlated_decoding = use_correlated_decoding

def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder:
try:
import pymatching
Expand All @@ -31,7 +58,14 @@ def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> Compiled
"For example, if you are using pip, run `pip install pymatching`.\n"
) from ex

return PyMatchingCompiledDecoder(pymatching.Matching.from_detector_error_model(dem))
kwargs = {}
if self.use_correlated_decoding:
check_pymatching_version_for_correlated_decoding(pymatching)
kwargs['enable_correlations'] = True
return PyMatchingCompiledDecoder(
pymatching.Matching.from_detector_error_model(dem, **kwargs),
use_correlated_decoding=self.use_correlated_decoding,
)

def decode_via_files(self,
*,
Expand Down Expand Up @@ -60,7 +94,9 @@ def decode_via_files(self,
if not hasattr(pymatching, 'cli'):
raise ValueError("""
The installed version of pymatching has no `pymatching.cli` method.

sinter requires pymatching 2.1.0 or later.

If you're using pip to install packages, this can be fixed by running

```
Expand All @@ -69,13 +105,18 @@ def decode_via_files(self,

""")

result = pymatching.cli(command_line_args=[
args = [
"predict",
"--dem", str(dem_path),
"--in", str(dets_b8_in_path),
"--in_format", "b8",
"--out", str(obs_predictions_b8_out_path),
"--out_format", "b8",
])
]
if self.use_correlated_decoding:
check_pymatching_version_for_correlated_decoding(pymatching)
args.append("--enable_correlations")

result = pymatching.cli(command_line_args=args)
if result:
raise ValueError("pymatching.cli returned a non-zero exit code")
4 changes: 3 additions & 1 deletion glue/sample/src/sinter/_decoding/_decoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[boo

@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES)
def test_post_selection(decoder: str, force_streaming: Optional[bool]):
if decoder == 'pymatching-correlated':
pytest.skip("Correlated matching does not support error probabilities > 0.5 in from_detector_error_model")
circuit = stim.Circuit("""
X_ERROR(0.6) 0
M 0
Expand All @@ -243,7 +245,7 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]):
M 1
DETECTOR(1, 0, 0) rec[-1]
OBSERVABLE_INCLUDE(0) rec[-1]

X_ERROR(0.1) 2
M 2
OBSERVABLE_INCLUDE(0) rec[-1]
Expand Down
Loading