Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions tensorflow_quantum/core/serialize/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@
_CONSTANT_TRUE = lambda x: True


def _single_qubit_channel_check(x):
"""Check that a noise channel operates on exactly one qubit."""
if len(x.qubits) != 1:
raise ValueError(
"Multi-qubit noise channels are not supported in TFQ. "
f"Got {x} acting on {len(x.qubits)} qubits. "
"Consider decomposing into single-qubit channels applied "
"via .on_each().")
return True


def _round(x):
return np.round(x, 6) if isinstance(x, float) else x

Expand Down Expand Up @@ -196,7 +207,7 @@ def _asymmetric_depolarize_serializer():
gate_type=cirq.AsymmetricDepolarizingChannel,
serialized_gate_id="ADP",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)
can_serialize_predicate=_single_qubit_channel_check)


def _asymmetric_depolarize_deserializer():
Expand Down Expand Up @@ -234,7 +245,7 @@ def _depolarize_channel_serializer():
gate_type=cirq.DepolarizingChannel,
serialized_gate_id="DP",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)
can_serialize_predicate=_single_qubit_channel_check)


def _depolarize_channel_deserializer():
Expand Down Expand Up @@ -272,7 +283,7 @@ def _gad_channel_serializer():
gate_type=cirq.GeneralizedAmplitudeDampingChannel,
serialized_gate_id="GAD",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)
can_serialize_predicate=_single_qubit_channel_check)


def _gad_channel_deserializer():
Expand Down Expand Up @@ -309,7 +320,7 @@ def _amplitude_damp_channel_serializer():
gate_type=cirq.AmplitudeDampingChannel,
serialized_gate_id="AD",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)
can_serialize_predicate=_single_qubit_channel_check)


def _amplitude_damp_channel_deserializer():
Expand Down Expand Up @@ -341,7 +352,7 @@ def _reset_channel_serializer():
gate_type=cirq.ResetChannel,
serialized_gate_id="RST",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)
can_serialize_predicate=_single_qubit_channel_check)


def _reset_channel_deserializer():
Expand Down Expand Up @@ -370,7 +381,7 @@ def _phase_damp_channel_serializer():
gate_type=cirq.PhaseDampingChannel,
serialized_gate_id="PD",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)
can_serialize_predicate=_single_qubit_channel_check)


def _phase_damp_channel_deserializer():
Expand Down Expand Up @@ -403,7 +414,7 @@ def _phase_flip_channel_serializer():
gate_type=cirq.PhaseFlipChannel,
serialized_gate_id="PF",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)
can_serialize_predicate=_single_qubit_channel_check)


def _phase_flip_channel_deserializer():
Expand Down Expand Up @@ -437,7 +448,7 @@ def _bit_flip_channel_serializer():
gate_type=cirq.BitFlipChannel,
serialized_gate_id="BF",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)
can_serialize_predicate=_single_qubit_channel_check)


def _bit_flip_channel_deserializer():
Expand Down
10 changes: 10 additions & 0 deletions tensorflow_quantum/core/serialize/serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,16 @@ def test_serialize_noise_channel_unsupported_value(self):
with self.assertRaises(ValueError):
serializer.serialize_circuit(simple_circuit)

def test_serialize_multi_qubit_noise_channel(self):
"""Ensure multi-qubit noise channels are rejected with clear error."""
q0 = cirq.GridQubit(0, 0)
q1 = cirq.GridQubit(0, 1)
multi_qubit_circuit = cirq.Circuit(
cirq.DepolarizingChannel(p=0.1, n_qubits=2)(q0, q1))
with self.assertRaisesRegex(ValueError,
"Multi-qubit noise channels"):
serializer.serialize_circuit(multi_qubit_circuit)
Comment on lines +723 to +731

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Great job on adding a test for this new validation! To make it more robust, I suggest expanding it to also test cirq.asymmetric_depolarize, which was the motivating example from the issue description. This will ensure that channels defined by error_probabilities with multi-qubit Pauli strings are also correctly rejected.

Additionally, the n_qubits argument in cirq.DepolarizingChannel is deprecated. It's better to omit it and let cirq infer the number of qubits from the on() call.

Suggested change
def test_serialize_multi_qubit_noise_channel(self):
"""Ensure multi-qubit noise channels are rejected with clear error."""
q0 = cirq.GridQubit(0, 0)
q1 = cirq.GridQubit(0, 1)
multi_qubit_circuit = cirq.Circuit(
cirq.DepolarizingChannel(p=0.1, n_qubits=2)(q0, q1))
with self.assertRaisesRegex(ValueError,
"Multi-qubit noise channels"):
serializer.serialize_circuit(multi_qubit_circuit)
def test_serialize_multi_qubit_noise_channel(self):
"""Ensure multi-qubit noise channels are rejected with clear error."""
q0 = cirq.GridQubit(0, 0)
q1 = cirq.GridQubit(0, 1)
depol_circuit = cirq.Circuit(
cirq.DepolarizingChannel(p=0.1)(q0, q1))
with self.assertRaisesRegex(ValueError, "Multi-qubit noise channels"):
serializer.serialize_circuit(depol_circuit)
asym_depol_circuit = cirq.Circuit(
cirq.asymmetric_depolarize(error_probabilities={'XX': 0.1})(q0, q1))
with self.assertRaisesRegex(ValueError, "Multi-qubit noise channels"):
serializer.serialize_circuit(asym_depol_circuit)


@parameterized.parameters([{'inp': v} for v in ['wrong', 1.0, None, []]])
def test_serialize_circuit_wrong_type(self, inp):
"""Attempt to serialize invalid objects types."""
Expand Down
Loading