Skip to content

Commit 90f1547

Browse files
🔥 Removal of Pre-Trained Model Support (#300)
This PR stops the support of pre-trained models for the MQT Predictor framework. Up until mqt.predictor v2.0.0, pre-trained models were provided. However, this is not feasible anymore due to the increasing number of devices and figures of merits. Instead, we now provide a detailed documentation on how to train and setup the MQT Predictor framework. --------- Signed-off-by: Nils Quetschlich <[email protected]> Co-authored-by: Lukas Burgholzer <[email protected]>
1 parent 003df9e commit 90f1547

File tree

12 files changed

+104
-284
lines changed

12 files changed

+104
-284
lines changed

.github/workflows/pretrained_model.yml

Lines changed: 0 additions & 62 deletions
This file was deleted.

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ from mqt.predictor import qcompile
6060
from mqt.bench import get_benchmark
6161

6262
# get a benchmark circuit on algorithmic level representing the GHZ state with 5 qubits from [MQT Bench](https://github.com/cda-tum/mqt-bench)
63-
qc_uncompiled = get_benchmark(benchmark_name="dj", level="alg", circuit_size=5)
63+
qc_uncompiled = get_benchmark(benchmark_name="ghz", level="alg", circuit_size=5)
6464

6565
# compile it using the MQT Predictor
6666
qc_compiled, compilation_information, quantum_device = qcompile(qc_uncompiled)
@@ -72,7 +72,13 @@ print(quantum_device, compilation_information)
7272
print(qc_compiled.draw())
7373
```
7474

75-
**Detailed documentation and examples are available at [ReadTheDocs](https://mqt.readthedocs.io/projects/predictor).**
75+
> [!NOTE]
76+
> To execute the code, respective machine learning models must be trained before.
77+
> Up until mqt.predictor v2.0.0, pre-trained models were provided. However, this is not feasible anymore due to the
78+
> increasing number of devices and figures of merits. Instead, we now provide a detailed documentation on how to train
79+
> and setup the MQT Predictor framework.\*\*
80+
81+
**Further documentation and examples are available at [ReadTheDocs](https://mqt.readthedocs.io/projects/predictor).**
7682

7783
## References
7884

docs/Compilation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ We trained one RL model for each currently :ref:`supported quantum device <suppo
5252
Training Data
5353
-------------
5454
To train the model, sufficient training data must be provided as qasm files in the `respective directory <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/training_data/training_circuits>`_.
55-
We provide the training data used for the pre-trained models which are stored `here <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/training_data/trained_model>`_.
55+
We provide the training data used for the initial performance evaluation of this framework which are stored `here <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/training_data/trained_model>`_.

docs/DeviceSelection.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Training Data
6363
-------------
6464

6565
To train the model, sufficient training data must be provided as qasm files in the `respective directory <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/ml/training_data/training_circuits>`_.
66-
We provide the training data used for the pre-trained model.
66+
We provide the training data used in the initial performance evaluation of this framework.
6767

6868
After the adjustment is finished, the following methods need to be called to generate the training data:
6969

docs/Usage.rst

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,74 @@ For that, the repository must be cloned and installed:
3737
pip install .
3838
3939
Afterwards, the package can be used as described :ref:`above <pip_usage>`.
40+
41+
MQT Predictor Framework Setup
42+
=============================
43+
To run ``qcompile``, the MQT Predictor framework must be set up. How this is properly done is described next.
44+
45+
First, the to-be-considered quantum devices must be included in the framework.
46+
Currently, all devices supported by `MQT Bench <https://github.com/cda-tum/mqt-bench>`_ are natively supported.
47+
In case another device shall be considered, it can be added by using a similar format as in MQT Bench but it is not
48+
necessary to add it in the repository since it can be directly added to the MQT Predictor framework as follows:
49+
50+
- Modify in `mqt/predictor/rl/predictorenv.py <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/predictorenv.py>`_. the line where ``mqt.bench.devices.get_device_by_name`` is used.
51+
- Modify in `mqt/predictor/ml/predictor.py <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/ml/predictor.py>`_. the lines where ``mqt.bench.devices.*`` are used.
52+
- Follow the same data format as defined in `mqt.bench.devices.device.py <https://github.com/cda-tum/mqt-bench/tree/main/src/mqt/bench/devices/device.py>`_
53+
54+
Second, for each supported device, a respective reinforcement learning model must be trained. This is done by running
55+
the following command based on the training data in the form of quantum circuits provided as qasm files in
56+
`mqt/predictor/rl/training_data/training_circuits <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/rl/training_data/training_circuits>`_:
57+
58+
.. code-block:: python
59+
60+
import mqt.predictor
61+
62+
rl_pred = mqt.predictor.rl.Predictor(
63+
figure_of_merit="expected_fidelity", device_name="ibm_washington"
64+
)
65+
rl_pred.train_model(timesteps=100000, model_name="sample_model_rl")
66+
67+
This will train a reinforcement learning model for the ``ibm_washington`` device with the expected fidelity as figure of merit.
68+
Additionally to the expected fidelity, also critical depth is provided as another figure of merit.
69+
Further figures of merit can be added in `mqt.predictor.reward.py <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/reward.py>`_.
70+
71+
Third, after the reinforcement learning models that are used for the respective compilations are trained, the
72+
supervised machine learning model to predict the device selection must be trained.
73+
This is done by first creating the necessary training data (based on the training data in the form of quantum circuits provided as qasm files in
74+
`mqt/predictor/ml/training_data/training_circuits <https://github.com/cda-tum/mqt-predictor/tree/main/src/mqt/predictor/ml/training_data/training_circuits>`_) and then running the following command:
75+
76+
.. code-block:: python
77+
78+
ml_pred = mqt.predictor.ml.Predictor()
79+
ml_pred.generate_compiled_circuits(timeout=600) # timeout in seconds
80+
training_data, name_list, scores_list = ml_pred.generate_trainingdata_from_qasm_files(
81+
figure_of_merit="expected_fidelity"
82+
)
83+
mqt.predictor.ml.helper.save_training_data(
84+
training_data, name_list, scores_list, figure_of_merit="expected_fidelity"
85+
)
86+
87+
This will compile all provided uncompiled training circuits for all available devices and figures of merit.
88+
Afterwards, the training data is generated individually for a figure of merit.
89+
This training data can then be saved and used to train the supervised machine learning model:
90+
91+
.. code-block:: python
92+
93+
ml_pred.train_random_forest_classifier(figure_of_merit="expected_fidelity")
94+
95+
Finally, the MQT Predictor framework is fully set up and can be used to predict the most
96+
suitable device for a given quantum circuit using supervised machine learning and compile
97+
the circuit for the predicted device using reinforcement learning by running:
98+
99+
.. code-block:: python
100+
101+
from mqt.predictor import qcompile
102+
from mqt.bench import get_benchmark
103+
104+
qc_uncompiled = get_benchmark(benchmark_name="ghz", level="alg", circuit_size=5)
105+
compiled_qc, compilation_information, device = qcompile(
106+
uncompiled_qc, figure_of_merit="expected_fidelity"
107+
)
108+
109+
110+
This returns the compiled quantum circuit for the predicted device together with additional information of the compilation procedure.

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The MQT Predictor framework is based on two main components:
2020
- A :doc:`Device-Specific Circuit Compilation <Compilation>` component that compiles a given quantum circuit for a given device.
2121

2222
Combining these two components, the framework can be used to automatically compile a given quantum circuit for the most suitable device optimizing a :doc:`customizable figure of merit<FigureOfMerit>`.
23+
How to use the framework is described in the :doc:`Usage <Usage>` section.
2324

2425
If you are interested in the theory behind MQT Predictor, have a look at the publications in the :doc:`references list <References>`.
2526

src/mqt/predictor/ml/predictor.py

Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -40,79 +40,6 @@ def set_classifier(self, clf: RandomForestClassifier) -> None:
4040
"""Sets the classifier to the given classifier."""
4141
self.clf = clf
4242

43-
def compile_all_circuits_circuitwise(
44-
self,
45-
figure_of_merit: reward.figure_of_merit,
46-
timeout: int,
47-
source_path: Path | None = None,
48-
target_path: Path | None = None,
49-
logger_level: int = logging.INFO,
50-
) -> None:
51-
"""Compiles all circuits in the given directory with the given timeout and saves them in the given directory.
52-
53-
Arguments:
54-
figure_of_merit: The figure of merit to be used for compilation.
55-
timeout: The timeout in seconds for the compilation of a single circuit.
56-
source_path: The path to the directory containing the circuits to be compiled. Defaults to None.
57-
target_path: The path to the directory where the compiled circuits should be saved. Defaults to None.
58-
logger_level: The level of the logger. Defaults to logging.INFO.
59-
60-
"""
61-
logger.setLevel(logger_level)
62-
63-
if source_path is None:
64-
source_path = ml.helper.get_path_training_circuits()
65-
66-
if target_path is None:
67-
target_path = ml.helper.get_path_training_circuits_compiled()
68-
69-
Parallel(n_jobs=-1, verbose=100)(
70-
delayed(self.generate_compiled_circuits_for_single_training_circuit)(
71-
filename, timeout, source_path, target_path, figure_of_merit
72-
)
73-
for filename in source_path.iterdir()
74-
)
75-
76-
def generate_compiled_circuits_for_single_training_circuit(
77-
self,
78-
filename: Path,
79-
timeout: int,
80-
source_path: Path,
81-
target_path: Path,
82-
figure_of_merit: reward.figure_of_merit,
83-
) -> None:
84-
"""Compiles a single circuit with the given timeout and saves it in the given directory.
85-
86-
Arguments:
87-
filename: The path to the circuit to be compiled.
88-
timeout: The timeout in seconds for the compilation of the circuit.
89-
source_path: The path to the directory containing the circuit to be compiled.
90-
target_path: The path to the directory where the compiled circuit should be saved.
91-
figure_of_merit: The figure of merit to be used for compilation.
92-
93-
"""
94-
try:
95-
qc = QuantumCircuit.from_qasm_file(Path(source_path) / filename)
96-
if filename.suffix != ".qasm":
97-
return
98-
99-
for i, dev in enumerate(self.devices):
100-
target_filename = str(filename).split("/")[-1].split(".qasm")[0] + "_" + figure_of_merit + "_" + str(i)
101-
if (Path(target_path) / (target_filename + ".qasm")).exists() or qc.num_qubits > dev.num_qubits:
102-
continue
103-
try:
104-
res = utils.timeout_watcher(rl.qcompile, [qc, figure_of_merit, dev.name], timeout)
105-
if isinstance(res, tuple):
106-
compiled_qc = res[0]
107-
with Path(target_path / (target_filename + ".qasm")).open("w", encoding="utf-8") as f:
108-
dump(compiled_qc, f)
109-
110-
except Exception as e:
111-
print(e, filename, "inner")
112-
113-
except Exception as e:
114-
print(e, filename, "outer")
115-
11643
def compile_all_circuits_devicewise(
11744
self,
11845
device_name: str,
@@ -570,7 +497,8 @@ def predict_probs(self, qc: Path | QuantumCircuit, figure_of_merit: reward.figur
570497
self.clf = load(path)
571498

572499
if self.clf is None:
573-
error_msg = "Classifier is neither trained nor saved."
500+
error_msg = "The ML model is not trained yet. Please train the model before using it."
501+
logger.error(error_msg)
574502
raise FileNotFoundError(error_msg)
575503

576504
feature_dict = ml.helper.create_feature_dict(qc) # type: ignore[unreachable]

src/mqt/predictor/rl/helper.py

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33
from __future__ import annotations
44

55
import logging
6-
import os
7-
import sys
86
from pathlib import Path
97
from typing import TYPE_CHECKING, Any
108

119
import numpy as np
1210
import requests
1311
from bqskit import MachineModel
14-
from packaging import version
1512
from pytket.architecture import Architecture
1613
from pytket.circuit import Circuit, Node, Qubit
1714
from pytket.passes import (
@@ -67,14 +64,9 @@
6764
from mqt.bench.devices import Device
6865

6966

70-
if TYPE_CHECKING or sys.version_info >= (3, 10, 0):
71-
from importlib import metadata, resources
72-
else:
73-
import importlib_metadata as metadata
74-
import importlib_resources as resources
75-
7667
import operator
7768
import zipfile
69+
from importlib import resources
7870

7971
from bqskit import compile as bqskit_compile
8072
from bqskit.ir import gates
@@ -429,63 +421,12 @@ def load_model(model_name: str) -> MaskablePPO:
429421
The loaded model.
430422
"""
431423
path = get_path_trained_model()
432-
433-
if Path(path / (model_name + ".zip")).exists():
424+
if Path(path / (model_name + ".zip")).is_file():
434425
return MaskablePPO.load(path / (model_name + ".zip"))
435-
logger.info("Model does not exist. Try to retrieve suitable Model from GitHub...")
436-
try:
437-
mqtpredictor_module_version = metadata.version("mqt.predictor")
438-
except ModuleNotFoundError:
439-
error_msg = (
440-
"Could not retrieve version of mqt.predictor. Please run 'pip install . or pip install mqt.predictor'."
441-
)
442-
raise RuntimeError(error_msg) from None
443-
444-
headers = None
445-
if "GITHUB_TOKEN" in os.environ:
446-
headers = {"Authorization": f"token {os.environ['GITHUB_TOKEN']}"}
447-
448-
version_found = False
449-
response = requests.get("https://api.github.com/repos/cda-tum/mqt-predictor/tags", headers=headers)
450-
451-
if not response:
452-
error_msg = "Querying the GitHub API failed. One reasons could be that the limit of 60 API calls per hour and IP address is exceeded."
453-
raise RuntimeError(error_msg)
454-
455-
available_versions = [elem["name"] for elem in response.json()]
456-
457-
for possible_version in available_versions:
458-
if version.parse(mqtpredictor_module_version) >= version.parse(possible_version):
459-
url = "https://api.github.com/repos/cda-tum/mqt-predictor/releases/tags/" + possible_version
460-
response = requests.get(url, headers=headers)
461-
if not response:
462-
error_msg = "Suitable trained models cannot be downloaded since the GitHub API failed. One reasons could be that the limit of 60 API calls per hour and IP address is exceeded."
463-
raise RuntimeError(error_msg)
464-
465-
response_json = response.json()
466-
if "assets" in response_json:
467-
assets = response_json["assets"]
468-
elif "asset" in response_json:
469-
assets = [response_json["asset"]]
470-
else:
471-
assets = []
472-
473-
for asset in assets:
474-
if model_name in asset["name"]:
475-
version_found = True
476-
download_url = asset["browser_download_url"]
477-
logger.info("Downloading model from: " + download_url)
478-
handle_downloading_model(download_url, model_name)
479-
break
480-
481-
if version_found:
482-
break
483-
484-
if not version_found:
485-
error_msg = "No suitable model found on GitHub. Please update your mqt.predictor package using 'pip install -U mqt.predictor'."
486-
raise RuntimeError(error_msg) from None
487426

488-
return MaskablePPO.load(path / model_name)
427+
error_msg = "The RL model is not trained yet. Please train the model before using it."
428+
logger.error(error_msg)
429+
raise FileNotFoundError(error_msg)
489430

490431

491432
def handle_downloading_model(download_url: str, model_name: str) -> None:

0 commit comments

Comments
 (0)