Skip to content

Commit a2ca536

Browse files
committed
Add checks for computer existence for filenames
Also, fix hatch
1 parent dabeed4 commit a2ca536

File tree

3 files changed

+32
-36
lines changed

3 files changed

+32
-36
lines changed

src/sirocco/parsing/yaml_data_models.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,15 @@ class ConfigShellTaskSpecs:
279279
plugin: ClassVar[Literal["shell"]] = "shell"
280280
port_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"{PORT(\[sep=.+\])?::(.+?)}"), repr=False)
281281
sep_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"\[sep=(.+)\]"), repr=False)
282-
src: str | None = Field(
282+
src: str | None = field(
283283
default=None,
284-
description=(
285-
"If `src` not absolute, this ends up to be relative to the root directory of the config file."
286-
"This should also be solved by registering `Code`s in AiiDA for the required scripts."
287-
"See issue #127 and #60"
288-
),
284+
metadata={
285+
"description": (
286+
"If `src` not absolute, this ends up to be relative to the root directory of the config file."
287+
"This should also be solved by registering `Code`s in AiiDA for the required scripts."
288+
"See issues #60 and #127."
289+
)
290+
},
289291
)
290292
command: str
291293
env_source_files: list[str] = field(default_factory=list)

src/sirocco/workgraph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,8 @@ def _set_shelljob_filenames(self, task: core.ShellTask):
335335
"""set AiiDA ShellJob filenames for AvailableData entities"""
336336

337337
filenames = {}
338-
339338
for input_ in task.input_data_nodes():
340-
if isinstance(input_, core.AvailableData):
339+
if isinstance(input_, core.AvailableData) and task.computer and input_.computer:
341340
filenames[input_.name] = Path(input_.src).name
342341

343342
workgraph_task = self.task_from_core(task)

tests/test_workgraph.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from sirocco.core import Workflow
1+
import pytest
22
from aiida import orm
3+
4+
from sirocco.core import Workflow
35
from sirocco.parsing import yaml_data_models as models
46
from sirocco.workgraph import AiidaWorkGraph
57

68

7-
def test_set_shelljob_filenames(tmp_path, aiida_localhost):
9+
@pytest.mark.usefixtures("aiida_localhost")
10+
def test_set_shelljob_filenames(tmp_path):
811
file_name = "foo.txt"
912
file_path = tmp_path / file_name
1013
# Dummy script, as `src` must be specified due to relative command path
@@ -19,17 +22,13 @@ def test_set_shelljob_filenames(tmp_path, aiida_localhost):
1922
tasks=[
2023
models.ConfigCycleTask(
2124
name="task",
22-
inputs=[
23-
models.ConfigCycleTaskInput(name="my_data", port="unused")
24-
],
25+
inputs=[models.ConfigCycleTaskInput(name="my_data", port="unused")],
2526
),
2627
],
2728
),
2829
],
2930
tasks=[
30-
models.ConfigShellTask(
31-
name="task", command="echo test", src=str(script_path)
32-
),
31+
models.ConfigShellTask(name="task", command="echo test", src=str(script_path)),
3332
],
3433
data=models.ConfigData(
3534
available=[
@@ -46,13 +45,14 @@ def test_set_shelljob_filenames(tmp_path, aiida_localhost):
4645

4746
core_wf = Workflow.from_config_workflow(config_workflow=config_wf)
4847
aiida_wf = AiidaWorkGraph(core_workflow=core_wf)
49-
assert isinstance(
50-
aiida_wf._workgraph.tasks[0].inputs.nodes["my_data"].value, orm.RemoteData
51-
)
52-
assert aiida_wf._workgraph.tasks[0].inputs.filenames.value == {"my_data": "foo.txt"}
48+
remote_data = aiida_wf._workgraph.tasks[0].inputs.nodes["my_data"].value # noqa: SLF001
49+
assert isinstance(remote_data, orm.RemoteData)
50+
filenames = aiida_wf._workgraph.tasks[0].inputs.filenames.value # noqa: SLF001
51+
assert filenames == {"my_data": "foo.txt"}
5352

5453

55-
def test_multiple_inputs_filenames(tmp_path, aiida_localhost):
54+
@pytest.mark.usefixtures("aiida_localhost")
55+
def test_multiple_inputs_filenames(tmp_path):
5656
file_names = ["foo.txt", "bar.txt", "baz.dat"]
5757
for name in file_names:
5858
(tmp_path / name).touch()
@@ -69,19 +69,15 @@ def test_multiple_inputs_filenames(tmp_path, aiida_localhost):
6969
models.ConfigCycleTask(
7070
name="task",
7171
inputs=[
72-
models.ConfigCycleTaskInput(
73-
name=f"data_{i}", port=f"port_{i}"
74-
)
72+
models.ConfigCycleTaskInput(name=f"data_{i}", port=f"port_{i}")
7573
for i in range(len(file_names))
7674
],
7775
),
7876
],
7977
),
8078
],
8179
tasks=[
82-
models.ConfigShellTask(
83-
name="task", command="echo test", src=str(script_path)
84-
),
80+
models.ConfigShellTask(name="task", command="echo test", src=str(script_path)),
8581
],
8682
data=models.ConfigData(
8783
available=[
@@ -101,10 +97,12 @@ def test_multiple_inputs_filenames(tmp_path, aiida_localhost):
10197
aiida_wf = AiidaWorkGraph(core_workflow=core_wf)
10298

10399
expected_filenames = {f"data_{i}": name for i, name in enumerate(file_names)}
104-
assert aiida_wf._workgraph.tasks[0].inputs.filenames.value == expected_filenames
100+
filenames = aiida_wf._workgraph.tasks[0].inputs.filenames.value # noqa: SLF001
101+
assert filenames == expected_filenames
105102

106103

107-
def test_directory_input_filenames(tmp_path, aiida_localhost):
104+
@pytest.mark.usefixtures("aiida_localhost")
105+
def test_directory_input_filenames(tmp_path):
108106
dir_name = "test_dir"
109107
dir_path = tmp_path / dir_name
110108
dir_path.mkdir()
@@ -119,17 +117,13 @@ def test_directory_input_filenames(tmp_path, aiida_localhost):
119117
tasks=[
120118
models.ConfigCycleTask(
121119
name="task",
122-
inputs=[
123-
models.ConfigCycleTaskInput(name="my_dir", port="unused")
124-
],
120+
inputs=[models.ConfigCycleTaskInput(name="my_dir", port="unused")],
125121
),
126122
],
127123
),
128124
],
129125
tasks=[
130-
models.ConfigShellTask(
131-
name="task", command="echo test", src=str(script_path)
132-
),
126+
models.ConfigShellTask(name="task", command="echo test", src=str(script_path)),
133127
],
134128
data=models.ConfigData(
135129
available=[
@@ -147,4 +141,5 @@ def test_directory_input_filenames(tmp_path, aiida_localhost):
147141
core_wf = Workflow.from_config_workflow(config_workflow=config_wf)
148142
aiida_wf = AiidaWorkGraph(core_workflow=core_wf)
149143

150-
assert aiida_wf._workgraph.tasks[0].inputs.filenames.value == {"my_dir": dir_name}
144+
filenames = aiida_wf._workgraph.tasks[0].inputs.filenames.value # noqa: SLF001
145+
assert filenames == {"my_dir": dir_name}

0 commit comments

Comments
 (0)