Skip to content

Commit 3c7592b

Browse files
Liu KeyuLiu Keyu
authored andcommitted
Fix: resolve pre-commit issues and add missing annotations
1 parent f71fb29 commit 3c7592b

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

src/mqt/predictor/rl/actions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747
from qiskit.passmanager import ConditionalController
4848
from qiskit.transpiler import (
49-
CouplingMap,
49+
CouplingMap, Target
5050
)
5151
from qiskit.transpiler.passes import (
5252
ApplyLayout,
@@ -126,6 +126,7 @@ class Action:
126126
transpile_pass: (
127127
list[qiskit_BasePass | tket_BasePass]
128128
| Callable[..., list[qiskit_BasePass | tket_BasePass]]
129+
| Callable[[Target, tuple[int, int]], list[qiskit_BasePass | tket_BasePass]]
129130
| Callable[
130131
...,
131132
Callable[..., tuple[Any, ...] | Circuit],

src/mqt/predictor/rl/helper.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
from qiskit.circuit import ClassicalRegister, QuantumRegister, Instruction, Qubit
2121
from qiskit.transpiler import PassManager, Target
2222
from qiskit.dagcircuit import DAGCircuit
23-
from qiskit_ibm_transpiler.ai.routing import AIRouting
2423

2524
from mqt.predictor.utils import calc_supermarq_features
2625
from mqt.predictor.rl.actions import Action
2726

2827
if TYPE_CHECKING:
2928
from numpy.random import Generator
3029
from numpy.typing import NDArray
30+
from qiskit_ibm_transpiler.ai.routing import AIRouting
31+
else:
32+
AIRouting = object # type: ignore[misc]
3133

3234
import zipfile
3335
from importlib import resources
@@ -103,7 +105,7 @@ def add_cregs_and_measurements(
103105
qc.append(instr, new_qargs, cargs)
104106
return qc
105107

106-
class SafeAIRouting(AIRouting):
108+
class SafeAIRouting(AIRouting): # type: ignore[misc]
107109
"""
108110
Custom AIRouting wrapper that removes classical registers before routing.
109111
@@ -137,9 +139,15 @@ def run(self, dag: DAGCircuit) -> DAGCircuit:
137139

138140
qubit_map = {}
139141
for virt in qc_orig.qubits:
140-
phys = final_layout[virt]
142+
try:
143+
phys = final_layout[virt] # This is now safe due to above check
144+
except KeyError as err:
145+
raise RuntimeError(f"Virtual qubit {virt} not found in final layout!") from err
141146
if isinstance(phys, int):
142-
qubit_map[virt] = qc_routed.qubits[phys]
147+
try:
148+
qubit_map[virt] = qc_routed.qubits[phys]
149+
except IndexError as err:
150+
raise RuntimeError(f"Physical index {phys} is out of range in routed circuit!") from err
143151
else:
144152
try:
145153
idx = qc_routed.qubits.index(phys)
@@ -183,6 +191,8 @@ def best_of_n_passmanager(
183191
else:
184192
all_passes = action.transpile_pass(device)
185193

194+
assert isinstance(all_passes, list)
195+
186196
layout_passes = all_passes[:-1]
187197
routing_pass = all_passes[-1:]
188198

src/mqt/predictor/rl/predictorenv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def metric_fn(circ: QuantumCircuit) -> float:
361361
return altered_qc
362362

363363
def _handle_qiskit_layout_postprocessing(
364-
self, action: Action, pm_property_set: dict, altered_qc: QuantumCircuit,
364+
self, action: Action, pm_property_set: dict[str, any], altered_qc: QuantumCircuit,
365365
) -> QuantumCircuit:
366366
if action.name == "VF2PostLayout":
367367
assert pm_property_set["VF2PostLayout_stop_reason"] is not None
@@ -379,7 +379,7 @@ def _handle_qiskit_layout_postprocessing(
379379
_output_qubit_list=altered_qc.qubits,
380380
_input_qubit_count=self.num_qubits_uncompiled_circuit,
381381
)
382-
if pm_property_set["final_layout"]:
382+
if self.layout is not None and pm_property_set["final_layout"]:
383383
self.layout.final_layout = pm_property_set["final_layout"]
384384
return altered_qc
385385

0 commit comments

Comments
 (0)