diff --git a/qualtran/bloqs/multiplexers/selected_majorana_fermion.py b/qualtran/bloqs/multiplexers/selected_majorana_fermion.py index fd5a94952e..ba593ce608 100644 --- a/qualtran/bloqs/multiplexers/selected_majorana_fermion.py +++ b/qualtran/bloqs/multiplexers/selected_majorana_fermion.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import cached_property -from typing import Iterator, Sequence, Tuple, Union +from typing import Dict, Iterator, Sequence, Tuple, Union import attrs import cirq @@ -25,6 +25,7 @@ from qualtran._infra.data_types import BQUInt from qualtran._infra.gate_with_registers import total_bits from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate +from qualtran.simulation.classical_sim import ClassicalValT @attrs.frozen @@ -137,5 +138,43 @@ def nth_operation( # type: ignore[override] yield self.target_gate(target[target_idx]).controlled_by(control) yield cirq.CZ(*accumulator, target[target_idx]) + def on_classical_vals(self, **vals) -> Dict[str, 'ClassicalValT']: + if self.target_gate != cirq.X and self.target_gate != cirq.Z: + return NotImplemented + if len(self.control_registers) > 1 or len(self.selection_registers) > 1: + return NotImplemented + control_name = self.control_registers[0].name + control = vals[control_name] + selection_name = self.selection_registers[0].name + selection = vals[selection_name] + target = vals['target'] + + # When target_gate == cirq.X, the action is (modulo phase) a single bitflip. + if control and self.target_gate == cirq.X: + max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1 + target = (2 ** (max_selection - selection)) ^ target + # When target_gate == cirq.Z, the action is only in the phase. + + return {control_name: control, selection_name: selection, 'target': target} + + def basis_state_phase(self, **vals) -> Union[complex, None]: + if self.target_gate != cirq.X and self.target_gate != cirq.Z: + return None + if len(self.control_registers) > 1 or len(self.selection_registers) > 1: + return None + control_name = self.control_registers[0].name + control = vals[control_name] + selection_name = self.selection_registers[0].name + selection = vals[selection_name] + target = vals['target'] + if control: + max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1 + if self.target_gate == cirq.X: + num_phases = (target >> (max_selection - selection + 1)).bit_count() + else: + num_phases = (target >> (max_selection - selection)).bit_count() + return 1 if (num_phases % 2) == 0 else -1 + return 1 + def __str__(self): return f'SelectedMajoranaFermion({self.target_gate})' diff --git a/qualtran/bloqs/multiplexers/selected_majorana_fermion_test.py b/qualtran/bloqs/multiplexers/selected_majorana_fermion_test.py index d0b9644207..5017f719dc 100644 --- a/qualtran/bloqs/multiplexers/selected_majorana_fermion_test.py +++ b/qualtran/bloqs/multiplexers/selected_majorana_fermion_test.py @@ -20,7 +20,10 @@ from qualtran._infra.gate_with_registers import get_named_qubits, total_bits from qualtran.bloqs.multiplexers.selected_majorana_fermion import SelectedMajoranaFermion from qualtran.cirq_interop.testing import GateHelper -from qualtran.testing import assert_valid_bloq_decomposition +from qualtran.testing import ( + assert_consistent_phased_classical_action, + assert_valid_bloq_decomposition, +) @pytest.mark.slow @@ -148,3 +151,14 @@ def test_selected_majorana_fermion_gate_make_on(): op = gate.on_registers(**get_named_qubits(gate.signature)) op2 = SelectedMajoranaFermion.make_on(target_gate=cirq.X, **get_named_qubits(gate.signature)) assert op == op2 + + +@pytest.mark.parametrize("selection_bitsize, target_bitsize", [(2, 4), (3, 5)]) +@pytest.mark.parametrize("target_gate", [cirq.X, cirq.Z]) +def test_selected_majorana_fermion_classical_action(selection_bitsize, target_bitsize, target_gate): + gate = SelectedMajoranaFermion( + Register('selection', BQUInt(selection_bitsize, target_bitsize)), target_gate=target_gate + ) + assert_consistent_phased_classical_action( + gate, selection=range(target_bitsize), target=range(2**target_bitsize), control=range(2) + ) diff --git a/qualtran/testing.py b/qualtran/testing.py index e06e8f1f78..3adf726c77 100644 --- a/qualtran/testing.py +++ b/qualtran/testing.py @@ -39,6 +39,7 @@ Side, ) from qualtran._infra.composite_bloq import _get_flat_dangling_soqs +from qualtran.simulation.classical_sim import do_phased_classical_simulation from qualtran.symbolics import is_symbolic if TYPE_CHECKING: @@ -716,3 +717,29 @@ def assert_consistent_classical_action( np.testing.assert_equal( bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}' ) + + +def assert_consistent_phased_classical_action( + bloq: Bloq, + **parameter_ranges: Union[NDArray, Sequence[int], Sequence[Union[Sequence[int], NDArray]]], +): + """Check that the bloq has a phased classical action consistent with its decomposition. + + Args: + bloq: bloq to test. + parameter_ranges: named arguments giving ranges for each of the registers of the bloq. + """ + cb = bloq.decompose_bloq() + parameter_names = tuple(parameter_ranges.keys()) + for vals in itertools.product(*[parameter_ranges[p] for p in parameter_names]): + call_with = {p: v for p, v in zip(parameter_names, vals)} + bloq_res, bloq_phase = do_phased_classical_simulation(bloq, call_with) + decomposed_res, decomposed_phase = do_phased_classical_simulation(cb, call_with) + np.testing.assert_equal( + bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}' + ) + np.testing.assert_equal( + bloq_phase, + decomposed_phase, + err_msg=f'{bloq=} {call_with=} {bloq_phase=} {decomposed_phase=}', + )