diff --git a/tensorflow_quantum/core/serialize/serializer.py b/tensorflow_quantum/core/serialize/serializer.py index 323e70025..b38f199b3 100644 --- a/tensorflow_quantum/core/serialize/serializer.py +++ b/tensorflow_quantum/core/serialize/serializer.py @@ -30,6 +30,17 @@ _CONSTANT_TRUE = lambda x: True +def _single_qubit_channel_check(x): + """Check that a noise channel operates on exactly one qubit.""" + if len(x.qubits) != 1: + raise ValueError( + "Multi-qubit noise channels are not supported in TFQ. " + f"Got {x} acting on {len(x.qubits)} qubits. " + "Consider decomposing into single-qubit channels applied " + "via .on_each().") + return True + + def _round(x): return np.round(x, 6) if isinstance(x, float) else x @@ -196,7 +207,7 @@ def _asymmetric_depolarize_serializer(): gate_type=cirq.AsymmetricDepolarizingChannel, serialized_gate_id="ADP", args=args, - can_serialize_predicate=_CONSTANT_TRUE) + can_serialize_predicate=_single_qubit_channel_check) def _asymmetric_depolarize_deserializer(): @@ -234,7 +245,7 @@ def _depolarize_channel_serializer(): gate_type=cirq.DepolarizingChannel, serialized_gate_id="DP", args=args, - can_serialize_predicate=_CONSTANT_TRUE) + can_serialize_predicate=_single_qubit_channel_check) def _depolarize_channel_deserializer(): @@ -272,7 +283,7 @@ def _gad_channel_serializer(): gate_type=cirq.GeneralizedAmplitudeDampingChannel, serialized_gate_id="GAD", args=args, - can_serialize_predicate=_CONSTANT_TRUE) + can_serialize_predicate=_single_qubit_channel_check) def _gad_channel_deserializer(): @@ -309,7 +320,7 @@ def _amplitude_damp_channel_serializer(): gate_type=cirq.AmplitudeDampingChannel, serialized_gate_id="AD", args=args, - can_serialize_predicate=_CONSTANT_TRUE) + can_serialize_predicate=_single_qubit_channel_check) def _amplitude_damp_channel_deserializer(): @@ -341,7 +352,7 @@ def _reset_channel_serializer(): gate_type=cirq.ResetChannel, serialized_gate_id="RST", args=args, - can_serialize_predicate=_CONSTANT_TRUE) + can_serialize_predicate=_single_qubit_channel_check) def _reset_channel_deserializer(): @@ -370,7 +381,7 @@ def _phase_damp_channel_serializer(): gate_type=cirq.PhaseDampingChannel, serialized_gate_id="PD", args=args, - can_serialize_predicate=_CONSTANT_TRUE) + can_serialize_predicate=_single_qubit_channel_check) def _phase_damp_channel_deserializer(): @@ -403,7 +414,7 @@ def _phase_flip_channel_serializer(): gate_type=cirq.PhaseFlipChannel, serialized_gate_id="PF", args=args, - can_serialize_predicate=_CONSTANT_TRUE) + can_serialize_predicate=_single_qubit_channel_check) def _phase_flip_channel_deserializer(): @@ -437,7 +448,7 @@ def _bit_flip_channel_serializer(): gate_type=cirq.BitFlipChannel, serialized_gate_id="BF", args=args, - can_serialize_predicate=_CONSTANT_TRUE) + can_serialize_predicate=_single_qubit_channel_check) def _bit_flip_channel_deserializer(): diff --git a/tensorflow_quantum/core/serialize/serializer_test.py b/tensorflow_quantum/core/serialize/serializer_test.py index 3a89a03ea..e0d2e4a36 100644 --- a/tensorflow_quantum/core/serialize/serializer_test.py +++ b/tensorflow_quantum/core/serialize/serializer_test.py @@ -720,6 +720,16 @@ def test_serialize_noise_channel_unsupported_value(self): with self.assertRaises(ValueError): serializer.serialize_circuit(simple_circuit) + def test_serialize_multi_qubit_noise_channel(self): + """Ensure multi-qubit noise channels are rejected with clear error.""" + q0 = cirq.GridQubit(0, 0) + q1 = cirq.GridQubit(0, 1) + multi_qubit_circuit = cirq.Circuit( + cirq.DepolarizingChannel(p=0.1, n_qubits=2)(q0, q1)) + with self.assertRaisesRegex(ValueError, + "Multi-qubit noise channels"): + serializer.serialize_circuit(multi_qubit_circuit) + @parameterized.parameters([{'inp': v} for v in ['wrong', 1.0, None, []]]) def test_serialize_circuit_wrong_type(self, inp): """Attempt to serialize invalid objects types."""