diff --git a/qualtran/__init__.py b/qualtran/__init__.py index a0f2b5e2d2..c56b46dae5 100644 --- a/qualtran/__init__.py +++ b/qualtran/__init__.py @@ -46,8 +46,11 @@ CompositeBloq, BloqBuilder, DidNotFlattenAnythingError, + Soquet, SoquetT, ConnectionT, + QVar, + QVarT, ) from ._infra.data_types import ( @@ -84,14 +87,7 @@ # Internal imports: none # External imports: none -from ._infra.quantum_graph import ( - BloqInstance, - Connection, - DanglingT, - LeftDangle, - RightDangle, - Soquet, -) +from ._infra.quantum_graph import BloqInstance, Connection, DanglingT, LeftDangle, RightDangle from ._infra.gate_with_registers import GateWithRegisters diff --git a/qualtran/_infra/adjoint.py b/qualtran/_infra/adjoint.py index 1ed4b1cb1e..b011268912 100644 --- a/qualtran/_infra/adjoint.py +++ b/qualtran/_infra/adjoint.py @@ -18,7 +18,15 @@ from attrs import frozen -from .composite_bloq import _binst_to_cxns, _cxns_to_soq_dict, _map_soqs, _reg_to_soq, BloqBuilder +from .composite_bloq import ( + _binst_to_cxns, + _cxns_to_soq_dict, + _map_soqs, + _reg_to_soq, + _SoquetT, + BloqBuilder, + QVarT, +) from .gate_with_registers import GateWithRegisters from .quantum_graph import LeftDangle, RightDangle from .registers import Signature @@ -26,12 +34,12 @@ if TYPE_CHECKING: import cirq - from qualtran import Bloq, CompositeBloq, Register, Signature, SoquetT + from qualtran import Bloq, CompositeBloq, Register, Signature from qualtran.drawing import WireSymbol from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator -def _adjoint_final_soqs(cbloq: 'CompositeBloq', new_signature: Signature) -> Dict[str, 'SoquetT']: +def _adjoint_final_soqs(cbloq: 'CompositeBloq', new_signature: Signature) -> Dict[str, '_SoquetT']: """`CompositeBloq.final_soqs()` but backwards.""" if LeftDangle not in cbloq._binst_graph: return {} @@ -57,15 +65,15 @@ def _adjoint_cbloq(cbloq: 'CompositeBloq') -> 'CompositeBloq': # First, we reverse the registers to initialize the BloqBuilder. old_signature = cbloq.signature new_signature = cbloq.signature.adjoint() - old_i_soqs = [_reg_to_soq(RightDangle, reg) for reg in old_signature.rights()] - new_i_soqs = [_reg_to_soq(LeftDangle, reg) for reg in new_signature.lefts()] - soq_map: List[Tuple[SoquetT, SoquetT]] = list(zip(old_i_soqs, new_i_soqs)) # Then we reverse the order of subbloqs bloqnections = reversed(list(cbloq.iter_bloqnections())) # And add subbloq.adjoint() back in for each subbloq. bb, _ = BloqBuilder.from_signature(new_signature) + old_i_soqs = [_reg_to_soq(RightDangle, reg) for reg in old_signature.rights()] + new_i_soqs = [bb._reg_to_qvar(LeftDangle, reg) for reg in new_signature.lefts()] + soq_map: List[Tuple[_SoquetT, QVarT]] = list(zip(old_i_soqs, new_i_soqs)) for binst, preds, succs in bloqnections: # Instead of get_me returning the right element of a predecessor connection, # it's the left element of a successor connection. diff --git a/qualtran/_infra/composite_bloq.ipynb b/qualtran/_infra/composite_bloq.ipynb index 1fe7ddcab5..4e0f4d798c 100644 --- a/qualtran/_infra/composite_bloq.ipynb +++ b/qualtran/_infra/composite_bloq.ipynb @@ -343,7 +343,7 @@ "bb, _ = BloqBuilder.from_signature(cbloq.signature)\n", "\n", "# We'll have to \"map\" the soquets from our template cbloq to our new one\n", - "soq_map: List[Tuple[SoquetT, SoquetT]] = []\n", + "soq_map = bb.initial_soq_map(cbloq.signature.lefts())\n", " \n", "# Iteration yields each bloq instance as well as its input and output soquets.\n", "for binst, in_soqs, old_out_soqs in cbloq.iter_bloqsoqs():\n", @@ -512,8 +512,8 @@ "# Go through and decompose each subbloq\n", "# We'll manually code this up in this notebook since this isn't a useful operation.\n", "bb, _ = BloqBuilder.from_signature(flat_three_p.signature)\n", - "soq_map: List[Tuple[SoquetT, SoquetT]] = []\n", - " \n", + "soq_map: List[Tuple[SoquetT, SoquetT]] = bb.initial_soq_map(flat_three_p.signature.lefts())\n", + "\n", "for binst, in_soqs, old_out_soqs in flat_three_p.iter_bloqsoqs():\n", " in_soqs = bb.map_soqs(in_soqs, soq_map)\n", " \n", diff --git a/qualtran/_infra/composite_bloq.py b/qualtran/_infra/composite_bloq.py index 4a5f117aab..5297456a03 100644 --- a/qualtran/_infra/composite_bloq.py +++ b/qualtran/_infra/composite_bloq.py @@ -13,10 +13,11 @@ # limitations under the License. """Classes for building and manipulating `CompositeBloq`.""" - +import warnings from collections.abc import Hashable from functools import cached_property from typing import ( + _ProtocolMeta, Callable, cast, Dict, @@ -32,6 +33,7 @@ Set, Tuple, TYPE_CHECKING, + TypeAlias, TypeGuard, TypeVar, Union, @@ -46,7 +48,15 @@ from .binst_graph_iterators import greedy_topological_sort from .bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError from .data_types import check_dtypes_consistent, QAny, QBit, QCDType, QDType -from .quantum_graph import BloqInstance, Connection, DanglingT, LeftDangle, RightDangle, Soquet +from .quantum_graph import ( + _QVar, + _Soquet, + BloqInstance, + Connection, + DanglingT, + LeftDangle, + RightDangle, +) from .registers import Register, Side, Signature if TYPE_CHECKING: @@ -60,29 +70,107 @@ from qualtran.symbolics import SymbolicInt -class SoquetT(Protocol): - """Either a Soquet or an array thereof. +class _NoSoquetIsInstanceMeta(_ProtocolMeta): + def __instancecheck__(cls, instance): + warnings.warn("isinstance(..., Soquet) is deprecated.", DeprecationWarning) + return isinstance(instance, (_Soquet, _QVar)) + # raise TypeError( + # "Do not rely on isinstance(..., Soquet). " + # "To distinguish a single soquet quantum variable " + # "from an n-dimensional array of them, use " + # "`SoquetT.is_single(soq)` and/or `SoquetT.is_ndarray(soq)`." + # ) + + +class Soquet(Protocol, metaclass=_NoSoquetIsInstanceMeta): + """A typing protocol for a soquet or qvar. + + In Qualtran v0.7 and earlier, the immutable `Soquet` object represented both the "quantum + variables" being passed around during bloq building as well as the nodes of the compute graph. + In Qualtran v0.8+, these concerns are separated. The compute graph nodes are frozen dataclasses + of type `_Soquet`, and the quantum variables are *mutable* objects of type `_QVar`. We put + additional helper attributes and methods onto `_QVar` to assist in bloq building. - To narrow objects of this type, use `BloqBuilder.is_single(soq)` and/or - `BloqBuilder.is_ndarray(soqs)`. + For backwards compatibility, the `Soquet` name is now assigned to this class: a + `typing.Protocol` that encapsulates the duck-typing behavior of `_Soquet` and `_QVar`. + Bloqs in the wild should not have to update the type annotations in `build_composite_bloq` + with this backwards compatibilty typing shim. + + `isinstance(..., Soquet)` checks will emit a deprecation warning + and return True for *either* `_Soquet` or `_QVar`. + + If you're using isinstance(soq, Soquet) to determine whether an item is a single object + or an ndarray of those objects, use `BloqBuilder.is_single(x)` or + `BloqBuilder.is_ndarray(x)`. See the documentation in `QVarT` for an example. + + If you're developing library functionality, you can port isinstance checks to either + `_Soquet` or `QVar` as appropriate. + """ + + def __new__(cls, *args, **kwargs): + warnings.warn( + "Constructing a soquet via `Soquet(...)` is deprecated. " + "User code should never construct soquets directly: " + "please use BloqBuilder.", + DeprecationWarning, + ) + return _Soquet(*args, **kwargs) + + @property + def shape(self) -> Tuple[int, ...]: ... + + def item(self, *args) -> _QVar: ... + + @property + def dtype(self) -> 'QCDType': ... + + def __hash__(self): ... + + @property + def reg(self) -> 'Register': ... + + def __getitem__(self, item) -> 'QVarT': ... + + +class _SoquetT(Protocol): + """Either an actual _Soquet or an array thereof.""" + + @property + def shape(self) -> Tuple[int, ...]: ... + + def item(self, *args) -> _Soquet: ... + + +class QVarT(Protocol): + """Either a QVar or an array thereof. + + To narrow objects of this type, use `BloqBuilder.is_single(qvar)` and/or + `BloqBuilder.is_ndarray(qvars)`. Example: - >>> soq_or_soqs: SoquetT - ... if BloqBuilder.is_ndarray(soq_or_soqs): - ... first_soq = soq_or_soqs.reshape(-1).item(0) + >>> qvar_or_qvars: QVarT + ... if BloqBuilder.is_ndarray(qvar_or_qvars): + ... first_soq = qvar_or_qvars.reshape(-1).item(0) ... else: ... # Note: `.item()` raises if not a single item. - ... first_soq = soq_or_soqs.item() + ... first_soq = qvar_or_qvars.item() """ @property def shape(self) -> Tuple[int, ...]: ... - def item(self, *args) -> Soquet: ... + def item(self, *args) -> _QVar: ... + + def __getitem__(self, item) -> 'QVarT': ... + + +# Compatibilities aliases +SoquetT: TypeAlias = QVarT +QVar: TypeAlias = Soquet -SoquetInT = Union[SoquetT, Sequence[SoquetT]] +SoquetInT = Union[QVarT, Sequence[QVarT]] """A soquet or array-like of soquets. This type alias is used for input argument to parts of the library that are more @@ -166,7 +254,7 @@ def _default_bloq_instances(self): } @cached_property - def all_soquets(self) -> FrozenSet[Soquet]: + def all_soquets(self) -> FrozenSet[_Soquet]: """A set of all `Soquet`s present in the compute graph.""" soquets = {cxn.left for cxn in self.connections} soquets |= {cxn.right for cxn in self.connections} @@ -323,7 +411,7 @@ def iter_bloqnections( def iter_bloqsoqs( self, - ) -> Iterator[Tuple[BloqInstance, Dict[str, SoquetT], Tuple[SoquetT, ...]]]: + ) -> Iterator[Tuple[BloqInstance, Dict[str, _SoquetT], Tuple[_SoquetT, ...]]]: """Iterate over bloq instances and their input soquets. This method is helpful for "adding from" this existing composite bloq. You must @@ -336,7 +424,7 @@ def iter_bloqsoqs( >>> from qualtran.bloqs.for_testing.with_decomposition import TestParallelCombo >>> cbloq = TestParallelCombo().decompose_bloq() >>> bb, _ = BloqBuilder.from_signature(cbloq.signature) - >>> soq_map: List[Tuple[SoquetT, SoquetT]] = [] + >>> soq_map = bb.initial_soq_map(cbloq.signature.lefts()) >>> for binst, in_soqs, old_out_soqs in cbloq.iter_bloqsoqs(): ... in_soqs = bb.map_soqs(in_soqs, soq_map) ... new_out_soqs = bb.add_t(binst.bloq, **in_soqs) @@ -362,7 +450,7 @@ def iter_bloqsoqs( out_soqs = tuple(_reg_to_soq(binst, reg) for reg in binst.bloq.signature.rights()) yield binst, in_soqs, out_soqs - def final_soqs(self) -> Dict[str, SoquetT]: + def final_soqs(self) -> Dict[str, _SoquetT]: """Return the final output soquets. This method is helpful for finalizing an "add from" operation, see `iter_bloqsoqs`. @@ -380,10 +468,11 @@ def final_soqs(self) -> Dict[str, SoquetT]: def copy(self) -> 'CompositeBloq': """Create a copy of this composite bloq by re-building it.""" bb, _ = BloqBuilder.from_signature(self.signature) - soq_map: List[Tuple[SoquetT, SoquetT]] = [] + soq_map = bb.initial_soq_map(self.signature.lefts()) + for binst, in_soqs, old_out_soqs in self.iter_bloqsoqs(): - in_soqs = _map_soqs(in_soqs, soq_map) - new_out_soqs = bb.add_t(binst.bloq, **in_soqs) + mapped_in_soqs = _map_soqs(in_soqs, soq_map) + new_out_soqs = bb.add_t(binst.bloq, **mapped_in_soqs) soq_map.extend(zip(old_out_soqs, new_out_soqs)) fsoqs = _map_soqs(self.final_soqs(), soq_map) @@ -426,11 +515,11 @@ def flatten_once( # pylint: disable=protected-access bb._i = max(binst.i for binst in self.bloq_instances) + 1 - flat_soq_map: Dict[Soquet, Soquet] = {} - new_out_soqs: Tuple[SoquetT, ...] + soq_map = bb.initial_soq_map(self.signature.lefts()) + new_out_soqs: Tuple[QVarT, ...] did_work = False - for binst, in_soqs, old_out_soqs in self.iter_bloqsoqs(): - in_soqs = _map_flat_soqs(in_soqs, flat_soq_map) # update `in_soqs` from old to new. + for binst, _in_soqs, old_out_soqs in self.iter_bloqsoqs(): + in_soqs = _map_soqs(_in_soqs, soq_map) # update `in_soqs` from old to new. if pred(binst): try: new_out_soqs = bb.add_from(binst.bloq, **in_soqs) @@ -445,12 +534,12 @@ def flatten_once( # pylint: disable=protected-access new_out_soqs = tuple(soq for _, soq in bb._add_binst(binst, in_soqs=in_soqs)) - _update_flat_soq_map(zip(old_out_soqs, new_out_soqs), flat_soq_map) + soq_map.extend(zip(old_out_soqs, new_out_soqs)) if not did_work: raise DidNotFlattenAnythingError() - fsoqs = _map_flat_soqs(self.final_soqs(), flat_soq_map) + fsoqs = _map_soqs(self.final_soqs(), soq_map) return bb.finalize(**fsoqs) def flatten( @@ -600,7 +689,7 @@ def _get_soquet( idx: Tuple[int, ...] = (), *, binst_graph: nx.DiGraph, -) -> 'Soquet': +) -> '_Soquet': """Retrieve a soquet given identifying information. We can uniquely address a Soquet by the arguments to this function. @@ -632,9 +721,9 @@ def _get_soquet( def _cxns_to_soq_dict( regs: Iterable[Register], cxns: Iterable[Connection], - get_me: Callable[[Connection], Soquet], - get_assign: Callable[[Connection], Soquet], -) -> Dict[str, SoquetT]: + get_me: Callable[[Connection], _Soquet], + get_assign: Callable[[Connection], _Soquet], +) -> Dict[str, '_SoquetT']: """Helper function to get a dictionary of soquets from a list of connections. Args: @@ -651,7 +740,7 @@ def _cxns_to_soq_dict( Returns: soqdict: A dictionary mapping register name to the selected soquets. """ - soqdict: Dict[str, SoquetT] = {} + soqdict: Dict[str, '_SoquetT'] = {} # Initialize multi-dimensional dictionary values. for reg in regs: @@ -673,7 +762,7 @@ def _cxns_to_soq_dict( def _cxns_to_cxn_dict( - regs: Iterable[Register], cxns: Iterable[Connection], get_me: Callable[[Connection], Soquet] + regs: Iterable[Register], cxns: Iterable[Connection], get_me: Callable[[Connection], _Soquet] ) -> Dict[str, ConnectionT]: """Helper function to get a dictionary of connections from a list of connections @@ -707,7 +796,7 @@ def _cxns_to_cxn_dict( return cxndict -def _get_dangling_soquets(signature: Signature, right: bool = True) -> Dict[str, SoquetT]: +def _get_dangling_soquets(signature: Signature, right: bool = True) -> Dict[str, _SoquetT]: """Get instantiated dangling soquets from a `Signature`. Args: @@ -727,13 +816,14 @@ def _get_dangling_soquets(signature: Signature, right: bool = True) -> Dict[str, regs = signature.lefts() dang = LeftDangle - all_soqs: Dict[str, SoquetT] = {} + all_soqs: Dict[str, _SoquetT] = {} + soqs: _SoquetT for reg in regs: all_soqs[reg.name] = _reg_to_soq(dang, reg) return all_soqs -def _flatten_soquet_collection(vals: Iterable[SoquetT]) -> List[Soquet]: +def _flatten_soquet_collection(vals: Iterable[_SoquetT]) -> List[_Soquet]: """Flatten SoquetT into a flat list of Soquet. SoquetT is either a unit Soquet or an ndarray thereof. @@ -748,7 +838,7 @@ def _flatten_soquet_collection(vals: Iterable[SoquetT]) -> List[Soquet]: return soqvals -def _get_flat_dangling_soqs(signature: Signature, right: bool) -> List[Soquet]: +def _get_flat_dangling_soqs(signature: Signature, right: bool) -> List[_Soquet]: """Flatten out the values of the soquet dictionaries from `_get_dangling_soquets`.""" soqdict = _get_dangling_soquets(signature, right=right) return _flatten_soquet_collection(soqdict.values()) @@ -774,19 +864,12 @@ def add(self, x: Hashable): pass -def _reg_to_soq( - binst: Union[BloqInstance, DanglingT], - reg: Register, - available: Union[Set[Soquet], _IgnoreAvailable] = _IgnoreAvailable(), -) -> SoquetT: +def _reg_to_soq(binst: Union[BloqInstance, DanglingT], reg: Register) -> _SoquetT: """Create the soquet or array of soquets for a register. Args: binst: The output soquet's bloq instance. reg: The register - available: By default, don't track the soquets. If a set is provided, we will add - each individual, indexed soquet to it. This is used for bookkeeping - in `BloqBuilder`. Returns: A Soquet or Soquets. For multi-dimensional @@ -796,16 +879,14 @@ def _reg_to_soq( if reg.shape: soqs = np.empty(reg.shape, dtype=object) for ri in reg.all_idxs(): - soq = Soquet(binst, reg, idx=ri) + soq = _Soquet(binst, reg, idx=ri) soqs[ri] = soq - available.add(soq) return soqs # Annoyingly, this must be a special case. # Otherwise, x[i] = thing will nest *array* objects because our ndarray's type is # 'object'. This wouldn't happen--for example--with an integer array. - soq = Soquet(binst, reg) - available.add(soq) + soq = _Soquet(binst, reg) return soq @@ -813,7 +894,7 @@ def _process_soquets( registers: Iterable[Register], in_soqs: Mapping[str, SoquetInT], debug_str: str, - func: Callable[[Soquet, Register, Tuple[int, ...]], None], + func: Callable[[_QVar, Register, Tuple[int, ...]], None], ) -> None: """Process and validate `in_soqs` in the context of `registers`. @@ -845,15 +926,15 @@ def _process_soquets( # this also supports length-zero indexing natively, which is good too. in_soq = np.asarray(in_soqs[reg.name]) except KeyError: - raise BloqError(f"{debug_str} requires a Soquet named `{reg.name}`.") from None + raise BloqError(f"During {debug_str}, we expected a value for '{reg.name}'.") from None unchecked_names.remove(reg.name) # so we can check for surplus arguments. for li in reg.all_idxs(): - idxed_soq = in_soq[li].item() + idxed_soq = in_soq.item(li) func(idxed_soq, reg, li) if not check_dtypes_consistent(idxed_soq.dtype, reg.dtype): - extra_str = f"{idxed_soq.reg.name}: {idxed_soq.dtype} vs {reg.name}: {reg.dtype}" + extra_str = f"{idxed_soq}: {idxed_soq.dtype} vs {reg.name}: {reg.dtype}" raise BloqError( f"{debug_str} register dtypes are not consistent {extra_str}." ) from None @@ -862,8 +943,8 @@ def _process_soquets( def _map_soqs( - soqs: Dict[str, SoquetT], soq_map: Iterable[Tuple[SoquetT, SoquetT]] -) -> Dict[str, SoquetT]: + soqs: Dict[str, _SoquetT], soq_map: Iterable[Tuple[_SoquetT, QVarT]] +) -> Dict[str, QVarT]: """Map `soqs` according to `soq_map`. See `CompositeBloq.iter_bloqsoqs` for example code. The public entry-point @@ -881,11 +962,11 @@ def _map_soqs( """ # First: flatten out any numpy arrays - flat_soq_map: Dict[Soquet, Soquet] = {} + flat_soq_map: Dict[_Soquet, _QVar] = {} for old_soqs, new_soqs in soq_map: if BloqBuilder.is_single(old_soqs): assert BloqBuilder.is_single(new_soqs), new_soqs - flat_soq_map[old_soqs] = new_soqs.item() + flat_soq_map[old_soqs.item()] = new_soqs.item() continue assert isinstance(old_soqs, np.ndarray), old_soqs @@ -895,14 +976,22 @@ def _map_soqs( flat_soq_map[o] = n # Then use vectorize to use the flat mapping. - def _map_soq(soq: Soquet) -> Soquet: + def _map_soq(soq: _Soquet) -> _QVar: # Helper function to map an individual soquet. - return flat_soq_map.get(soq, soq) + if soq in flat_soq_map: + return flat_soq_map[soq] + + warnings.warn( + "You must initialize your `soq_map` with `bb.initial_soq_map`. " + "Using a fallback that will disable all QVar features. " + "See the docstring for `CompositeBloq.copy` for an example of how to structure your code." + ) + return _QVar(soq, bb=None) # type: ignore[arg-type] # Use `vectorize` to call `_map_soq` on each element of the array. vmap = np.vectorize(_map_soq, otypes=[object]) - def _map_soqs(soqs: SoquetT) -> SoquetT: + def _map_soqs(soqs: _SoquetT) -> 'QVarT': if BloqBuilder.is_ndarray(soqs): return vmap(soqs) return _map_soq(soqs.item()) @@ -923,7 +1012,7 @@ def _map_soq(soq: Soquet) -> Soquet: vmap = np.vectorize(_map_soq, otypes=[object]) def _map_soqs(soqs: SoquetT) -> SoquetT: - if isinstance(soqs, Soquet): + if isinstance(soqs, _Soquet): return _map_soq(soqs) return vmap(soqs) @@ -935,8 +1024,8 @@ def _update_flat_soq_map( ): """Flatten SoquetT into a flat_soq_map. This function mutates `flat_soq_map`.""" for old_soqs, new_soqs in soq_map: - if isinstance(old_soqs, Soquet): - assert isinstance(new_soqs, Soquet), new_soqs + if isinstance(old_soqs, _Soquet): + assert isinstance(new_soqs, _Soquet), new_soqs flat_soq_map[old_soqs] = new_soqs continue @@ -1019,14 +1108,14 @@ def __init__(self, add_registers_allowed: bool = True): self._i = 0 # Bookkeeping for linear types; Soquets must be used exactly once. - self._available: Set[Soquet] = set() + self._available: Set[_Soquet] = set() # Whether we can call `add_register` and do non-strict `finalize()`. self.add_register_allowed = add_registers_allowed def add_register_from_dtype( self, reg: Union[str, Register], dtype: Optional[QCDType] = None - ) -> Union[None, SoquetT]: + ) -> Union[None, QVarT]: """Add a new typed register to the composite bloq being built. If this bloq builder was constructed with `add_registers_allowed=False`, @@ -1070,21 +1159,21 @@ def add_register_from_dtype( self._regs.append(reg) if reg.side & Side.LEFT: - return _reg_to_soq(LeftDangle, reg, available=self._available) + return self._reg_to_qvar(LeftDangle, reg, track=True) return None @overload - def add_register(self, reg: Register, bitsize: None = None) -> Union[None, SoquetT]: ... + def add_register(self, reg: Register, bitsize: None = None) -> Union[None, QVarT]: ... @overload - def add_register(self, reg: str, bitsize: 'SymbolicInt') -> SoquetT: ... + def add_register(self, reg: str, bitsize: 'SymbolicInt') -> QVarT: ... @overload - def add_register(self, reg: str, bitsize: 'QCDType') -> SoquetT: ... + def add_register(self, reg: str, bitsize: 'QCDType') -> QVarT: ... def add_register( self, reg: Union[str, Register], bitsize: Union[None, 'QCDType', 'SymbolicInt'] = None - ) -> Union[None, SoquetT]: + ) -> Union[None, QVarT]: """Add a new register to the composite bloq being built. If this bloq builder was constructed with `add_registers_allowed=False`, @@ -1120,7 +1209,7 @@ def add_register( @classmethod def from_signature( cls, signature: Signature, add_registers_allowed: bool = False - ) -> Tuple['BloqBuilder', Dict[str, SoquetT]]: + ) -> Tuple['BloqBuilder', Dict[str, QVarT]]: """Construct a BloqBuilder with a pre-specified signature. This is safer if e.g. you're decomposing an existing Bloq and need the signatures @@ -1129,7 +1218,7 @@ def from_signature( # Initial construction: allow register addition for the following loop. bb = cls(add_registers_allowed=True) - initial_soqs: Dict[str, SoquetT] = {} + initial_soqs: Dict[str, QVarT] = {} for reg in signature: if reg.side & Side.LEFT: register = bb.add_register_from_dtype(reg) @@ -1143,8 +1232,16 @@ def from_signature( return bb, initial_soqs + @overload @staticmethod - def is_single(x: 'SoquetT') -> TypeGuard['Soquet']: + def is_single(x: '_SoquetT') -> TypeGuard['_Soquet']: ... + + @overload + @staticmethod + def is_single(x: 'QVarT') -> TypeGuard['QVar']: ... + + @staticmethod + def is_single(x): """Returns True if `x` is a single soquet (not an ndarray of them). This doesn't use stringent runtime type checking; it uses the SoquetT protocol @@ -1152,8 +1249,16 @@ def is_single(x: 'SoquetT') -> TypeGuard['Soquet']: """ return x.shape == () + @overload + @staticmethod + def is_ndarray(x: '_SoquetT') -> TypeGuard['NDArray']: ... + + @overload + @staticmethod + def is_ndarray(x: 'QVarT') -> TypeGuard['NDArray']: ... + @staticmethod - def is_ndarray(x: 'SoquetT') -> TypeGuard['NDArray']: + def is_ndarray(x): """Returns True if `x` is an ndarray of soquets (not a single one). This doesn't use stringent runtime type checking; it uses the SoquetT protocol @@ -1163,8 +1268,8 @@ def is_ndarray(x: 'SoquetT') -> TypeGuard['NDArray']: @staticmethod def map_soqs( - soqs: Dict[str, SoquetT], soq_map: Iterable[Tuple[SoquetT, SoquetT]] - ) -> Dict[str, SoquetT]: + soqs: Dict[str, _SoquetT], soq_map: Iterable[Tuple[_SoquetT, QVarT]] + ) -> Dict[str, QVarT]: """Map `soqs` according to `soq_map`. See `CompositeBloq.iter_bloqsoqs` for example code. @@ -1186,10 +1291,47 @@ def _new_binst_i(self) -> int: self._i += 1 return i + def _make_qvar( + self, binst: Union[BloqInstance, DanglingT], reg: Register, idx: Tuple[int, ...] = () + ): + return _QVar(_Soquet(binst, reg, idx), bb=self) + + def _reg_to_qvar( + self, binst: Union[BloqInstance, DanglingT], reg: Register, *, track: bool = False + ) -> 'QVarT': + """Create the soquet or array of soquets for a register. + + Args: + binst: The output soquet's bloq instance. + reg: The register + track: Whether this is making new qvars that we need to track to enforce linear logic. + + Returns: + A Soquet or Soquets. For multi-dimensional + registers, the value will be an array of indexed Soquets. For 0-dimensional (normal) + registers, the value will be a `Soquet` object. + """ + if reg.shape: + soqs = np.empty(reg.shape, dtype=object) + for ri in reg.all_idxs(): + soq = _QVar(_Soquet(binst, reg, idx=ri), bb=self) + soqs[ri] = soq + if track: + self._available.add(soq.soquet) + return soqs + + # Annoyingly, this must be a special case. + # Otherwise, x[i] = thing will nest *array* objects because our ndarray's type is + # 'object'. This wouldn't happen--for example--with an integer array. + soq = _QVar(_Soquet(binst, reg), bb=self) + if track: + self._available.add(soq.soquet) + return soq + def _add_cxn( self, binst: Union[BloqInstance, DanglingT], - idxed_soq: Soquet, + idxed_soq: _QVar, reg: Register, idx: Tuple[int, ...], ) -> None: @@ -1199,16 +1341,19 @@ def _add_cxn( `(reg, idx)`. """ try: - self._available.remove(idxed_soq) + self._available.remove(idxed_soq.soquet) except KeyError: bloq = binst if isinstance(binst, DanglingT) else binst.bloq raise BloqError( - f"{idxed_soq} is not an available Soquet for `{bloq}.{reg.name}`." + f"During construction of the bloq, a quantum variable was re-used.\n" + f" When calling: {bloq}\n" + f" Register name: {reg.name}\n" + f" Re-used soquet details: {idxed_soq.soquet}" ) from None - cxn = Connection(idxed_soq, Soquet(binst, reg, idx)) + cxn = Connection(idxed_soq.soquet, self._make_qvar(binst, reg, idx).soquet) self._cxns.append(cxn) - def add_t(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[SoquetT, ...]: + def add_t(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[QVarT, ...]: """Add a new bloq instance to the compute graph and always return a tuple of soquets. This method will always return a tuple of soquets. See `BloqBuilder.add_d(..)` for a @@ -1318,7 +1463,7 @@ def add(self, bloq: Bloq, **in_soqs: SoquetInT): def _add_binst( self, binst: BloqInstance, in_soqs: Mapping[str, SoquetInT] - ) -> Iterator[Tuple[str, SoquetT]]: + ) -> Iterator[Tuple[str, QVarT]]: """Add a bloq instance. Warning! Do not use this function externally! Untold bad things will happen if @@ -1328,19 +1473,32 @@ def _add_binst( bloq = binst.bloq - def _add(idxed_soq: Soquet, reg: Register, idx: Tuple[int, ...]): + def _add(idxed_soq: _QVar, reg: Register, idx: Tuple[int, ...]): # close over `binst` return self._add_cxn(binst, idxed_soq, reg, idx) _process_soquets( - registers=bloq.signature.lefts(), in_soqs=in_soqs, debug_str=str(bloq), func=_add + registers=bloq.signature.lefts(), + in_soqs=in_soqs, + debug_str=f'a call to {bloq}', + func=_add, ) yield from ( - (reg.name, _reg_to_soq(binst, reg, available=self._available)) - for reg in bloq.signature.rights() + (reg.name, self._reg_to_qvar(binst, reg, track=True)) for reg in bloq.signature.rights() ) - def add_from(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[SoquetT, ...]: + def initial_soq_map(self, lefts: Iterable[Register]) -> List[Tuple['_SoquetT', 'QVarT']]: + """The initial mapping from old soquets to new soquets known to this bloq builder. + + This is used in patterns when you plan on calling `BloqBuilder.map_soqs` to add + connections from an "old" composite bloq to the "new" bloq we're currently building. + """ + soq_map: List[Tuple['_SoquetT', 'QVarT']] = [ + (_reg_to_soq(LeftDangle, reg), self._reg_to_qvar(LeftDangle, reg)) for reg in lefts + ] + return soq_map + + def add_from(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple['QVarT', ...]: """Add all the sub-bloqs from `bloq` to the compute graph. This is useful for adding multiple bloq instances at once in a "flat" or "unrolled" way. @@ -1355,7 +1513,7 @@ def add_from(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[SoquetT, ...]: `Soquet`s or an array thereof. Returns: - out_soqs: A `SoquetT` for each right (output) register ordered + out_soqs: A `QVarT` for each right (output) register ordered according to `bloq.signature`. The ordering is according to `bloq.signature` and irrespective of the order of `**in_soqs`. """ @@ -1366,22 +1524,22 @@ def add_from(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[SoquetT, ...]: for k, v in in_soqs.items(): in_soqs[k] = np.asarray(v) + in_soqs = cast(Dict[str, QVarT], in_soqs) # Initial mapping of LeftDangle according to user-provided in_soqs. - soq_map: List[Tuple[SoquetT, SoquetT]] = [ - (_reg_to_soq(LeftDangle, reg), cast(SoquetT, in_soqs[reg.name])) - for reg in cbloq.signature.lefts() + soq_map: List[Tuple[_SoquetT, QVarT]] = [ + (_reg_to_soq(LeftDangle, reg), in_soqs[reg.name]) for reg in cbloq.signature.lefts() ] - for binst, in_soqs, old_out_soqs in cbloq.iter_bloqsoqs(): - in_soqs = _map_soqs(in_soqs, soq_map) + for binst, _in_soqs, old_out_soqs in cbloq.iter_bloqsoqs(): + in_soqs = _map_soqs(_in_soqs, soq_map) new_out_soqs = self.add_t(binst.bloq, **in_soqs) soq_map.extend(zip(old_out_soqs, new_out_soqs)) fsoqs = _map_soqs(cbloq.final_soqs(), soq_map) return tuple(fsoqs[reg.name] for reg in cbloq.signature.rights()) - def finalize(self, **final_soqs: SoquetT) -> CompositeBloq: + def finalize(self, **final_soqs: SoquetInT) -> CompositeBloq: """Finish building a CompositeBloq and return the immutable CompositeBloq. This method is similar to calling `add()` but instead of adding a new Bloq, @@ -1403,25 +1561,25 @@ def finalize(self, **final_soqs: SoquetT) -> CompositeBloq: # If items from `final_soqs` don't already exist in `_regs`, add RIGHT registers # for them. Then call `_finalize_strict` where the actual dangling connections are added. - def _infer_reg(name: str, soq: SoquetT) -> Register: - """Go from Soquet -> register, but use a specific name for the register.""" + def _infer_shaped_dtype(soq: SoquetT) -> Tuple['QCDType', Tuple[int, ...]]: + """Extract (dtype, shape) from SoquetT""" if BloqBuilder.is_single(soq): - return Register(name=name, dtype=soq.dtype, side=Side.RIGHT) - assert BloqBuilder.is_ndarray(soq) + return soq.item().dtype, () # Get info from 0th soquet in an ndarray. - return Register( - name=name, dtype=soq.reshape(-1).item(0).dtype, shape=soq.shape, side=Side.RIGHT - ) + assert BloqBuilder.is_ndarray(soq) + dtype = soq.reshape(-1).item(0).dtype + return dtype, soq.shape - right_reg_names = [reg.name for reg in self._regs if reg.side & Side.RIGHT] + existing_right_reg_names = [reg.name for reg in self._regs if reg.side & Side.RIGHT] for name, soq in final_soqs.items(): - if name not in right_reg_names: - self._regs.append(_infer_reg(name, soq)) + if name not in existing_right_reg_names: + dtype, shape = _infer_shaped_dtype(np.asarray(soq)) + self._regs.append(Register(name=name, dtype=dtype, shape=shape, side=Side.RIGHT)) return self._finalize_strict(**final_soqs) - def _finalize_strict(self, **final_soqs: SoquetT) -> CompositeBloq: + def _finalize_strict(self, **final_soqs: SoquetInT) -> CompositeBloq: """Finish building a CompositeBloq and return the immutable CompositeBloq. Args: @@ -1430,7 +1588,7 @@ def _finalize_strict(self, **final_soqs: SoquetT) -> CompositeBloq: """ signature = Signature(self._regs) - def _fin(idxed_soq: Soquet, reg: Register, idx: Tuple[int, ...]): + def _fin(idxed_soq: _QVar, reg: Register, idx: Tuple[int, ...]): # close over `RightDangle` return self._add_cxn(RightDangle, idxed_soq, reg, idx) @@ -1448,14 +1606,14 @@ def _fin(idxed_soq: Soquet, reg: Register, idx: Tuple[int, ...]): def allocate( self, n: Union[int, sympy.Expr] = 1, dtype: Optional[QDType] = None, dirty: bool = False - ) -> Soquet: + ) -> 'QVar': from qualtran.bloqs.bookkeeping import Allocate if dtype is not None: return self.add(Allocate(dtype=dtype, dirty=dirty)) return self.add(Allocate(dtype=(QAny(n)), dirty=dirty)) - def free(self, soq: Soquet, dirty: bool = False) -> None: + def free(self, soq: QVarT, dirty: bool = False) -> None: from qualtran.bloqs.bookkeeping import Free if not BloqBuilder.is_single(soq): @@ -1467,7 +1625,7 @@ def free(self, soq: Soquet, dirty: bool = False) -> None: self.add(Free(dtype=qdtype, dirty=dirty), reg=soq) - def split(self, soq: SoquetInT) -> NDArray[Soquet]: # type: ignore[type-var] + def split(self, soq: QVarT) -> NDArray['QVar']: # type: ignore[type-var] """Add a Split bloq to split up a register.""" from qualtran.bloqs.bookkeeping import Split @@ -1480,16 +1638,16 @@ def split(self, soq: SoquetInT) -> NDArray[Soquet]: # type: ignore[type-var] return self.add(Split(dtype=qdtype), reg=soq) - def join(self, soqs: SoquetInT, dtype: Optional[QDType] = None) -> Soquet: + def join(self, soqs: SoquetInT, dtype: Optional[QDType] = None) -> 'Soquet': from qualtran.bloqs.bookkeeping import Join try: soqs = np.asarray(soqs) (n,) = soqs.shape - except (AttributeError, ValueError): - raise ValueError("`join` expects a 1-d array of input soquets to join.") from None + except (AttributeError, ValueError) as e: + raise ValueError("`join` expects a 1-d array of input soquets to join.") from e - if not all(soq.reg.bitsize == 1 for soq in soqs): + if not all(soq.dtype.num_bits == 1 for soq in soqs): raise ValueError("`join` can only join equal-bitsized soquets, currently only size 1.") if dtype is None: dtype = QAny(n) diff --git a/qualtran/_infra/composite_bloq_test.py b/qualtran/_infra/composite_bloq_test.py index 9395666a77..f9f3953f7a 100644 --- a/qualtran/_infra/composite_bloq_test.py +++ b/qualtran/_infra/composite_bloq_test.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import cached_property -from typing import cast, Dict, List, Tuple +from typing import Any, cast, Dict import attrs import networkx as nx @@ -40,8 +40,14 @@ Soquet, SoquetT, ) -from qualtran._infra.composite_bloq import _create_binst_graph, _get_dangling_soquets, _get_soquet +from qualtran._infra.composite_bloq import ( + _create_binst_graph, + _get_dangling_soquets, + _get_soquet, + _SoquetT, +) from qualtran._infra.data_types import BQUInt, QAny, QBit, QFxp, QUInt +from qualtran._infra.quantum_graph import _QVar, _Soquet from qualtran.bloqs.basic_gates import CNOT, IntEffect, ZeroEffect from qualtran.bloqs.bookkeeping import Join from qualtran.bloqs.for_testing.atom import TestAtom, TestTwoBitOp @@ -59,12 +65,12 @@ def _manually_make_test_cbloq_cxns(): binst2 = BloqInstance(tcn, 2) assert binst1 != binst2 return [ - Connection(Soquet(LeftDangle, q1), Soquet(binst1, control)), - Connection(Soquet(LeftDangle, q2), Soquet(binst1, target)), - Connection(Soquet(binst1, control), Soquet(binst2, target)), - Connection(Soquet(binst1, target), Soquet(binst2, control)), - Connection(Soquet(binst2, control), Soquet(RightDangle, q1)), - Connection(Soquet(binst2, target), Soquet(RightDangle, q2)), + Connection(_Soquet(LeftDangle, q1), _Soquet(binst1, control)), + Connection(_Soquet(LeftDangle, q2), _Soquet(binst1, target)), + Connection(_Soquet(binst1, control), _Soquet(binst2, target)), + Connection(_Soquet(binst1, target), _Soquet(binst2, control)), + Connection(_Soquet(binst2, control), _Soquet(RightDangle, q1)), + Connection(_Soquet(binst2, target), _Soquet(RightDangle, q2)), ], signature @@ -139,27 +145,31 @@ def test_map_soqs(): bb, _ = BloqBuilder.from_signature(cbloq.signature) bb._i = 100 # pylint: disable=protected-access - soq_map: List[Tuple[SoquetT, SoquetT]] = [] + soq_map = bb.initial_soq_map(cbloq.signature.lefts()) + for binst, in_soqs, old_out_soqs in cbloq.iter_bloqsoqs(): if binst.i == 0: - assert in_soqs == bb.map_soqs(in_soqs, soq_map) + assert bb.map_soqs(in_soqs, soq_map) == { + 'ctrl': _QVar(in_soqs['ctrl'].item(), bb=bb), + 'target': _QVar(in_soqs['target'].item(), bb=bb), + } elif binst.i == 1: for k, val in bb.map_soqs(in_soqs, soq_map).items(): - assert isinstance(val, Soquet) - assert isinstance(val.binst, BloqInstance) - assert val.binst.i >= 100 + assert isinstance(val, _QVar) + assert isinstance(val.soquet.binst, BloqInstance) + assert val.soquet.binst.i >= 100 else: raise AssertionError() - in_soqs = bb.map_soqs(in_soqs, soq_map) - new_out_soqs = bb.add_t(binst.bloq, **in_soqs) + mapped_in_soqs = bb.map_soqs(in_soqs, soq_map) + new_out_soqs = bb.add_t(binst.bloq, **mapped_in_soqs) soq_map.extend(zip(old_out_soqs, new_out_soqs)) fsoqs = bb.map_soqs(cbloq.final_soqs(), soq_map) for k, val in fsoqs.items(): - assert isinstance(val, Soquet) - assert isinstance(val.binst, BloqInstance) - assert val.binst.i >= 100 + assert isinstance(val, _QVar) + assert isinstance(val.soquet.binst, BloqInstance) + assert val.soquet.binst.i >= 100 cbloq = bb.finalize(**fsoqs) assert isinstance(cbloq, CompositeBloq) @@ -194,7 +204,18 @@ def test_bloq_builder(): signature = Signature.build(x=1, y=1) x_reg, y_reg = signature bb, initial_soqs = BloqBuilder.from_signature(signature) - assert initial_soqs == {'x': Soquet(LeftDangle, x_reg), 'y': Soquet(LeftDangle, y_reg)} + + # Using deprecated Soquet constructor (to be removed) + assert initial_soqs == { + 'x': _QVar(Soquet(LeftDangle, x_reg), bb=bb), # type: ignore + 'y': _QVar(Soquet(LeftDangle, y_reg), bb=bb), # type: ignore + } + + # Using private constructor + assert initial_soqs == { + 'x': _QVar(_Soquet(LeftDangle, x_reg), bb=bb), + 'y': _QVar(_Soquet(LeftDangle, y_reg), bb=bb), + } x = initial_soqs['x'] y = initial_soqs['y'] @@ -219,17 +240,17 @@ def _get_bb(): def test_wrong_soquet(): bb, x, y = _get_bb() - with pytest.raises(BloqError, match=r'.*is not an available Soquet for .*target.*'): - bad_target_arg = Soquet(BloqInstance(TestTwoBitOp(), i=12), Register('target', QAny(2))) + with pytest.raises(BloqError): + bad_target_arg = bb._make_qvar( + BloqInstance(TestTwoBitOp(), i=12), Register('target', QAny(2)) + ) bb.add(TestTwoBitOp(), ctrl=x, target=bad_target_arg) def test_double_use_1(): bb, x, y = _get_bb() - with pytest.raises( - BloqError, match=r'.*is not an available Soquet for `TestTwoBitOp.*target`.*' - ): + with pytest.raises(BloqError, match=r'.*quantum variable was re\-used.*'): bb.add(TestTwoBitOp(), ctrl=x, target=x) @@ -238,14 +259,16 @@ def test_double_use_2(): x2, y2 = bb.add(TestTwoBitOp(), ctrl=x, target=y) - with pytest.raises(BloqError, match=r'.*is not an available Soquet for `TestTwoBitOp\.ctrl`\.'): + with pytest.raises(BloqError, match=r'.*quantum variable was re\-used.*'): x3, y3 = bb.add(TestTwoBitOp(), ctrl=x, target=y) def test_missing_args(): bb, x, y = _get_bb() - with pytest.raises(BloqError, match=r'.*requires a Soquet named `ctrl`.'): + with pytest.raises( + BloqError, match=r"During a call to TestTwoBitOp, we expected a value for 'ctrl'\." + ): bb.add(TestTwoBitOp(), target=y) @@ -262,15 +285,17 @@ def test_finalize_wrong_soquet(): assert x != x2 assert y != y2 - with pytest.raises(BloqError, match=r'.*is not an available Soquet for .*y.*'): - bb.finalize(x=x2, y=Soquet(BloqInstance(TestTwoBitOp(), i=12), Register('target', QAny(2)))) + with pytest.raises(BloqError, match=r'.*quantum variable was re\-used.*'): + bb.finalize( + x=x2, y=bb._make_qvar(BloqInstance(TestTwoBitOp(), i=12), Register('target', QAny(2))) + ) def test_finalize_double_use_1(): bb, x, y = _get_bb() x2, y2 = bb.add(TestTwoBitOp(), ctrl=x, target=y) - with pytest.raises(BloqError, match=r'.*is not an available Soquet for .*y.*'): + with pytest.raises(BloqError, match=r'.*quantum variable was re\-used.*'): bb.finalize(x=x2, y=x2) @@ -278,7 +303,7 @@ def test_finalize_double_use_2(): bb, x, y = _get_bb() x2, y2 = bb.add(TestTwoBitOp(), ctrl=x, target=y) - with pytest.raises(BloqError, match=r'.*is not an available Soquet for `RightDangle\.x`\.'): + with pytest.raises(BloqError, match=r'.*quantum variable was re\-used.*'): bb.finalize(x=x, y=y2) @@ -286,7 +311,8 @@ def test_finalize_missing_args(): bb, x, y = _get_bb() x2, y2 = bb.add(TestTwoBitOp(), ctrl=x, target=y) - with pytest.raises(BloqError, match=r'Finalizing requires a Soquet named `x`.'): + bb.add_register_allowed = False + with pytest.raises(BloqError, match=r"During Finalizing, we expected a value for 'x'\."): bb.finalize(y=y2) @@ -296,15 +322,15 @@ def test_finalize_strict_too_many_args(): bb.add_register_allowed = False with pytest.raises(BloqError, match=r'Finalizing does not accept Soquets.*z.*'): - bb.finalize(x=x2, y=y2, z=Soquet(RightDangle, Register('asdf', QBit()))) + bb.finalize(x=x2, y=y2, z=_Soquet(RightDangle, Register('asdf', QBit()))) def test_finalize_bad_args(): bb, x, y = _get_bb() x2, y2 = bb.add(TestTwoBitOp(), ctrl=x, target=y) - with pytest.raises(BloqError, match=r'.*is not an available Soquet.*RightDangle\.z.*'): - bb.finalize(x=x2, y=y2, z=Soquet(RightDangle, Register('asdf', QBit()))) + with pytest.raises(BloqError): + bb.finalize(x=x2, y=y2, z=bb._make_qvar(RightDangle, Register('asdf', QBit()))) def test_finalize_alloc(): @@ -320,7 +346,7 @@ def test_get_soquets(): soqs = _get_dangling_soquets(Join(QAny(10)).signature, right=True) assert list(soqs.keys()) == ['reg'] soq = soqs['reg'] - assert isinstance(soq, Soquet) + assert isinstance(soq, _Soquet) assert soq.binst == RightDangle assert soq.reg.bitsize == 10 @@ -438,6 +464,28 @@ def test_copy(cls): assert cbloq.debug_text() == cbloq2.debug_text() +def _copy_bad_init(cbloq): + # Test backwards-compatibility shim where we don't use `bb.initial_soq_map`. + from qualtran._infra.composite_bloq import _map_soqs + + bb, _ = BloqBuilder.from_signature(cbloq.signature) + soq_map: Any = [] # !!! + + for binst, in_soqs, old_out_soqs in cbloq.iter_bloqsoqs(): + mapped_in_soqs = _map_soqs(in_soqs, soq_map) + new_out_soqs = bb.add_t(binst.bloq, **mapped_in_soqs) + soq_map.extend(zip(old_out_soqs, new_out_soqs)) + + fsoqs = _map_soqs(cbloq.final_soqs(), soq_map) + return bb.finalize(**fsoqs) + + +def test_old_copy(): + cbloq = TestParallelCombo().decompose_bloq() + new_cbloq = _copy_bad_init(cbloq) + assert new_cbloq.debug_text() == cbloq.debug_text() + + @pytest.mark.parametrize('call_decompose', [False, True]) def test_add_from(call_decompose): bb = BloqBuilder() @@ -648,14 +696,14 @@ def test_get_soquet(): def test_can_tell_individual_from_ndsoquet(): - s1 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(0,)) - s2 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(1,)) - s3 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(2,)) - s4 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(3,)) + s1 = _Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(0,)) + s2 = _Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(1,)) + s3 = _Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(2,)) + s4 = _Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(3,)) # A ndarray of soquet objects should be SoquetT and we can tell by checking its shape. - ndsoq: SoquetT = np.array([s1, s2, s3, s4]) - assert_type(ndsoq, SoquetT) + ndsoq: _SoquetT = np.array([s1, s2, s3, s4]) + assert_type(ndsoq, _SoquetT) assert ndsoq.shape assert ndsoq.shape == (4,) assert ndsoq.item(2) == s3 @@ -663,22 +711,24 @@ def test_can_tell_individual_from_ndsoquet(): _ = ndsoq.item() # A single soquet is still a valid SoquetT, and it has a false-y shape. - single_soq: SoquetT = s1 - assert_type(single_soq, SoquetT) + single_soq: _SoquetT = s1 + assert_type(single_soq, _SoquetT) assert not single_soq.shape assert single_soq.shape == () single_soq_unwarp = single_soq.item() assert single_soq_unwarp == s1 # A single soquet wrapped in a 0-dim ndarray is ok if you call `item()`. - single_soq2: SoquetT = np.asarray(s1) - assert_type(single_soq2, SoquetT) + single_soq2: _SoquetT = np.asarray(s1) + assert_type(single_soq2, _SoquetT) assert not single_soq2.shape assert single_soq2.shape == () single_soq2_unwrap = single_soq2.item() assert hash(single_soq2_unwrap) == hash(s1) assert single_soq2_unwrap == s1 - assert isinstance(single_soq2_unwrap, Soquet) + with pytest.warns(DeprecationWarning, match=r'deprecated'): + assert isinstance(single_soq2_unwrap, Soquet) # type: ignore[misc] + assert isinstance(single_soq2_unwrap, _Soquet) @pytest.mark.notebook diff --git a/qualtran/_infra/controlled.py b/qualtran/_infra/controlled.py index 5d6133b881..1b7ab37cb3 100644 --- a/qualtran/_infra/controlled.py +++ b/qualtran/_infra/controlled.py @@ -642,10 +642,10 @@ def build_composite_bloq( cbloq = self.subbloq.decompose_bloq() ctrl_soqs: List['SoquetT'] = [initial_soqs[creg_name] for creg_name in self.ctrl_reg_names] + soq_map = bb.initial_soq_map(cbloq.signature.lefts()) - soq_map: List[Tuple[SoquetT, SoquetT]] = [] - for binst, in_soqs, old_out_soqs in cbloq.iter_bloqsoqs(): - in_soqs = bb.map_soqs(in_soqs, soq_map) + for binst, _in_soqs, old_out_soqs in cbloq.iter_bloqsoqs(): + in_soqs = bb.map_soqs(_in_soqs, soq_map) new_bloq, adder = binst.bloq.get_ctrl_system(self.ctrl_spec) adder_output = adder(bb, ctrl_soqs=ctrl_soqs, in_soqs=in_soqs) ctrl_soqs = list(adder_output[0]) diff --git a/qualtran/_infra/quantum_graph.py b/qualtran/_infra/quantum_graph.py index f3d4f34cfc..d79a66d1ec 100644 --- a/qualtran/_infra/quantum_graph.py +++ b/qualtran/_infra/quantum_graph.py @@ -13,14 +13,16 @@ # limitations under the License. """Plumbing for bloq-to-bloq `Connection`s.""" - +import warnings from functools import cached_property -from typing import Tuple, TYPE_CHECKING, Union +from typing import Optional, Tuple, TYPE_CHECKING, Union +import attrs +import numpy as np from attrs import field, frozen if TYPE_CHECKING: - from qualtran import Bloq, BloqBuilder, QCDType, Register + from qualtran import Bloq, BloqBuilder, QCDType, QVarT, Register @frozen @@ -76,10 +78,10 @@ def _to_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: @frozen -class Soquet: +class _Soquet: """One half of a connection. - Users should not construct these directly. They should be marshalled + Users should not construct these directly. They should be marshaled by a `BloqBuilder`. A `Soquet` acts as the node type in our quantum compute graph. It is a particular @@ -117,7 +119,7 @@ def dtype(self) -> 'QCDType': def shape(self) -> Tuple[int, ...]: return () - def item(self, *args) -> 'Soquet': + def item(self, *args) -> '_Soquet': if args: raise ValueError("Tried to index into a single soquet.") return self @@ -132,6 +134,65 @@ def __str__(self) -> str: return f'{self.binst}.{self.pretty()}' +@attrs.mutable +class _QVar: + """A handle to a quantum variable used during bloq building. + + Do not construct these objects directly. Please use the `QVar` `typing.Protocol` for + type annotations. + """ + + soquet: _Soquet + bb: 'BloqBuilder' = field(kw_only=True) + _split_components: Optional['QVarT'] = field(default=None) + ssa_name: Optional[str] = field(default=None, kw_only=True) + + @property + def dtype(self) -> 'QCDType': + return self.soquet.dtype + + @property + def shape(self) -> Tuple[int, ...]: + return () + + def item(self, *args): + if args and args != ((),): + raise ValueError(f"Tried to index {args!r} into a single soquet.") + return self + + @property + def reg(self) -> 'Register': + warnings.warn( + "Accessing the register property of a quantum variable is highly discouraged " + "and will be dis-allowed in the future.", + DeprecationWarning, + ) + return self.soquet.reg + + def __hash__(self): + raise TypeError("QVar objects during bloq building are *not* hashable.") + + def __getitem__(self, item): + if self._split_components is None: + self._split_components = self.bb.split(self) + + return self._split_components[item] + + def __len__(self): + return self.dtype.num_bits + + def __array__(self, dtype=None, copy=None): + # This method is super important -- + # throughout the library, we use np.asarray(soqs) + arr = np.empty(shape=(), dtype=object) + arr[()] = self + if copy is None: + return arr + if copy: + raise NotImplementedError() + return arr + + LeftDangle = DanglingT("LeftDangle") RightDangle = DanglingT("RightDangle") @@ -151,8 +212,8 @@ class Connection: is directed. """ - left: Soquet - right: Soquet + left: _Soquet + right: _Soquet @cached_property def num_qubits(self) -> int: diff --git a/qualtran/_infra/quantum_graph_test.py b/qualtran/_infra/quantum_graph_test.py index b692334d1d..76ac14e87e 100644 --- a/qualtran/_infra/quantum_graph_test.py +++ b/qualtran/_infra/quantum_graph_test.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import no_type_check import pytest from qualtran import BloqInstance, DanglingT, LeftDangle, QAny, Register, RightDangle, Side, Soquet +from qualtran._infra.quantum_graph import _Soquet from qualtran.bloqs.for_testing import TestAtom, TestTwoBitOp @@ -43,7 +45,19 @@ def test_dangling_hash(): def test_soquet(): - soq = Soquet(BloqInstance(TestTwoBitOp(), i=0), Register('x', QAny(10))) + soq = _Soquet(BloqInstance(TestTwoBitOp(), i=0), Register('x', QAny(10))) + assert soq.reg.side is Side.THRU + assert soq.idx == () + assert soq.pretty() == 'x' + + assert soq.item() == soq + assert soq.dtype == QAny(10) + + +@no_type_check +def test_old_construct_soquet(): + with pytest.warns(DeprecationWarning, match=r'deprecated.*'): + soq = Soquet(BloqInstance(TestTwoBitOp(), i=0), Register('x', QAny(10))) assert soq.reg.side is Side.THRU assert soq.idx == () assert soq.pretty() == 'x' @@ -57,16 +71,16 @@ def test_soquet_idxed(): reg = Register('y', QAny(10), shape=(10, 2)) with pytest.raises(ValueError, match=r'Bad index.*'): - _ = Soquet(binst, reg) + _ = _Soquet(binst, reg) with pytest.raises(ValueError, match=r'Bad index.*'): - _ = Soquet(binst, reg, idx=(5,)) + _ = _Soquet(binst, reg, idx=(5,)) - soq = Soquet(binst, reg, idx=(5, 0)) + soq = _Soquet(binst, reg, idx=(5, 0)) assert soq.pretty() == 'y[5, 0]' with pytest.raises(ValueError, match=r'Bad index.*'): - _ = Soquet(binst, reg, idx=(5,)) + _ = _Soquet(binst, reg, idx=(5,)) def test_bloq_instance(): diff --git a/qualtran/bloqs/bookkeeping/free_test.py b/qualtran/bloqs/bookkeeping/free_test.py index 63ffa89314..75a8a2d040 100644 --- a/qualtran/bloqs/bookkeeping/free_test.py +++ b/qualtran/bloqs/bookkeeping/free_test.py @@ -24,11 +24,11 @@ def test_free(bloq_autotester): def test_util_bloqs(): bb = BloqBuilder() qs1 = bb.add(Allocate(QAny(10))) - assert isinstance(qs1, Soquet) + assert isinstance(qs1, Soquet) # type: ignore[misc] qs2 = bb.add(Split(QAny(10)), reg=qs1) assert qs2.shape == (10,) qs3 = bb.add(Join(QAny(10)), reg=qs2) - assert isinstance(qs3, Soquet) + assert isinstance(qs3, Soquet) # type: ignore[misc] no_return = bb.add(Free(QAny(10)), reg=qs3) assert no_return is None assert bb.finalize().tensor_contract() == 1.0 diff --git a/qualtran/bloqs/bookkeeping/split_test.py b/qualtran/bloqs/bookkeeping/split_test.py index 8efb0a620f..a6f47b185d 100644 --- a/qualtran/bloqs/bookkeeping/split_test.py +++ b/qualtran/bloqs/bookkeeping/split_test.py @@ -44,15 +44,15 @@ def test_classical_sim(): cbloq = bb.finalize(y=y) ret, assign = call_cbloq_classically(cbloq.signature, vals={}, binst_graph=cbloq._binst_graph) - assert assign[x] == 0 + assert assign[x.soquet] == 0 # type: ignore[attr-defined] - assert assign[xs[0]] == 0 - assert assign[xs_1_orig] == 0 - assert assign[xs[2]] == 0 - assert assign[xs[3]] == 0 + assert assign[xs[0].soquet] == 0 + assert assign[xs_1_orig.soquet] == 0 + assert assign[xs[2].soquet] == 0 + assert assign[xs[3].soquet] == 0 - assert assign[xs[1]] == 1 - assert assign[y] == 4 + assert assign[xs[1].soquet] == 1 + assert assign[y.soquet] == 4 # type: ignore[attr-defined] assert ret == {'y': 4} diff --git a/qualtran/bloqs/mcmt/and_bloq.py b/qualtran/bloqs/mcmt/and_bloq.py index 611f3228e2..7c043dd92b 100644 --- a/qualtran/bloqs/mcmt/and_bloq.py +++ b/qualtran/bloqs/mcmt/and_bloq.py @@ -42,6 +42,8 @@ DecomposeTypeError, GateWithRegisters, QBit, + QVar, + QVarT, Register, Side, Signature, @@ -93,6 +95,16 @@ def signature(self) -> Signature: def adjoint(self) -> 'And': return attrs.evolve(self, uncompute=not self.uncompute) + @classmethod + def qcall(cls, ctrl: 'QVarT', *, cv1=1, cv2=1, uncompute: bool = False, **maybe_target: 'QVar'): + ctrl = np.asarray(ctrl) + bb = ctrl.item(0).bb + bloq = cls(cv1=cv1, cv2=cv2, uncompute=uncompute) + if uncompute: + return bb.add(bloq, ctrl=ctrl, target=maybe_target['target']) + else: + return bb.add(bloq, ctrl=ctrl) + def decompose_bloq(self) -> 'CompositeBloq': raise DecomposeTypeError(f"{self} is atomic.") diff --git a/qualtran/bloqs/mcmt/and_bloq_test.py b/qualtran/bloqs/mcmt/and_bloq_test.py index 70d853579f..56603141bf 100644 --- a/qualtran/bloqs/mcmt/and_bloq_test.py +++ b/qualtran/bloqs/mcmt/and_bloq_test.py @@ -22,7 +22,7 @@ from attrs import frozen import qualtran.testing as qlt_testing -from qualtran import Bloq, BloqBuilder, Signature, Soquet, SoquetT +from qualtran import Bloq, BloqBuilder, Signature, SoquetT from qualtran.bloqs.basic_gates import OneEffect, OneState, ZeroEffect, ZeroState from qualtran.bloqs.mcmt.and_bloq import _and_bloq, _multi_and, _multi_and_symb, And, MultiAnd from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity @@ -111,8 +111,6 @@ def test_inverse(): bb = BloqBuilder() q0 = bb.add_register('q0', 1) q1 = bb.add_register('q1', 1) - assert isinstance(q0, Soquet) - assert isinstance(q1, Soquet) qs, trg = bb.add(And(), ctrl=[q0, q1]) qs = bb.add(And(uncompute=True), ctrl=qs, target=trg) cbloq = bb.finalize(q0=qs[0], q1=qs[1]) @@ -197,8 +195,8 @@ def signature(self) -> 'Signature': def build_composite_bloq( self, bb: 'BloqBuilder', q0: 'SoquetT', q1: 'SoquetT' ) -> Dict[str, 'SoquetT']: - assert isinstance(q0, Soquet) - assert isinstance(q1, Soquet) + assert BloqBuilder.is_single(q0) + assert BloqBuilder.is_single(q1) qs, trg = bb.add(And(), ctrl=[q0, q1]) q0, q1 = bb.add(And(uncompute=True), ctrl=qs, target=trg) return {'q0': q0, 'q1': q1} @@ -232,9 +230,6 @@ def test_multiand_adjoint(): q0 = bb.add_register('q0', 1) q1 = bb.add_register('q1', 1) q2 = bb.add_register('q2', 1) - assert isinstance(q0, Soquet) - assert isinstance(q1, Soquet) - assert isinstance(q2, Soquet) qs, junk, trg = bb.add(MultiAnd((1, 1, 1)), ctrl=[q0, q1, q2]) qs = bb.add(MultiAnd((1, 1, 1)).adjoint(), ctrl=qs, target=trg, junk=junk) diff --git a/qualtran/cirq_interop/_bloq_to_cirq.py b/qualtran/cirq_interop/_bloq_to_cirq.py index 79eca64cf6..949abc48b9 100644 --- a/qualtran/cirq_interop/_bloq_to_cirq.py +++ b/qualtran/cirq_interop/_bloq_to_cirq.py @@ -31,7 +31,6 @@ RightDangle, Side, Signature, - Soquet, ) from qualtran._infra.binst_graph_iterators import greedy_topological_sort from qualtran._infra.composite_bloq import _binst_to_cxns @@ -40,6 +39,7 @@ merge_qubits, split_qubits, ) +from qualtran._infra.quantum_graph import _Soquet from qualtran.cirq_interop._cirq_to_bloq import _QReg, CirqQuregInT, CirqQuregT from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager from qualtran.drawing import Circle, LarrowTextBox, ModPlus, RarrowTextBox, TextBox, WireSymbol @@ -186,7 +186,7 @@ def __repr__(self) -> str: return f'BloqAsCirqGate({self.bloq})' -def _track_soq_name_changes(cxns: Iterable[Connection], qvar_to_qreg: Dict[Soquet, _QReg]): +def _track_soq_name_changes(cxns: Iterable[Connection], qvar_to_qreg: Dict[_Soquet, _QReg]): """Track inter-Bloq name changes across the two ends of a connection.""" for cxn in cxns: qvar_to_qreg[cxn.right] = qvar_to_qreg[cxn.left] @@ -197,7 +197,7 @@ def _bloq_to_cirq_op( bloq: Bloq, pred_cxns: Iterable[Connection], succ_cxns: Iterable[Connection], - qvar_to_qreg: Dict[Soquet, _QReg], + qvar_to_qreg: Dict[_Soquet, _QReg], qubit_manager: cirq.QubitManager, ) -> Optional[cirq.Operation]: _track_soq_name_changes(pred_cxns, qvar_to_qreg) @@ -249,8 +249,8 @@ def _cbloq_to_cirq_circuit( k: np.apply_along_axis(_QReg, -1, *(v, signature.get_left(k).dtype)) # type: ignore for k, v in cirq_quregs.items() } - qvar_to_qreg: Dict[Soquet, _QReg] = { - Soquet(LeftDangle, idx=idx, reg=reg): np.asarray(cirq_quregs[reg.name])[idx] + qvar_to_qreg: Dict[_Soquet, _QReg] = { + _Soquet(LeftDangle, idx=idx, reg=reg): np.asarray(cirq_quregs[reg.name]).item(idx) for reg in signature.lefts() for idx in reg.all_idxs() } @@ -271,7 +271,7 @@ def _cbloq_to_cirq_circuit( def _f_quregs(reg: Register) -> CirqQuregT: ret = np.empty(reg.shape + (reg.bitsize,), dtype=object) for idx in reg.all_idxs(): - soq = Soquet(RightDangle, idx=idx, reg=reg) + soq = _Soquet(RightDangle, idx=idx, reg=reg) ret[idx] = qvar_to_qreg[soq].qubits return ret diff --git a/qualtran/cirq_interop/_cirq_to_bloq.py b/qualtran/cirq_interop/_cirq_to_bloq.py index 7dea7daeb1..20c7872e57 100644 --- a/qualtran/cirq_interop/_cirq_to_bloq.py +++ b/qualtran/cirq_interop/_cirq_to_bloq.py @@ -38,6 +38,7 @@ QAny, QBit, QDType, + QVar, Register, Side, Signature, @@ -277,9 +278,7 @@ def __hash__(self): return hash(self.qubits) -def _ensure_in_reg_exists( - bb: BloqBuilder, in_reg: _QReg, qreg_to_qvar: Dict[_QReg, Soquet] -) -> None: +def _ensure_in_reg_exists(bb: BloqBuilder, in_reg: _QReg, qreg_to_qvar: Dict[_QReg, QVar]) -> None: """Takes care of qubit allocations, split and joins to ensure `qreg_to_qvar[in_reg]` exists.""" from qualtran.bloqs.bookkeeping import Cast @@ -298,7 +297,7 @@ def _ensure_in_reg_exists( # a. Split all registers containing at-least one qubit corresponding to `in_reg`. in_reg_qubits = set(in_reg.qubits) - new_qreg_to_qvar: Dict[_QReg, Soquet] = {} + new_qreg_to_qvar: Dict[_QReg, QVar] = {} for qreg, soq in qreg_to_qvar.items(): if len(qreg.qubits) > 1 and any(q in qreg.qubits for q in in_reg_qubits): new_qreg_to_qvar |= { @@ -309,7 +308,7 @@ def _ensure_in_reg_exists( qreg_to_qvar.clear() # b. Join all 1-bit registers, corresponding to individual qubits, that make up `in_reg`. - soqs_to_join: Dict[cirq.Qid, Soquet] = {} + soqs_to_join: Dict[cirq.Qid, QVar] = {} for qreg, soq in new_qreg_to_qvar.items(): if len(in_reg_qubits) > 1 and qreg.qubits and qreg.qubits[0] in in_reg_qubits: assert len(qreg.qubits) == 1, "Individual qubits should have been split by now." @@ -337,13 +336,11 @@ def _ensure_in_reg_exists( def _gather_input_soqs( - bb: BloqBuilder, - op_quregs: Dict[str, NDArray[_QReg]], # type: ignore[type-var] - qreg_to_qvar: Dict[_QReg, Soquet], + bb: BloqBuilder, op_quregs: Dict[str, NDArray[_QReg]], qreg_to_qvar: Dict[_QReg, QVar] # type: ignore[type-var] ) -> Dict[str, NDArray[Soquet]]: # type: ignore[type-var] qvars_in: Dict[str, NDArray[Soquet]] = {} # type: ignore[type-var] for reg_name, quregs in op_quregs.items(): - flat_soqs: List[Soquet] = [] + flat_soqs: List[QVar] = [] for qureg in quregs.flatten(): _ensure_in_reg_exists(bb, qureg, qreg_to_qvar) flat_soqs.append(qreg_to_qvar[qureg]) @@ -521,7 +518,7 @@ def cirq_optree_to_cbloq( bb, initial_soqs = BloqBuilder.from_signature(signature, add_registers_allowed=False) # 1. Compute qreg_to_qvar for input qubits in the LEFT signature. - qreg_to_qvar: Dict[_QReg, Soquet] = {} + qreg_to_qvar: Dict[_QReg, QVar] = {} for reg in signature.lefts(): if reg.name not in in_quregs: raise ValueError(f"Register {reg.name} from signature must be present in in_quregs.") @@ -568,17 +565,17 @@ def cirq_optree_to_cbloq( for q in quregs.flatten(): _ = qreg_to_qvar.pop(q) else: - assert quregs.shape == np.array(qvars_out[reg.name]).shape - qreg_to_qvar |= zip(quregs.flatten(), np.array(qvars_out[reg.name]).flatten()) + assert quregs.shape == np.asarray(qvars_out[reg.name]).shape + qreg_to_qvar |= zip(quregs.flatten(), np.asarray(qvars_out[reg.name]).flatten()) # 4. Combine Soquets to match the right signature. final_soqs_dict = _gather_input_soqs( bb, {reg.name: out_quregs[reg.name] for reg in signature.rights()}, qreg_to_qvar ) - final_soqs_set = set(soq for soqs in final_soqs_dict.values() for soq in soqs.flatten()) + final_soqs_set = set(soq.soquet for soqs in final_soqs_dict.values() for soq in soqs.flatten()) # 5. Free all dangling Soquets which are not part of the final soquets set. for qvar in qreg_to_qvar.values(): - if qvar not in final_soqs_set: + if qvar.soquet not in final_soqs_set: # type: ignore[attr-defined] bb.free(qvar) return bb.finalize(**final_soqs_dict) diff --git a/qualtran/cirq_interop/_cirq_to_bloq_test.py b/qualtran/cirq_interop/_cirq_to_bloq_test.py index d2c49ee57c..8dab0a424e 100644 --- a/qualtran/cirq_interop/_cirq_to_bloq_test.py +++ b/qualtran/cirq_interop/_cirq_to_bloq_test.py @@ -33,7 +33,6 @@ Register, Side, Signature, - Soquet, SoquetT, ) from qualtran._infra.gate_with_registers import get_named_qubits @@ -53,8 +52,8 @@ def signature(self) -> Signature: def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']: ctrl, target = soqs['control'], soqs['target'] - assert isinstance(ctrl, Soquet) - assert isinstance(target, Soquet) + assert BloqBuilder.is_single(ctrl) + assert BloqBuilder.is_single(target) ctrl, target = bb.add(CirqGateAsBloq(cirq.CNOT), q=[ctrl, target]) return {'control': ctrl, 'target': target} diff --git a/qualtran/drawing/graphviz.py b/qualtran/drawing/graphviz.py index 8f03465a0b..7504e7d698 100644 --- a/qualtran/drawing/graphviz.py +++ b/qualtran/drawing/graphviz.py @@ -33,12 +33,12 @@ RightDangle, Side, Signature, - Soquet, ) +from qualtran._infra.quantum_graph import _Soquet def _assign_ids_to_bloqs_and_soqs( - bloq_instances: Iterable[BloqInstance], all_soquets: Iterable[Soquet] + bloq_instances: Iterable[BloqInstance], all_soquets: Iterable[_Soquet] ) -> Dict[Any, str]: """Assign unique identifiers to bloq instances, soquets, and register groups. @@ -86,7 +86,7 @@ def add(item: Any, desired_id: str): def _parition_registers_in_a_group( regs: Iterable[Register], binst: BloqInstance -) -> Tuple[List[Soquet], List[Soquet], List[Soquet]]: +) -> Tuple[List[_Soquet], List[_Soquet], List[_Soquet]]: """Construct and sort the expected Soquets for a given register group. Since we expect the input registers to be in a group, we assert that @@ -99,7 +99,7 @@ def _parition_registers_in_a_group( thrus = [] for reg in regs: for idx in reg.all_idxs(): - soq = Soquet(binst, reg, idx) + soq = _Soquet(binst, reg, idx) if reg.side is Side.LEFT: lefts.append(soq) elif reg.side is Side.RIGHT: @@ -149,7 +149,7 @@ def __init__(self, bloq: Bloq): self.ids = _assign_ids_to_bloqs_and_soqs(self._binsts, self._soquets) - def get_dangle_node(self, soq: Soquet) -> pydot.Node: + def get_dangle_node(self, soq: _Soquet) -> pydot.Node: """Overridable method to create a Node representing dangling Soquets.""" return pydot.Node(self.ids[soq], label=soq.pretty(), shape='plaintext') @@ -170,15 +170,15 @@ def add_dangles( subg = pydot.Subgraph(rank='same') for reg in regs: for idx in reg.all_idxs(): - subg.add_node(self.get_dangle_node(Soquet(dangle, reg, idx=idx))) + subg.add_node(self.get_dangle_node(_Soquet(dangle, reg, idx=idx))) graph.add_subgraph(subg) return graph - def soq_label(self, soq: Soquet) -> str: + def soq_label(self, soq: _Soquet) -> str: """Overridable method for getting label text for a Soquet.""" return soq.pretty() - def get_thru_register(self, thru: Soquet) -> str: + def get_thru_register(self, thru: _Soquet) -> str: """Overridable method for generating a