Skip to content

Commit fe11f36

Browse files
nquetschlichburgholzerpre-commit-ci[bot]
authored
Fix training data handling (#58)
- fixed the issue reported in #56 - improves unzipping if no to be compiled qasm files are present Co-authored-by: Lukas Burgholzer <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ea9e8c6 commit fe11f36

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

src/mqt/predictor/ml/Predictor.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,22 @@ def generate_compiled_circuits(
141141
if target_path is None:
142142
target_path = str(ml.helper.get_path_training_circuits_compiled())
143143

144-
source_circuits_list = []
145-
146-
for file in Path(source_path).iterdir():
147-
if "qasm" in str(file):
148-
source_circuits_list.append(str(file))
149-
150144
path_zip = Path(source_path) / "mqtbench_training_samples.zip"
151-
if len(source_circuits_list) == 0 and path_zip.exists():
145+
if (
146+
not any(file.suffix == ".qasm" for file in Path(source_path).iterdir())
147+
and path_zip.exists()
148+
):
152149
path_zip = str(path_zip)
153150
import zipfile
154151

155152
with zipfile.ZipFile(path_zip, "r") as zip_ref:
156153
zip_ref.extractall(source_path)
157154

158-
if not Path(source_path).is_dir():
159-
Path(source_path).mkdir()
155+
Path(target_path).mkdir(exist_ok=True)
156+
157+
source_circuits_list = [
158+
file.name for file in Path(source_path).iterdir() if file.suffix == ".qasm"
159+
]
160160

161161
Parallel(n_jobs=-1, verbose=100)(
162162
delayed(self.compile_all_circuits_for_qc)(
@@ -174,7 +174,7 @@ def generate_trainingdata_from_qasm_files(
174174
175175
Keyword arguments:
176176
source_path -- path to file
177-
target_directory -- path to directory for compiled circuit
177+
target_path -- path to directory for compiled circuit
178178
179179
Return values:
180180
training_data_ML_aggregated -- training data
@@ -194,7 +194,7 @@ def generate_trainingdata_from_qasm_files(
194194

195195
results = Parallel(n_jobs=-1, verbose=100)(
196196
delayed(self.generate_training_sample)(
197-
str(filename), source_path, target_path
197+
str(filename.name), source_path, target_path
198198
)
199199
for filename in Path(source_path).iterdir()
200200
)

tests/ml/test_predictor_ml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def test_generate_compiled_circuits():
8181
qasm_path = Path("compiled_test.qasm")
8282
qc.qasm(filename=str(qasm_path))
8383
predictor.generate_compiled_circuits(source_path, str(target_path))
84+
assert any(file.suffix == ".qasm" for file in target_path.iterdir())
8485

8586
training_sample, circuit_name, scores = predictor.generate_training_sample(
8687
str(qasm_path), source_path, target_path

0 commit comments

Comments
 (0)