diff --git a/pyproject.toml b/pyproject.toml index 7b15489557..f8e9768800 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ skip_glob = ["qualtran/protos/*"] [tool.pytest.ini_options] filterwarnings = [ 'ignore::DeprecationWarning:quimb.linalg.approx_spectral:', + 'ignore::qualtran.bloqs.bookkeeping.partition.LegacyPartitionWarning', 'ignore:.*standard platformdirs.*:DeprecationWarning:jupyter_client.*' ] # we define classes like TestBloq etc. which pytest tries to collect, diff --git a/qualtran/_infra/data_types.py b/qualtran/_infra/data_types.py index fe570ea7ee..6652b10191 100644 --- a/qualtran/_infra/data_types.py +++ b/qualtran/_infra/data_types.py @@ -358,6 +358,76 @@ def __str__(self) -> str: return 'CBit()' +@attrs.frozen +class _Any(BitEncoding[int]): + """Bag of Qubits of a given bitsize. + + Here (and throughout Qualtran), we use a big-endian bit convention. The most significant + bit is at index 0. + """ + + bitsize: SymbolicInt + + def get_domain(self) -> Iterable[int]: + raise TypeError(f"Ambiguous domain for {self}. Please use a more specific type.") + + def to_bits(self, x: int) -> List[int]: + if is_symbolic(self.bitsize): + raise ValueError(f"Cannot compute bits for symbolic {self.bitsize=}") + if x == 0: + return [0] * int(self.bitsize) + + raise TypeError( + f"Ambiguous encoding for {self} when encoding non zero value {x=}. Please use a more specific type." + ) + + def to_bits_array(self, x_array: NDArray[np.integer]) -> NDArray[np.uint8]: + if is_symbolic(self.bitsize): + raise ValueError(f"Cannot compute bits for symbolic {self.bitsize=}") + + values = np.atleast_1d(x_array) + if values.size == 0: + return np.zeros((values.shape[0], int(self.bitsize)), dtype=np.uint8) + + if not np.all(values == 0): + raise TypeError( + f"Ambiguous encoding for {self} when encoding non zero values {values=}. Please use a more specific type." + ) + + return np.zeros((values.shape[0], int(self.bitsize)), dtype=np.uint8) + + def from_bits(self, bits: Sequence[int]) -> int: + if all(x == 0 for x in bits): + return 0 + + raise TypeError( + f"Ambiguous value for {self} when bits ({bits}) are non zero. Please use a more specific type." + ) + + def from_bits_array(self, bits_array: NDArray[np.uint8]) -> NDArray[np.uint64]: + bitstrings = np.atleast_2d(bits_array) + if bitstrings.shape[1] != self.bitsize: + raise ValueError(f"Input bitsize {bitstrings.shape[1]} does not match {self.bitsize=}") + + if bitstrings.size == 0: + return np.zeros(bitstrings.shape[0], dtype=np.uint64) + + if not np.all(bitstrings == 0): + raise TypeError( + f"Ambiguous value for {self} when bits are non zero ({bits_array}). Please use a more specific type." + ) + + return np.zeros(bitstrings.shape[0], dtype=np.uint64) + + def assert_valid_val(self, val: int, debug_str: str = 'val') -> None: + pass + + def assert_valid_val_array( + self, val_array: NDArray[np.integer], debug_str: str = 'val' + ) -> None: + pass + + @attrs.frozen class QAny(QDType[Any]): """Opaque bag-of-qubits type.""" @@ -366,7 +436,7 @@ class QAny(QDType[Any]): @property def _bit_encoding(self) -> BitEncoding[Any]: - return _UInt(self.bitsize) + return _Any(self.bitsize) def __attrs_post_init__(self): if is_symbolic(self.bitsize): @@ -375,15 +445,6 @@ def __attrs_post_init__(self): if not isinstance(self.bitsize, int): raise ValueError(f"Bad bitsize for QAny: {self.bitsize}") - def get_classical_domain(self) -> Iterable[Any]: - raise TypeError(f"Ambiguous domain for {self}. Please use a more specific type.") - - def assert_valid_classical_val(self, val: Any, debug_str: str = 'val'): - pass - - def assert_valid_classical_val_array(self, val_array: NDArray, debug_str: str = 'val'): - pass - @attrs.frozen class _Int(BitEncoding[int]): diff --git a/qualtran/_infra/data_types_test.py b/qualtran/_infra/data_types_test.py index 355ed49d57..a9ab9cd227 100644 --- a/qualtran/_infra/data_types_test.py +++ b/qualtran/_infra/data_types_test.py @@ -441,10 +441,32 @@ def test_qbit_to_and_from_bits(): assert_to_and_from_bits_array_consistent(QBit(), [0, 1]) -def test_qany_to_and_from_bits(): - assert list(QAny(4).to_bits(10)) == [1, 0, 1, 0] +def test_qany_to_bits(): + with pytest.raises(TypeError, match=r"Ambiguous encoding"): + QAny(4).to_bits(10) - assert_to_and_from_bits_array_consistent(QAny(4), range(16)) + +def test_qany_from_bits_only_all_zeros(): + assert QAny(4).from_bits([0, 0, 0, 0]) == 0 + + with pytest.raises(TypeError, match=r"Ambiguous value"): + QAny(4).from_bits([1, 0, 0, 0]) + + +def test_qany_to_bits_array(): + enc = QAny(4) + assert np.all(enc.to_bits_array(np.array([0, 0])) == 0) + + with pytest.raises(TypeError, match=r"Ambiguous encoding"): + enc.to_bits_array(np.array([1])) + + +def test_qany_from_bits_array(): + enc = QAny(4) + assert np.all(enc.from_bits_array(np.zeros((2, 4), dtype=np.uint8)) == 0) + + with pytest.raises(TypeError, match=r"Ambiguous value"): + enc.from_bits_array(np.array([[1, 0, 0, 0]], dtype=np.uint8)) def test_qintonescomp_to_and_from_bits(): diff --git a/qualtran/bloqs/basic_gates/z_basis.py b/qualtran/bloqs/basic_gates/z_basis.py index 48231b5d50..a2e7c54327 100644 --- a/qualtran/bloqs/basic_gates/z_basis.py +++ b/qualtran/bloqs/basic_gates/z_basis.py @@ -43,9 +43,9 @@ ConnectionT, CtrlSpec, DecomposeTypeError, - QAny, QBit, QDType, + QUInt, Register, Side, Signature, @@ -479,7 +479,7 @@ def check(self, attribute, val): def dtype(self) -> QDType: if self.bitsize == 1: return QBit() - return QAny(self.bitsize) + return QUInt(self.bitsize) @cached_property def signature(self) -> Signature: diff --git a/qualtran/bloqs/bookkeeping/cast.py b/qualtran/bloqs/bookkeeping/cast.py index 9fd7bee565..535a5387b8 100644 --- a/qualtran/bloqs/bookkeeping/cast.py +++ b/qualtran/bloqs/bookkeeping/cast.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from functools import cached_property from typing import Dict, List, Tuple, TYPE_CHECKING @@ -26,13 +27,16 @@ CompositeBloq, ConnectionT, DecomposeTypeError, + QAny, QCDType, QDType, + QUInt, Register, Side, Signature, ) from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq +from qualtran.bloqs.bookkeeping.partition import LegacyPartitionWarning from qualtran.symbolics import is_symbolic if TYPE_CHECKING: @@ -120,7 +124,23 @@ def my_tensors( ] def on_classical_vals(self, reg: int) -> Dict[str, 'ClassicalValT']: - res = self.out_dtype.from_bits(self.inp_dtype.to_bits(reg)) + if isinstance(self.inp_dtype, QAny) or isinstance(self.out_dtype, QAny): + warnings.warn( + "Doing classical casting with QAny is ambiguous, transforming it as QUInt for legacy purposes", + category=LegacyPartitionWarning, + ) + match (self.inp_dtype, self.out_dtype): + case (QAny(), _): + res = self.out_dtype.from_bits(QUInt(self.inp_dtype.bitsize).to_bits(reg)) + case (_, QAny()): + res = QUInt(self.out_dtype.bitsize).from_bits(self.inp_dtype.to_bits(reg)) + case (QAny(), QAny()): + res = QUInt(self.out_dtype.bitsize).from_bits( + QUInt(self.inp_dtype.bitsize).to_bits(reg) + ) + case _: + res = self.out_dtype.from_bits(self.inp_dtype.to_bits(reg)) + return {'reg': res} def as_cirq_op(self, qubit_manager, reg: 'CirqQuregT') -> Tuple[None, Dict[str, 'CirqQuregT']]: diff --git a/qualtran/bloqs/bookkeeping/join.py b/qualtran/bloqs/bookkeeping/join.py index b194add90d..f2fe52cb93 100644 --- a/qualtran/bloqs/bookkeeping/join.py +++ b/qualtran/bloqs/bookkeeping/join.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from functools import cached_property from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING @@ -24,6 +25,7 @@ CompositeBloq, ConnectionT, DecomposeTypeError, + QAny, QBit, QDType, QUInt, @@ -32,6 +34,7 @@ Signature, ) from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq +from qualtran.bloqs.bookkeeping.partition import LegacyPartitionWarning from qualtran.drawing import directional_text_box, Text, WireSymbol if TYPE_CHECKING: @@ -99,6 +102,13 @@ def my_tensors( ] def on_classical_vals(self, reg: 'NDArray[np.uint]') -> Dict[str, int]: + if isinstance(self.dtype, QAny): + warnings.warn( + "Doing classical operations with QAny is ambiguous, returning a QUInt for legacy purposes", + category=LegacyPartitionWarning, + ) + return {'reg': QUInt(self.dtype.bitsize).from_bits(reg.tolist())} + return {'reg': self.dtype.from_bits(reg.tolist())} def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol': diff --git a/qualtran/bloqs/bookkeeping/partition.ipynb b/qualtran/bloqs/bookkeeping/partition.ipynb index 65ff47688d..9080931b62 100644 --- a/qualtran/bloqs/bookkeeping/partition.ipynb +++ b/qualtran/bloqs/bookkeeping/partition.ipynb @@ -39,8 +39,9 @@ "Partition a generic index into multiple registers.\n", "\n", "#### Parameters\n", - " - `n`: The total bitsize of the un-partitioned register\n", + " - `n`: The total bit-size of the un-partitioned register. Required if `dtype_in` is None. Deprecated. Kept for backward compatibility. Use `dtype_in` instead whenever possible.\n", " - `regs`: Registers to partition into. The `side` attribute is ignored.\n", + " - `dtype_in`: Type of the un-partitioned register. Required if `n` is None. If None, the type is inferred as `QUInt(n)`.\n", " - `partition`: `False` means un-partition instead. \n", "\n", "#### Registers\n", diff --git a/qualtran/bloqs/bookkeeping/partition.py b/qualtran/bloqs/bookkeeping/partition.py index 8cebbb5809..b0440f25ca 100644 --- a/qualtran/bloqs/bookkeeping/partition.py +++ b/qualtran/bloqs/bookkeeping/partition.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import warnings from functools import cached_property -from typing import Dict, List, Sequence, Tuple, TYPE_CHECKING +from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import numpy as np import sympy @@ -29,6 +30,7 @@ DecomposeTypeError, QAny, QDType, + QUInt, Register, Side, Signature, @@ -46,16 +48,45 @@ from qualtran.simulation.classical_sim import ClassicalValT +class LegacyPartitionWarning(DeprecationWarning): + """Warnings for legacy Partition usage, when declaring only n.""" + + +def _constrain_qany_reg(reg: Register): + """Changes the dtype of a register to note break legacy code + + This function should be bound to dissapear + """ + if isinstance(reg.dtype, QAny): + warnings.warn( + f"Doing classical casting with QAny ({reg=}) is ambiguous, transforming it as QUInt for legacy purposes", + category=LegacyPartitionWarning, + ) + return evolve(reg, dtype=QUInt(reg.dtype.bitsize)) + return reg + + +def _regs_to_tuple(x): + if x is None: + return None + return x if isinstance(x, tuple) else tuple(x) + + +def _not_none(_inst, attr, value): + if value is None: + raise ValueError(f"{attr.name} cannot be None") + + class _PartitionBase(_BookkeepingBloq, metaclass=abc.ABCMeta): """Generalized paritioning functionality.""" @property @abc.abstractmethod - def n(self) -> SymbolicInt: ... + def n(self) -> Optional[SymbolicInt]: ... - @cached_property - def lumped_dtype(self) -> QDType: - return QAny(bitsize=self.n) + @property + @abc.abstractmethod + def lumped_dtype(self) -> QDType: ... @property @abc.abstractmethod @@ -98,6 +129,8 @@ def my_tensors( ) -> List['qtn.Tensor']: import quimb.tensor as qtn + if self.n is None: + raise DecomposeTypeError(f"cannot compute tensors with unknown n for {self}") if is_symbolic(self.n): raise DecomposeTypeError(f"cannot compute tensors for symbolic {self}") @@ -124,6 +157,7 @@ def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT'] xbits = self.lumped_dtype.to_bits(x) start = 0 for reg in self._regs: + reg = _constrain_qany_reg(reg) size = int(np.prod(reg.shape + (reg.bitsize,))) bits_reg = xbits[start : start + size] if reg.shape == (): @@ -138,6 +172,7 @@ def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT'] def _classical_unpartition_to_bits(self, **vals: 'ClassicalValT') -> NDArray[np.uint8]: out_vals: list[NDArray[np.uint8]] = [] for reg in self._regs: + reg = _constrain_qany_reg(reg) reg_val = np.asanyarray(vals[reg.name]) bitstrings = reg.dtype.to_bits_array(reg_val.ravel()) out_vals.append(bitstrings.ravel()) @@ -166,8 +201,11 @@ class Partition(_PartitionBase): """Partition a generic index into multiple registers. Args: - n: The total bitsize of the un-partitioned register + n: The total bit-size of the un-partitioned register. Required if `dtype_in` is None. + Deprecated. Kept for backward compatibility. Use `dtype_in` instead whenever possible. regs: Registers to partition into. The `side` attribute is ignored. + dtype_in: Type of the un-partitioned register. Required if `n` is None. If None, + the type is inferred as `QUInt(n)`. partition: `False` means un-partition instead. Registers: @@ -175,18 +213,42 @@ class Partition(_PartitionBase): [user spec]: The registers provided by the `regs` argument. RIGHT by default. """ - n: SymbolicInt - regs: Tuple[Register, ...] = field( - converter=lambda x: x if isinstance(x, tuple) else tuple(x), validator=validators.min_len(1) + n: Optional[SymbolicInt] = field(default=None) + regs: Optional[Tuple[Register, ...]] = field( + converter=_regs_to_tuple, validator=(_not_none, validators.min_len(1)), default=None ) - partition: bool = True + dtype_in: Optional[QDType] = field(default=None) + partition: bool = field(default=True) def __attrs_post_init__(self): + if self.n is None and self.dtype_in is None: + raise ValueError(f"Provide exactly n or dtype_in {self.n=}, {self.dtype_in=}") + elif self.n is not None and self.dtype_in is None: + warnings.warn( + "Partition: By not setting dtype_in you could encounter errors when running " + "assert_consistent_classical_action", + category=LegacyPartitionWarning, + ) + elif self.n is None and self.dtype_in is not None: + object.__setattr__(self, "n", self.dtype_in.num_qubits) + elif self.n is not None and self.dtype_in is not None: + if self.n != self.dtype_in.num_qubits: + raise ValueError( + f"{self.dtype_in=} should have size {self.n=}, currently {self.dtype_in.num_qubits=}" + ) + warnings.warn( + "Specifying both n and dtype_in is redundant", category=UserWarning, stacklevel=1 + ) + self._validate() + @property + def lumped_dtype(self) -> QDType: + return QUInt(bitsize=cast(SymbolicInt, self.n)) if self.dtype_in is None else self.dtype_in + @property def _regs(self) -> Sequence[Register]: - return self.regs + return cast(Tuple[Register, ...], self.regs) @cached_property def signature(self) -> 'Signature': @@ -195,11 +257,11 @@ def signature(self) -> 'Signature': return Signature( [Register('x', self.lumped_dtype, side=lumped)] - + [evolve(reg, side=partitioned) for reg in self.regs] + + [evolve(reg, side=partitioned) for reg in self._regs] ) def adjoint(self): - return evolve(self, partition=not self.partition) + return evolve(self, n=None, dtype_in=self.lumped_dtype, partition=not self.partition) @frozen @@ -228,6 +290,10 @@ class Split2(_PartitionBase): def n(self) -> SymbolicInt: return self.n1 + self.n2 + @property + def lumped_dtype(self) -> QDType: + return QUInt(bitsize=self.n) + @property def partition(self) -> bool: return True @@ -289,6 +355,10 @@ class Join2(_PartitionBase): def n(self) -> SymbolicInt: return self.n1 + self.n2 + @property + def lumped_dtype(self) -> QDType: + return QUInt(bitsize=self.n) + @property def partition(self) -> bool: return False diff --git a/qualtran/bloqs/bookkeeping/partition_test.py b/qualtran/bloqs/bookkeeping/partition_test.py index c6c3236332..7be26b864c 100644 --- a/qualtran/bloqs/bookkeeping/partition_test.py +++ b/qualtran/bloqs/bookkeeping/partition_test.py @@ -20,7 +20,7 @@ import pytest from attrs import frozen -from qualtran import Bloq, BloqBuilder, QAny, QGF, Register, Signature, Soquet, SoquetT +from qualtran import Bloq, BloqBuilder, QAny, QGF, QInt, QUInt, Register, Signature, Soquet, SoquetT from qualtran._infra.gate_with_registers import get_named_qubits from qualtran.bloqs.basic_gates import CNOT from qualtran.bloqs.bookkeeping import Partition @@ -37,10 +37,22 @@ def test_partition(bloq_autotester): def test_partition_check(): with pytest.raises(ValueError): _ = Partition(n=0, regs=()) + with pytest.raises(ValueError): + _ = Partition(n=10, regs=None) + with pytest.raises(ValueError): + _ = Partition(dtype_in=QUInt(10)) with pytest.raises(ValueError): _ = Partition(n=1, regs=(Register('x', QAny(2)),)) with pytest.raises(ValueError): _ = Partition(n=4, regs=(Register('x', QAny(1)), Register('x', QAny(3)))) + with pytest.raises(ValueError): + _ = Partition(n=10) + + regs = (Register("xx", QUInt(4)), Register("yy", QInt(6))) + with pytest.raises(ValueError): + _ = Partition(regs=regs) + with pytest.raises(ValueError): + _ = Partition(n=11, regs=regs) @frozen @@ -57,7 +69,7 @@ def signature(self) -> Signature: def build_composite_bloq(self, bb: 'BloqBuilder', test_regs: 'SoquetT') -> Dict[str, 'Soquet']: bloq_regs = self.test_bloq.signature - partition = Partition(self.bitsize, bloq_regs) # type: ignore[arg-type] + partition = Partition(dtype_in=QUInt(self.bitsize), regs=bloq_regs) # type: ignore[arg-type] out_regs = bb.add(partition, x=test_regs) out_regs = bb.add(self.test_bloq, **{reg.name: sp for reg, sp in zip(bloq_regs, out_regs)}) test_regs = bb.add( diff --git a/qualtran/bloqs/bookkeeping/split.py b/qualtran/bloqs/bookkeeping/split.py index ab02e2e6f4..485f2752fa 100644 --- a/qualtran/bloqs/bookkeeping/split.py +++ b/qualtran/bloqs/bookkeeping/split.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from functools import cached_property from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING @@ -26,6 +27,7 @@ CompositeBloq, ConnectionT, DecomposeTypeError, + QAny, QBit, QDType, QUInt, @@ -34,6 +36,7 @@ Signature, ) from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq +from qualtran.bloqs.bookkeeping.partition import LegacyPartitionWarning from qualtran.drawing import directional_text_box, Text, WireSymbol if TYPE_CHECKING: @@ -92,6 +95,13 @@ def as_pl_op(self, wires: 'Wires') -> 'Operation': return None def on_classical_vals(self, reg: int) -> Dict[str, 'ClassicalValT']: + if isinstance(self.dtype, QAny): + warnings.warn( + "Doing classical operations with QAny is ambiguous, returning a QUInt for legacy purposes", + category=LegacyPartitionWarning, + ) + return {'reg': np.asarray(QUInt(self.dtype.bitsize).to_bits(reg))} + return {'reg': np.asarray(self.dtype.to_bits(reg))} def my_tensors( diff --git a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard_test.py b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard_test.py index c3307b5f56..6cde7fca4c 100644 --- a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard_test.py +++ b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard_test.py @@ -17,7 +17,7 @@ import pytest import qualtran.testing as qlt_testing -from qualtran import QAny, QUInt +from qualtran import QUInt from qualtran.bloqs.chemistry.hubbard_model.qubitization import ( HubbardMajorannaOperator, HubbardSpinUpZ, @@ -124,7 +124,7 @@ def test_hubbard_spin_up_z_classical(): onehot[the_x + M * the_y] = 1 # The bloqs deal with one monolithic target register. - system = QAny(2 * N).from_bits(all_ones + onehot) + system = QUInt(2 * N).from_bits(all_ones + onehot) # Go through all possible x,y selection indices and see if a phase is applied. negative_phases = [] diff --git a/qualtran/bloqs/mcmt/ctrl_spec_and.py b/qualtran/bloqs/mcmt/ctrl_spec_and.py index 3fa7edd03c..138dce22d1 100644 --- a/qualtran/bloqs/mcmt/ctrl_spec_and.py +++ b/qualtran/bloqs/mcmt/ctrl_spec_and.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from functools import cached_property from typing import Optional, TYPE_CHECKING, Union import numpy as np import sympy -from attrs import frozen +from attrs import evolve, frozen from qualtran import ( Bloq, @@ -35,7 +36,7 @@ Side, Signature, ) -from qualtran.bloqs.bookkeeping.partition import Partition +from qualtran.bloqs.bookkeeping.partition import LegacyPartitionWarning, Partition from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd from qualtran.drawing import directional_text_box, Text, WireSymbol from qualtran.resource_counting.generalizers import ignore_split_join @@ -125,9 +126,18 @@ def _flat_cvs(self) -> Union[tuple[int, ...], HasLength]: return HasLength(self.n_ctrl_qubits) flat_cvs: list[int] = [] - for reg, cv in zip(self.control_registers, self.ctrl_spec.cvs): - assert isinstance(cv, np.ndarray) - flat_cvs.extend(reg.dtype.to_bits_array(cv.ravel()).ravel()) + for reg, cvs in zip(self.control_registers, self.ctrl_spec.cvs): + assert isinstance(cvs, np.ndarray) + + if isinstance(reg.dtype, QAny) and not np.all(cvs == 0): + warnings.warn( + f"Asking for a non zero controll value ({cvs}) for a QAny ({reg=}) is ambiguous, " + "transforming QAny in QUInt for legacy purposes", + category=LegacyPartitionWarning, + ) + reg = evolve(reg, dtype=QUInt(reg.dtype.bitsize)) + + flat_cvs.extend(reg.dtype.to_bits_array(cvs.ravel()).ravel()) return tuple(flat_cvs) def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> dict[str, 'SoquetT']: diff --git a/qualtran/simulation/tensor/_tensor_from_classical.py b/qualtran/simulation/tensor/_tensor_from_classical.py index bf87b44b32..4d4c84681c 100644 --- a/qualtran/simulation/tensor/_tensor_from_classical.py +++ b/qualtran/simulation/tensor/_tensor_from_classical.py @@ -17,10 +17,13 @@ import numpy as np from numpy.typing import NDArray +from qualtran.bloqs.bookkeeping.partition import _constrain_qany_reg + if TYPE_CHECKING: import quimb.tensor as qtn - from qualtran import Bloq, ConnectionT, Register + from qualtran import Bloq, ConnectionT, QAny, QUInt, Register + from qualtran.bloqs.bookkeeping.partition import LegacyPartitionWarning from qualtran.simulation.classical_sim import ClassicalValT @@ -55,7 +58,7 @@ def _bloq_to_dense_via_classical_action(bloq: 'Bloq') -> NDArray: assert np.size(last) == 0 input_kwargs = { - reg.name: _bits_to_classical_reg_data(reg, bits) + reg.name: _bits_to_classical_reg_data(_constrain_qany_reg(reg), bits) for reg, bits in zip(bloq.signature.lefts(), inputs_t) } output_args = bloq.call_classically(**input_kwargs) @@ -63,7 +66,7 @@ def _bloq_to_dense_via_classical_action(bloq: 'Bloq') -> NDArray: if output_args: output_t = np.concatenate( [ - reg.dtype.to_bits_array(np.asarray(vals)).flat + _constrain_qany_reg(reg).dtype.to_bits_array(np.asarray(vals)).flat for reg, vals in zip(bloq.signature.rights(), output_args) ] )