Skip to content

Commit

Permalink
Non e' molto, ma e' un lavoro onesto
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Feb 12, 2025
1 parent c1cd09c commit a06cfd2
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 42 deletions.
89 changes: 47 additions & 42 deletions pennylane/transforms/optimization/single_qubit_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,33 @@ def interpret_operation(self, op: Operator):
f"cumulative_angles: {cumulative_angles}, obtained from single_qubit_rot_angles called on current op: {op}"
)
except (NotImplementedError, AttributeError):
# If the operation does not have the single_qubit_rot_angles method, we store it
# because we know we can't fuse it with any other operation.
# We cannot interpret the operation right away, because we need the use this to separate
print(f"single_qubit_rot_angles not available for current_gate: {op}")
print("registro op in previous_ops (rischiando di sovrascrivere) e returno")
for w in op.wires:
self.previous_ops[w] = op
return []

previous_ops_on_wires = set(self.previous_ops.get(w) for w in op.wires)

print(
f"removing previous_ops_on_wires={previous_ops_on_wires} from previous_ops={self.previous_ops}"
)

for o in previous_ops_on_wires:
if o is not None:
for w in o.wires:
self.previous_ops.pop(w)

print(
f"interpreting previous_ops_on_wires (printed above) and op={op} with super().interpret_operation, then returning"
)
res = []

Check notice on line 90 in pennylane/transforms/optimization/single_qubit_fusion.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/optimization/single_qubit_fusion.py#L90

Access to a protected member _primitive of a client class (protected-access)
for o in previous_ops_on_wires:
res.append(super().interpret_operation(o))

res.append(super().interpret_operation(op))
return res

# previous operation on the same wire
prev_op = self.previous_ops.get(op.wires[0], None)
print(f"prev_op: {prev_op} retrieved from previous_ops")
print(f"prev_op: {prev_op} retrieved from previous_ops={self.previous_ops}")

if prev_op is None:
# We cannot interpret the operation right away, because for example the first operation in the circuit
# has no previous ops stored but it might be able to be fused with the next operation.
Expand All @@ -90,55 +106,41 @@ def interpret_operation(self, op: Operator):
self.previous_ops[w] = qml.Rot._primitive.impl(
*cumulative_angles, wires=op.wires
)
# Rot(*cumulative_angles, wires=op.wires)
print(f"previous_ops: {self.previous_ops}, I just stored Rot")

while prev_op is not None:

print(f"starting while loop with prev_op: {prev_op}")

try:
prev_op_angles = qml.math.stack(prev_op.single_qubit_rot_angles())
print(
f"prev_op angles: {prev_op_angles}, obtained from single_qubit_rot_angles called on prev_op: {prev_op}"
)
except (NotImplementedError, AttributeError):
print(f"single_qubit_rot_angles not available for prev_op: {prev_op}. Break.")
break
print(
f"Stored the current op (transformed in Rot) in previous_ops: {self.previous_ops}. Returning now"
)
return []

cumulative_angles = fuse_rot_angles(cumulative_angles, prev_op_angles)
prev_op_angles = qml.math.stack(prev_op.single_qubit_rot_angles())

prev_op = self.previous_ops.get(op.wires[0], None)
print(f"prev_op: {prev_op}")
# We need to be careful about the order of the operations, as rotations do not commute in general.
cumulative_angles = fuse_rot_angles(prev_op_angles, cumulative_angles)

print(f"SEMO giunti alla fine. interpreto tutti i previous ops sullo stesso wire")
print(f"cumulative_angles after fuse_rot_angles: {cumulative_angles}")

# Putting the operations in a set to avoid applying the same op multiple times
# Using a set causes order to no longer be guaranteed, so the new order of the
# operations might differ from the original order. However, this only impacts
# operators without any shared wires, so correctness will not be impacted.
previous_ops_on_wires = set(self.previous_ops.get(w) for w in op.wires)
for o in previous_ops_on_wires:
if o is not None:
for w in o.wires:
self.previous_ops.pop(w)
for w in op.wires:
self.previous_ops[w] = qml.Rot._primitive.impl(*cumulative_angles, wires=op.wires)

Check notice on line 122 in pennylane/transforms/optimization/single_qubit_fusion.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/optimization/single_qubit_fusion.py#L122

Access to a protected member _primitive of a client class (protected-access)
print(
f"Stored the current op (transformed in Rot) in previous_ops: {self.previous_ops}. Returning now"
)

res = []
for o in previous_ops_on_wires:
res.append(super().interpret_operation(o))
return res
return []

def interpret_all_previous_ops(self) -> None:

print(f"\ninterpret_all_previous_ops called")

Check notice on line 131 in pennylane/transforms/optimization/single_qubit_fusion.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/optimization/single_qubit_fusion.py#L131

Missing function or method docstring (missing-function-docstring)

ops_remaining = set(self.previous_ops.values())

Check notice on line 133 in pennylane/transforms/optimization/single_qubit_fusion.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/optimization/single_qubit_fusion.py#L133

Using an f-string that does not have any interpolated variables (f-string-without-interpolation)
print(f"ops_remaining: {ops_remaining}")
print(f"ops_remaining: {ops_remaining}, which will be interpreted now")
for op in ops_remaining:
super().interpret_operation(op)
print(f"interpreted op: {op} with super().interpret_operation")

all_wires = tuple(self.previous_ops.keys())

print(
f"removing operations on all_wires={all_wires} from previous_ops={self.previous_ops}"
)

for w in all_wires:
self.previous_ops.pop(w)

Expand Down Expand Up @@ -512,6 +514,9 @@ def qfunc(r1, r2):
)
except (NotImplementedError, AttributeError):
break

print(f"cumulative_angles: {cumulative_angles} before fuse_rot_angles")

cumulative_angles = fuse_rot_angles(cumulative_angles, next_gate_angles)
print(
f"cumulative_angles: {cumulative_angles}, obtained from fuse_rot_angles applied to cumulative_angles and next_gate_angles"
Expand Down
93 changes: 93 additions & 0 deletions tests/capture/transforms/test_capture_single_qubit_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the ``SingleQubitFusionInterpreter`` class."""

# pylint:disable=wrong-import-position,protected-access
import pytest

import pennylane as qml

jax = pytest.importorskip("jax")

from pennylane.capture.primitives import (
adjoint_transform_prim,
cond_prim,
ctrl_transform_prim,
for_loop_prim,
grad_prim,
jacobian_prim,
qnode_prim,
while_loop_prim,
)
from pennylane.transforms.optimization.single_qubit_fusion import (
SingleQubitFusionInterpreter,
single_qubit_plxpr_to_plxpr,
)

pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")]


def check_matrix_equivalence(matrix_expected, matrix_obtained, atol=1e-8):
"""Takes two matrices and checks if multiplying one by the conjugate
transpose of the other gives the identity."""

mat_product = qml.math.dot(qml.math.conj(qml.math.T(matrix_obtained)), matrix_expected)
mat_product = mat_product / mat_product[0, 0]

return qml.math.allclose(mat_product, qml.math.eye(matrix_expected.shape[0]), atol=atol)


def extract_abstract_operator_eqns(jaxpr):
"""Extracts all JAXPR equations that correspond to abstract operators."""
abstract_op_eqns = []

for eqn in jaxpr.eqns:

primitive = eqn.primitive

if getattr(primitive, "prim_type", "") == "operator":

abstract_op_eqns.append(eqn)

return abstract_op_eqns


class TestSingleQubitFusionInterpreter:
"""Unit tests for the SingleQubitFusionInterpreter class"""

def test_single_qubit_full_fusion(self):
"""Test that a sequence of single-qubit gates all fuse."""

def circuit():
qml.RZ(0.3, wires=0)
qml.Hadamard(wires=0)
qml.Rot(0.1, 0.2, 0.3, wires=0)
qml.RX(0.1, wires=0)
qml.SX(wires=0)
qml.T(wires=0)
qml.PauliX(wires=0)

transformed_circuit = SingleQubitFusionInterpreter()(circuit)

jaxpr = jax.make_jaxpr(transformed_circuit)()
cleaned_jaxpr = extract_abstract_operator_eqns(jaxpr)

expected_primitive = {qml.Rot._primitive}
actual_primitives = {cleaned_jaxpr[0].primitive}
assert expected_primitive == actual_primitives

with qml.capture.pause():
transformed_circuit_comparison = qml.transforms.single_qubit_fusion(circuit)

# TODO: find a way to compare the two transformed circuits

0 comments on commit a06cfd2

Please sign in to comment.