Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionalty for determining whether pairs of moments commute #6679

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 90 additions & 16 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@
from typing_extensions import Self

import numpy as np
from scipy.cluster.hierarchy import DisjointSet

from cirq import protocols, ops, qis, _compat
from cirq._import import LazyLoader
from cirq.ops import raw_types, op_tree
from cirq.protocols import circuit_diagram_info_protocol
from cirq.protocols import (
circuit_diagram_info_protocol,
apply_unitary,
ApplyUnitaryArgs,
definitely_commutes,
)
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
Expand Down Expand Up @@ -657,10 +663,11 @@ def cleanup_key(key: Any) -> Any:
return diagram.render()

def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
"""Determines whether Moment commutes with the Operation.
"""Determines whether Moment commutes with either another Moment or
an Operation.

Args:
other: An Operation object. Other types are not implemented yet.
other: An Operation or Moment object. Other types are not implemented yet.
In case a different type is specified, NotImplemented is
returned.
atol: Absolute error tolerance. If all entries in v1@v2 - v2@v1
Expand All @@ -669,25 +676,92 @@ def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplem

Returns:
True: The Moment and Operation commute OR they don't have shared
quibits.
qubits.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you fix the formatting?

False: The two values do not commute.
NotImplemented: In case we don't know how to check this, e.g.
the parameter type is not supported yet.
"""
if not isinstance(other, ops.Operation):
return NotImplemented

other_qubits = set(other.qubits)
for op in self.operations:
if not other_qubits.intersection(set(op.qubits)):
continue

commutes = protocols.commutes(op, other, atol=atol, default=NotImplemented)
if isinstance(other, ops.Operation):
# If an Operation is provided, convert this to a Moment consisting only
# of the given Operation
return self._commutes_(Moment(other), atol=atol)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not preserve the old code? ... perhaps in a def _commutes_with_op


if isinstance(other, Moment):
# Check if sets of qubits overlap. If not, then no need to go any further.
if not set(self.qubits) & set(other.qubits):
return True

# Check pairwise commuting between all pairs of
# operations. If they all commute then no
# need to go any further
if all(
definitely_commutes(op_1, op_2, atol=atol)
for op_1, op_2 in itertools.product(self.operations, other.operations)
):
return True

# Decompose into disjoint overlapping sets of qubits
qubit_subsets = [list(op.qubits) for op in self.operations + other.operations]
disjoint_set = DisjointSet(itertools.chain.from_iterable(qubit_subsets))
for subset in qubit_subsets:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if a qubit appears in both moments in operations with different qubits?
so in m1: op1 acting on (q1, q0) in m2 op2 acting on q3, q4, q0)

if len(subset) < 2:
continue
for k in range(len(subset) - 1):
disjoint_set.merge(subset[k], subset[k + 1])
disjoint_qubit_subsets = disjoint_set.subsets()

# Decompose both moments onto each disjoint set of qubits and
# check for commutation using the unitary representation
if all(
definitely_commutes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this code raise an error if the operations are not unitary?

self[disjoint_set]._unitary_on_qubits(list(disjoint_set)),
other[disjoint_set]._unitary_on_qubits(list(disjoint_set)),
atol=atol,
)
for disjoint_set in disjoint_qubit_subsets
):
return True

return False

return NotImplemented

def _unitary_on_qubits(self, target_qubits: list['cirq.Qid']) -> np.ndarray:
"""Returns the unitary representation of the given moment when acting
on the target qubits.

.. note::

The :code:`target_qubits` must contain all the qubits that the
moment acts on.

if not commutes or commutes is NotImplemented:
return commutes
Args:
moment: The moment to decompose.
target_basis: The target qubits.

return True
Returns:
np.ndarray: The unitary representation of the Moment on the
target qubits.
"""
# Check moment has support on subset of target qubits and that there
# are no duplicates
current_qubits = self.qubits
assert all(qubit in target_qubits for qubit in current_qubits)
assert len(set(target_qubits)) == len(target_qubits)
# Define dims
total_qubits = len(target_qubits)
dim = 2**total_qubits
# Get the indices of the target qubit that the moment has support on
qubit_indices = [target_qubits.index(qubit) for qubit in current_qubits]

# Get the tensor operation corresponding to the moment acting on the
# target qubits.
id_tensor = qis.eye_tensor((2,) * total_qubits, dtype=np.complex128)
unitary = apply_unitary(
self, args=ApplyUnitaryArgs(id_tensor, np.empty_like(id_tensor), qubit_indices)
)
# Reshape into a square unitary matrix
return unitary.reshape(dim, dim)


class _SortByValFallbackToType:
Expand Down
12 changes: 12 additions & 0 deletions cirq-core/cirq/circuits/moment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,18 @@ def test_commutes():
assert not cirq.commutes(moment, cirq.X(c))


def test_commutes_multiqubit_gates():
a = cirq.NamedQubit('a')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test case is very simple and doesn't cover the more complex cases ... see above

b = cirq.NamedQubit('b')
c = cirq.NamedQubit("c")

moment = cirq.Moment([cirq.Z(a), cirq.Z(b)])
assert cirq.commutes(moment, cirq.XXPowGate(exponent=1 / 2)(a, b))

moment = cirq.Moment([cirq.XXPowGate(exponent=1 / 2)(a, b), cirq.Z(c)])
assert not cirq.commutes(moment, cirq.Z(b))


def test_transform_qubits():
a, b = cirq.LineQubit.range(2)
x, y = cirq.GridQubit.rect(2, 1, 10, 20)
Expand Down