Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions glue/cirq/stimcirq/_stim_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
Expand Down Expand Up @@ -64,19 +65,25 @@ def _proper_transform_circuit_qubits(circuit: cirq.AbstractCircuit, remap: Dict[


class CircuitTranslationTracker:
def __init__(self, flatten: bool):
def __init__(self, flatten: bool, single_measure_key: Optional[str] = None):
self.qubit_coords: Dict[int, cirq.Qid] = {}
self.origin: DefaultDict[float] = collections.defaultdict(float)
self.num_measurements_seen = 0
self.full_circuit = cirq.Circuit()
self.tick_circuit = cirq.Circuit()
self.flatten = flatten
self.have_seen_loop = False
self.single_measure_key = single_measure_key

def get_next_measure_id(self) -> int:
self.num_measurements_seen += 1
return self.num_measurements_seen - 1

def get_next_measure_key(self) -> str:
if self.single_measure_key is None:
return str(self.get_next_measure_id())
return self.single_measure_key

def append_operation(self, op: cirq.Operation) -> None:
self.tick_circuit.append(op, strategy=cirq.InsertStrategy.INLINE)

Expand Down Expand Up @@ -186,7 +193,7 @@ def process_measurement_instruction(
for t in targets:
if not t.is_qubit_target:
raise NotImplementedError(f"instruction={instruction!r}")
key = str(self.get_next_measure_id())
key = self.get_next_measure_key()
self.append_operation(
MeasureAndOrResetGate(
measure=measure,
Expand Down Expand Up @@ -248,7 +255,7 @@ def process_mpp(self, instruction: stim.CircuitInstruction) -> None:

obs = _stim_targets_to_dense_pauli_string(group)
qubits = [cirq.LineQubit(t.value) for t in group]
key = str(self.get_next_measure_id())
key = self.get_next_measure_key()
self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits).with_tags(*tags))

def process_spp_dag(self, instruction: stim.CircuitInstruction) -> None:
Expand Down Expand Up @@ -290,7 +297,7 @@ def process_m_pair(self, instruction: stim.CircuitInstruction, basis: str) -> No
if targets[0].is_inverted_result_target ^ targets[1].is_inverted_result_target:
obs *= -1
qubits = [cirq.LineQubit(targets[0].value), cirq.LineQubit(targets[1].value)]
key = str(self.get_next_measure_id())
key = self.get_next_measure_key()
self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits).with_tags(*tags))

def process_mxx(self, instruction: stim.CircuitInstruction) -> None:
Expand All @@ -309,7 +316,7 @@ def process_mpad(self, instruction: stim.CircuitInstruction) -> None:
if t.value == 1:
obs *= -1
qubits = []
key = str(self.get_next_measure_id())
key = self.get_next_measure_key()
self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits))

def process_correlated_error(self, instruction: stim.CircuitInstruction) -> None:
Expand Down Expand Up @@ -632,12 +639,17 @@ def handler(
}


def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False) -> cirq.Circuit:
def stim_circuit_to_cirq_circuit(
circuit: stim.Circuit,
*,
flatten: bool = False,
single_measure_key: Optional[str] = None,
) -> cirq.Circuit:
"""Converts a stim circuit into an equivalent cirq circuit.

Qubit indices are turned into cirq.LineQubit instances. Measurements are
keyed by their ordering (e.g. the first measurement is keyed "0", the second
is keyed "1", etc).
is keyed "1", etc) unless a fixed measure_key is provided.

Not all circuits can be converted:
- ELSE_CORRELATED_ERROR instructions are not supported.
Expand All @@ -652,6 +664,8 @@ def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False
explicitly repeating their instructions multiple times. Also,
SHIFT_COORDS instructions are removed by appropriately adjusting the
coordinate metadata of later instructions.
single_measure_key: Defaults to None. If provided, all measurements are
keyed with this string instead of sequentially generated numbers.

Returns:
The converted circuit.
Expand All @@ -671,6 +685,8 @@ def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False
1: ───────X──────────────────!M('0')───
"""
tracker = CircuitTranslationTracker(flatten=flatten)
tracker = CircuitTranslationTracker(
flatten=flatten, single_measure_key=single_measure_key
)
tracker.process_circuit(repetitions=1, circuit=circuit)
return tracker.output()
48 changes: 47 additions & 1 deletion glue/cirq/stimcirq/_stim_to_cirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,4 +778,50 @@ def test_round_trip_with_pauli_obs():
""")
cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(stim_circuit)
restored_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit)
assert restored_circuit == stim_circuit
assert restored_circuit == stim_circuit


def test_single_measure_key_order():
stim_circuits = [
stim.Circuit(
"""
X 1
X 1 3
X 1 3
X 1 3 2
M 1
M 3
M 2
M 0
"""
),
stim.Circuit(
"""
X 1
X 1
X 1
X 1
M 1 3
X 2
M 2 0
"""
)
]
measure_key = "m"
for stim_circuit in stim_circuits:
cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(
stim_circuit, single_measure_key=measure_key
)
qubits = cirq.LineQubit.range(4)
expected_order = [
qubits[targ.qubit_value]
for inst in stim_circuit if inst.name == "M"
for targ in inst.targets_copy()
]
actual_order = []
for op in cirq_circuit.all_operations():
if isinstance(op.gate, cirq.MeasurementGate):
assert op.gate.key == measure_key
assert len(op.qubits) == 1
actual_order.append(op.qubits[0])
assert expected_order == actual_order
Loading