Skip to content

Commit a3ba836

Browse files
Liu KeyuLiu Keyu
authored andcommitted
Fix: resolve pre-commit issues and add missing annotations
Fix: resolve pre-commit issues and add missing annotations Fix: resolve pre-commit issues and add missing annotations Remove example_test.py Remove example_test.py
1 parent 78dc1aa commit a3ba836

File tree

4 files changed

+79
-31
lines changed

4 files changed

+79
-31
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dependencies = [
4646
"numpy>=1.22,<2; sys_platform == 'darwin' and 'x86_64' in platform_machine and python_version < '3.13'", # Restrict numpy v2 for macOS x86 since it is not supported anymore since torch v2.3.0
4747
"torch>=2.2.2,<2.3.0; sys_platform == 'darwin' and 'x86_64' in platform_machine and python_version < '3.13'", # Restrict torch v2.3.0 for macOS x86 since it is not supported anymore.
4848
"typing-extensions>=4.1", # for `assert_never`
49+
"qiskit-ibm-transpiler>=0.2.0",
4950
]
5051

5152
classifiers = [

src/mqt/predictor/rl/example_test.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/mqt/predictor/rl/helper.py

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212

1313
import logging
1414
from pathlib import Path
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Optional, Callable, List, Tuple, Dict, Any
1616

1717
import numpy as np
1818
from qiskit import QuantumCircuit
1919
from qiskit.converters import circuit_to_dag, dag_to_circuit
20-
from qiskit.circuit import ClassicalRegister, QuantumRegister
21-
from qiskit.transpiler import PassManager
20+
from qiskit.circuit import ClassicalRegister, QuantumRegister, Instruction
21+
from qiskit.transpiler import PassManager, Target
22+
from qiskit.dagcircuit import DAGCircuit
2223
from qiskit_ibm_transpiler.ai.routing import AIRouting
2324

2425
from mqt.predictor.utils import calc_supermarq_features
26+
from mqt.predictor.rl.actions import Action
2527

2628
if TYPE_CHECKING:
2729
from numpy.random import Generator
@@ -32,7 +34,16 @@
3234

3335
logger = logging.getLogger("mqt-predictor")
3436

35-
def extract_cregs_and_measurements(qc):
37+
def extract_cregs_and_measurements(qc: QuantumCircuit) -> Tuple[List[ClassicalRegister], List[tuple[Instruction, List, List]]]:
38+
"""
39+
Extracts classical registers and measurement operations from a quantum circuit.
40+
41+
Args:
42+
qc: The input QuantumCircuit.
43+
44+
Returns:
45+
A tuple containing a list of classical registers and a list of measurement operations.
46+
"""
3647
cregs = [ClassicalRegister(cr.size, name=cr.name) for cr in qc.cregs]
3748
measurements = [
3849
(item.operation, item.qubits, item.clbits)
@@ -41,7 +52,16 @@ def extract_cregs_and_measurements(qc):
4152
]
4253
return cregs, measurements
4354

44-
def remove_cregs(qc):
55+
def remove_cregs(qc: QuantumCircuit) -> QuantumCircuit:
56+
"""
57+
Removes classical registers and measurement operations from the circuit.
58+
59+
Args:
60+
qc: The input QuantumCircuit.
61+
62+
Returns:
63+
A new QuantumCircuit with only quantum operations (no cregs or measurements).
64+
"""
4565
qregs = [QuantumRegister(qr.size, name=qr.name) for qr in qc.qregs]
4666
new_qc = QuantumCircuit(*qregs)
4767
old_to_new = {}
@@ -55,7 +75,24 @@ def remove_cregs(qc):
5575
new_qc.append(instr, qargs)
5676
return new_qc
5777

58-
def add_cregs_and_measurements(qc, cregs, measurements, qubit_map=None):
78+
def add_cregs_and_measurements(
79+
qc: QuantumCircuit,
80+
cregs: List[ClassicalRegister],
81+
measurements: List[Tuple[Instruction, List, List]],
82+
qubit_map: Optional[Dict] = None,
83+
) -> QuantumCircuit:
84+
"""
85+
Adds classical registers and measurement operations back to the quantum circuit.
86+
87+
Args:
88+
qc: The quantum circuit to which cregs and measurements are added.
89+
cregs: List of ClassicalRegister to add.
90+
measurements: List of measurement instructions as tuples (Instruction, qubits, clbits).
91+
qubit_map: Optional dictionary mapping original qubits to new qubits.
92+
93+
Returns:
94+
The modified QuantumCircuit with cregs and measurements added.
95+
"""
5996
for cr in cregs:
6097
qc.add_register(cr)
6198
for instr, qargs, cargs in measurements:
@@ -68,10 +105,15 @@ def add_cregs_and_measurements(qc, cregs, measurements, qubit_map=None):
68105

69106
class SafeAIRouting(AIRouting):
70107
"""
71-
Remove cregs before AIRouting and add them back afterwards
72-
Necessary because there are cases AIRouting can't handle
108+
Custom AIRouting wrapper that removes classical registers before routing.
109+
110+
This prevents failures in AIRouting when classical bits are present by
111+
temporarily removing classical registers and measurements and restoring
112+
them after routing is completed.
73113
"""
74-
def run(self, dag):
114+
def run(self, dag: DAGCircuit) -> DAGCircuit:
115+
"""Run the routing pass on a DAGCircuit."""
116+
75117
# 1. Convert input dag to circuit
76118
qc_orig = dag_to_circuit(dag)
77119

@@ -101,26 +143,36 @@ def run(self, dag):
101143
else:
102144
try:
103145
idx = qc_routed.qubits.index(phys)
104-
except ValueError:
105-
raise RuntimeError(f"Physical qubit {phys} not found in output circuit!")
146+
except ValueError as err:
147+
raise RuntimeError(f"Physical qubit {phys} not found in output circuit!") from err
106148
qubit_map[virt] = qc_routed.qubits[idx]
107149
# 7. Restore classical registers and measurement instructions
108150
qc_final = add_cregs_and_measurements(qc_routed, cregs, measurements, qubit_map)
109151
# 8. Return as dag
110152
return circuit_to_dag(qc_final)
111153

112154
def best_of_n_passmanager(
113-
action, device, qc, max_iteration=(20,20),
114-
metric_fn=None,
115-
):
155+
action: Action,
156+
device: Target,
157+
qc: QuantumCircuit,
158+
max_iteration: Tuple[int, int] = (20, 20),
159+
metric_fn: Optional[Callable[[QuantumCircuit], float]] = None,
160+
)-> tuple[QuantumCircuit, Dict[str, Any]]:
116161
"""
117162
Runs the given transpile_pass multiple times and keeps the best result.
118-
action: the action dict with a 'transpile_pass' key (lambda/device->[passes])
119-
device: the backend or device
120-
qc: input circuit
121-
max_iteration: number of times to try
122-
metric_fn: function(circ) -> float for scoring
123-
require_layout: skip outputs with missing layouts
163+
164+
Args:
165+
action: The action dictionary with a 'transpile_pass' key
166+
(lambda device -> [passes]).
167+
device: The target backend or device.
168+
qc: The input quantum circuit.
169+
max_iteration: A tuple (layout_trials, routing_trials) specifying
170+
how many times to try.
171+
metric_fn: Optional function to score circuits; defaults to circuit depth.
172+
173+
Returns:
174+
A tuple containing the best transpiled circuit and its corresponding
175+
property set.
124176
"""
125177
best_val = None
126178
best_result = None
@@ -249,7 +301,7 @@ def get_openqasm_gates() -> list[str]:
249301
"rccx",
250302
]
251303

252-
def create_feature_dict(qc: QuantumCircuit, basis_gates: list[str], coupling_map) -> dict[str, int | NDArray[np.float64]]:
304+
def create_feature_dict(qc: QuantumCircuit) -> dict[str, int | NDArray[np.float64]]:
253305
"""Creates a feature dictionary for a given quantum circuit.
254306
255307
Arguments:

src/mqt/predictor/rl/predictorenv.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,17 +316,18 @@ def apply_action(self, action_index: int) -> QuantumCircuit | None:
316316
raise ValueError(msg)
317317

318318
def _apply_qiskit_action(self, action: Action, action_index: int) -> QuantumCircuit:
319-
if action.get("stochastic", False):
320-
metric_fn = lambda circ: circ.count_ops().get("swap", 0)
319+
if getattr(action, "stochastic", False):
320+
def metric_fn(circ: QuantumCircuit) -> float:
321+
return float(circ.count_ops().get("swap", 0))
321322
# for stochastic actions, pass the layout/routing trials parameter
322323
max_iteration = self.max_iter
323-
if "Sabre" in action["name"] and "AIRouting" not in action["name"]:
324+
if "Sabre" in action.name and "AIRouting" not in action.name:
324325
# Internal trials for Sabre
325326
transpile_pass = action.transpile_pass(self.device, max_iteration)
326327
pm = PassManager(transpile_pass)
327328
altered_qc = pm.run(self.state)
328329
pm_property_set = dict(pm.property_set)
329-
elif "AIRouting" in action["name"]:
330+
elif "AIRouting" in action.name:
330331
# Run AIRouting in custom loop
331332
altered_qc, pm_property_set = best_of_n_passmanager(
332333
action,

0 commit comments

Comments
 (0)