Skip to content

Commit 90f2d6a

Browse files
authored
Add support of icon task in workgraph module (#143)
Adds a new test case `small-icon` and renamed test case `small` to `small-shell`. The `small-icon` test case is a copy of `small` replacing the usage of the dummy icon scripts with real icon. Because the `small` test case increases coverage of `ShellTask` it is kept. - Add dispatching for `IconTask` in multiple functions in responsible for creating the WorkGraph. - Pass `computer` argument to `core.Task` since it is required for the creation of a `IconCaclculation`. In the `small-shell` config the usage of `computer` has been removed temporarly until PR #136 fixes the usage. - Changes in `core.AvailableData`: - Use `config_rootdir` to resolve location of data if relative. - The `src` member is now compulsory and validated. This change required to implement the `from_config` constructor for `AvailableData` and `GeneratedData` as the validation does not happen for `GeneratedData`. - Fixing how `is_restart` is determined: It was not considering the data structure correctly data items. It was only checking the existance of the restart key in the input items, it was, however, not considering that the data structure still lists the input item with an empty list when the `when` keyword is used. Now it validates correctly to `False`. - In tests we use now `pytest.fixture.usefixtures` when a fixture is not directly used but is required to be executed - Exclude `test/cases/*` in type check as it contains dummy python scripts - Update the default options of `hatch test` to run without icon - Rename pytest fixture `icon_grid_simple_path` to `icon_grid_path`
1 parent c65aff7 commit 90f2d6a

39 files changed

+616
-251
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"pydantic",
3232
"ruamel.yaml",
3333
"aiida-core>=2.5",
34+
"aiida-icon>=0.4.0",
3435
"aiida-workgraph==0.5.2",
3536
"termcolor",
3637
"pygraphviz",
@@ -95,7 +96,7 @@ extra-dependencies = [
9596
"ipdb"
9697
]
9798
default-args = []
98-
extra-args = ["--doctest-modules", '-m not slow', '-m not requires_icon']
99+
extra-args = ["--doctest-modules", "-m", "not slow and not requires_icon"]
99100

100101
[[tool.hatch.envs.hatch-test.matrix]]
101102
python = ["3.12"]
@@ -135,7 +136,7 @@ extra-dependencies = [
135136
]
136137

137138
[tool.hatch.envs.types.scripts]
138-
check = "mypy --no-incremental {args:.}"
139+
check = "mypy --exclude 'tests/cases/*' --no-incremental {args:.}"
139140

140141
[[tool.mypy.overrides]]
141142
module = ["isoduration", "isoduration.*"]

src/sirocco/core/_tasks/icon_task.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class IconTask(models.ConfigIconTaskSpecs, Task):
2222

2323
def __post_init__(self):
2424
super().__post_init__()
25-
2625
# detect master namelist
2726
master_namelist = None
2827
for namelist in self.namelists:
@@ -33,6 +32,7 @@ def __post_init__(self):
3332
msg = f"Failed to read master namelists. Could not find {self._MASTER_NAMELIST_NAME!r} in namelists {self.namelists}"
3433
raise ValueError(msg)
3534
self._master_namelist = master_namelist
35+
self.src = self._validate_src(self.src, self.config_rootdir)
3636

3737
# retrieve model namelist name from master namelist
3838
if (master_model_nml := self._master_namelist.namelist.get(self._MASTER_MODEL_NML_SECTION, None)) is None:
@@ -64,7 +64,8 @@ def model_namelist(self) -> NamelistFile:
6464
@property
6565
def is_restart(self) -> bool:
6666
"""Check if the icon task starts from the restart file."""
67-
return self._AIIDA_ICON_RESTART_FILE_PORT_NAME in self.inputs
67+
# restart port must be present and nonempty
68+
return bool(self.inputs.get(self._AIIDA_ICON_RESTART_FILE_PORT_NAME, False))
6869

6970
def update_icon_namelists_from_workflow(self):
7071
if not isinstance(self.cycle_point, DateCyclePoint):
@@ -115,3 +116,18 @@ def build_from_config(cls: type[Self], config: models.ConfigTask, **kwargs: Any)
115116
)
116117
self.update_icon_namelists_from_workflow()
117118
return self
119+
120+
@staticmethod
121+
def _validate_src(config_src: Path, config_rootdir: Path | None = None) -> Path:
122+
if config_rootdir is None and not config_src.is_absolute():
123+
msg = f"Cannot specify relative path {config_src} for namelist while the rootdir is None"
124+
raise ValueError(msg)
125+
126+
src = config_src if config_rootdir is None else (config_rootdir / config_src)
127+
if not src.exists():
128+
msg = f"Icon executable in path {src} does not exist."
129+
raise FileNotFoundError(msg)
130+
if not src.is_file():
131+
msg = f"Icon executable in path {src} is not a file."
132+
raise OSError(msg)
133+
return src

src/sirocco/core/graph_items.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,53 @@ class Data(ConfigBaseDataSpecs, GraphItem):
4545
color: ClassVar[str] = field(default="light_blue", repr=False)
4646

4747
@classmethod
48-
def from_config(cls, config: ConfigBaseData, coordinates: dict) -> AvailableData | GeneratedData:
48+
def from_config(
49+
cls, config: ConfigBaseData, config_rootdir: Path, coordinates: dict
50+
) -> AvailableData | GeneratedData:
4951
data_class = AvailableData if isinstance(config, ConfigAvailableData) else GeneratedData
50-
return data_class(
52+
53+
return data_class.from_config(
54+
config=config,
55+
config_rootdir=config_rootdir,
56+
coordinates=coordinates,
57+
)
58+
59+
60+
class AvailableData(Data):
61+
src: Path
62+
63+
@classmethod
64+
def from_config(cls, config: ConfigBaseData, config_rootdir: Path, coordinates: dict) -> Self:
65+
src = cls._validate_src(config.src, config_rootdir)
66+
return cls(
5167
name=config.name,
5268
computer=config.computer,
5369
type=config.type,
54-
src=config.src,
70+
src=src,
5571
coordinates=coordinates,
5672
)
5773

74+
@staticmethod
75+
def _validate_src(config_src: Path | None, config_rootdir: Path | None = None) -> Path | None:
76+
if config_src is None:
77+
return None
78+
if config_rootdir is None and not config_src.is_absolute():
79+
msg = f"Cannot specify relative path {config_src} for namelist while the rootdir is None"
80+
raise ValueError(msg)
5881

59-
class AvailableData(Data):
60-
pass
82+
return config_src if config_rootdir is None else (config_rootdir / config_src)
6183

6284

6385
class GeneratedData(Data):
64-
pass
86+
@classmethod
87+
def from_config(cls, config: ConfigBaseData, config_rootdir: Path, coordinates: dict) -> Self: # noqa: ARG003 # we need to keep same signature as for AvailableData
88+
return cls(
89+
name=config.name,
90+
computer=config.computer,
91+
type=config.type,
92+
src=config.src,
93+
coordinates=coordinates,
94+
)
6595

6696

6797
@dataclass(kw_only=True)
@@ -92,9 +122,15 @@ def __init_subclass__(cls, **kwargs):
92122
def input_data_nodes(self) -> Iterator[Data]:
93123
yield from chain(*self.inputs.values())
94124

125+
def input_data_items(self) -> Iterator[tuple[str, Data]]:
126+
yield from ((key, value) for key, values in self.inputs.items() for value in values)
127+
95128
def output_data_nodes(self) -> Iterator[Data]:
96129
yield from chain(*self.outputs.values())
97130

131+
def output_data_items(self) -> Iterator[tuple[str | None, Data]]:
132+
yield from ((key, value) for key, values in self.outputs.items() for value in values)
133+
98134
@classmethod
99135
def from_config(
100136
cls: type[Self],

src/sirocco/core/workflow.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ def iter_coordinates(cycle_point: CyclePoint, param_refs: list[str]) -> Iterator
5656
# 1 - create availalbe data nodes
5757
for available_data_config in config_data.available:
5858
for coordinates in iter_coordinates(OneOffPoint(), available_data_config.parameters):
59-
self.data.add(Data.from_config(config=available_data_config, coordinates=coordinates))
59+
self.data.add(
60+
Data.from_config(
61+
config=available_data_config, config_rootdir=config_rootdir, coordinates=coordinates
62+
)
63+
)
6064

6165
# 2 - create output data nodes
6266
for cycle_config in config_cycles:
@@ -65,7 +69,11 @@ def iter_coordinates(cycle_point: CyclePoint, param_refs: list[str]) -> Iterator
6569
for data_ref in task_ref.outputs:
6670
data_config = config_data_dict[data_ref.name]
6771
for coordinates in iter_coordinates(cycle_point, data_config.parameters):
68-
self.data.add(Data.from_config(config=data_config, coordinates=coordinates))
72+
self.data.add(
73+
Data.from_config(
74+
config=data_config, config_rootdir=config_rootdir, coordinates=coordinates
75+
)
76+
)
6977

7078
# 3 - create cycles and tasks
7179
for cycle_config in config_cycles:

src/sirocco/parsing/yaml_data_models.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,13 @@ def merge_path_key(cls, data: Any) -> dict[str, Any]:
441441
@dataclass(kw_only=True)
442442
class ConfigIconTaskSpecs:
443443
plugin: ClassVar[Literal["icon"]] = "icon"
444+
# PRCOMMENT this is later resolved to an absolute path so we cannot use it for serialization
445+
# as the tests would fail. Since the main purpose of the repr is for regression tests
446+
# I disable its prnting here
447+
src: Path = field(repr=False)
444448

445449

446-
class ConfigIconTask(ConfigBaseTask):
450+
class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs):
447451
"""Class representing an ICON task configuration from a workflow file
448452
449453
Examples:
@@ -460,6 +464,7 @@ class ConfigIconTask(ConfigBaseTask):
460464
... - path/to/case_nml:
461465
... block_1:
462466
... param_name: param_value
467+
... src: path/to/icon
463468
... '''
464469
... )
465470
>>> icon_task_cfg = validate_yaml_content(ConfigIconTask, snippet)
@@ -487,7 +492,7 @@ class DataType(enum.StrEnum):
487492
@dataclass(kw_only=True)
488493
class ConfigBaseDataSpecs:
489494
type: DataType
490-
src: Path
495+
src: Path | None = None
491496
format: str | None = None
492497
computer: str | None = None
493498

@@ -522,7 +527,7 @@ class ConfigBaseData(_NamedBaseModel, ConfigBaseDataSpecs):
522527

523528

524529
class ConfigAvailableData(ConfigBaseData):
525-
pass
530+
src: Path
526531

527532

528533
class ConfigGeneratedData(ConfigBaseData):

src/sirocco/workgraph.py

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import functools
4+
import io
5+
import uuid
46
from pathlib import Path
57
from typing import TYPE_CHECKING, Any, TypeAlias
68

@@ -9,6 +11,7 @@
911
import aiida_workgraph # type: ignore[import-untyped] # does not have proper typing and stubs
1012
import aiida_workgraph.tasks.factory.shelljob_task # type: ignore[import-untyped] # is only for a workaround
1113
from aiida.common.exceptions import NotExistent
14+
from aiida_icon.calculations import IconCalculation
1215

1316
from sirocco import core
1417

@@ -93,13 +96,13 @@ def __init__(self, core_workflow: core.Workflow):
9396
for task in self._core_workflow.tasks:
9497
self.create_task_node(task)
9598
# Create and link corresponding output sockets
96-
for output in task.output_data_nodes():
97-
self._link_output_node_to_task(task, output)
99+
for port, output in task.output_data_items():
100+
self._link_output_node_to_task(task, port, output)
98101

99102
# link input nodes to workgraph tasks
100103
for task in self._core_workflow.tasks:
101-
for input_ in task.input_data_nodes():
102-
self._link_input_node_to_task(task, input_)
104+
for port, input_ in task.input_data_items():
105+
self._link_input_node_to_task(task, port, input_)
103106

104107
# set shelljob arguments
105108
for task in self._core_workflow.tasks:
@@ -180,7 +183,7 @@ def _add_available_data(self):
180183
if isinstance(data, core.AvailableData):
181184
self._add_aiida_input_data_node(data)
182185

183-
def _add_aiida_input_data_node(self, data: core.Data):
186+
def _add_aiida_input_data_node(self, data: core.AvailableData):
184187
"""
185188
Create an `aiida.orm.Data` instance from the provided graph item.
186189
"""
@@ -267,23 +270,76 @@ def _create_shell_task_node(self, task: core.ShellTask):
267270

268271
self._aiida_task_nodes[label] = workgraph_task
269272

270-
def _link_output_node_to_task(self, task: core.Task, output: core.Data):
273+
@create_task_node.register
274+
def _create_icon_task_node(self, task: core.IconTask):
275+
task_label = self.get_aiida_label_from_graph_item(task)
276+
277+
try:
278+
# PRCOMMENT move to parsing? But then it has aiida logic
279+
computer = aiida.orm.Computer.collection.get(label=task.computer)
280+
except NotExistent as err:
281+
msg = f"Could not find computer {task.computer!r} in AiiDA database. One needs to create and configure the computer before running a workflow."
282+
raise ValueError(msg) from err
283+
284+
label_uuid = str(uuid.uuid4())
285+
icon_code = aiida.orm.InstalledCode(
286+
label=f"icon-{label_uuid}",
287+
description="aiida_icon",
288+
default_calc_job_plugin="icon.icon",
289+
computer=computer,
290+
filepath_executable=str(task.src),
291+
with_mpi=False,
292+
).store()
293+
294+
builder = IconCalculation.get_builder()
295+
builder.code = icon_code
296+
297+
task.update_icon_namelists_from_workflow()
298+
299+
with io.StringIO() as buffer:
300+
task.master_namelist.namelist.write(buffer)
301+
buffer.seek(0)
302+
builder.master_namelist = aiida.orm.SinglefileData(buffer, task.master_namelist.name)
303+
304+
with io.StringIO() as buffer:
305+
task.model_namelist.namelist.write(buffer)
306+
buffer.seek(0)
307+
builder.model_namelist = aiida.orm.SinglefileData(buffer, task.model_namelist.name)
308+
309+
self._aiida_task_nodes[task_label] = self._workgraph.add_task(builder)
310+
311+
@functools.singledispatchmethod
312+
def _link_output_node_to_task(self, task: core.Task, port: str, output: core.Data): # noqa: ARG002
313+
"""Dispatch linking input to task based on task type."""
314+
315+
msg = f"method not implemented for task type {type(task)}"
316+
raise NotImplementedError(msg)
317+
318+
@_link_output_node_to_task.register
319+
def _link_output_node_to_shell_task(self, task: core.ShellTask, _: str, output: core.Data):
271320
"""Links the output to the workgraph task."""
272321

273322
workgraph_task = self.task_from_core(task)
274323
output_label = self.get_aiida_label_from_graph_item(output)
275324
output_socket = workgraph_task.add_output("workgraph.any", str(output.src))
276325
self._aiida_socket_nodes[output_label] = output_socket
277326

327+
@_link_output_node_to_task.register
328+
def _link_output_node_to_icon_task(self, task: core.IconTask, port: str, output: core.Data):
329+
workgraph_task = self.task_from_core(task)
330+
output_label = self.get_aiida_label_from_graph_item(output)
331+
output_socket = workgraph_task.outputs._sockets.get(port) # noqa SLF001 # there so public accessor
332+
self._aiida_socket_nodes[output_label] = output_socket
333+
278334
@functools.singledispatchmethod
279-
def _link_input_node_to_task(self, task: core.Task, input_: core.Data): # noqa: ARG002
335+
def _link_input_node_to_task(self, task: core.Task, port: str, input_: core.Data): # noqa: ARG002
280336
""" "Dispatch linking input to task based on task type"""
281337

282338
msg = f"method not implemented for task type {type(task)}"
283339
raise NotImplementedError(msg)
284340

285341
@_link_input_node_to_task.register
286-
def _link_input_node_to_shelltask(self, task: core.ShellTask, input_: core.Data):
342+
def _link_input_node_to_shell_task(self, task: core.ShellTask, _: str, input_: core.Data):
287343
"""Links the input to the workgraph shell task."""
288344

289345
workgraph_task = self.task_from_core(task)
@@ -305,6 +361,20 @@ def _link_input_node_to_shelltask(self, task: core.ShellTask, input_: core.Data)
305361
else:
306362
raise TypeError
307363

364+
@_link_input_node_to_task.register
365+
def _link_input_node_to_icon_task(self, task: core.IconTask, port: str, input_: core.Data):
366+
"""Links the input to the workgraph shell task."""
367+
368+
workgraph_task = self.task_from_core(task)
369+
370+
# resolve data
371+
if isinstance(input_, core.AvailableData):
372+
setattr(workgraph_task.inputs, f"{port}", self.data_from_core(input_))
373+
elif isinstance(input_, core.GeneratedData):
374+
setattr(workgraph_task.inputs, f"{port}", self.socket_from_core(input_))
375+
else:
376+
raise TypeError
377+
308378
def _link_wait_on_to_task(self, task: core.Task):
309379
"""link wait on tasks to workgraph task"""
310380

@@ -349,7 +419,7 @@ def _set_shelljob_filenames(self, task: core.ShellTask):
349419

350420
if task.computer and input_.computer and isinstance(input_, core.AvailableData):
351421
# For RemoteData on the same computer, use just the filename
352-
filename = Path(input_.src).name
422+
filename = input_.src.name
353423
filenames[input_.name] = filename
354424
else:
355425
# For other cases (including GeneratedData), we need to handle parameterized data
@@ -368,7 +438,7 @@ def _set_shelljob_filenames(self, task: core.ShellTask):
368438
filename = input_label
369439
else:
370440
# Single data node with this name - can use simple filename
371-
filename = Path(input_.src).name if hasattr(input_, "src") else input_.name
441+
filename = input_.src.name if input_.src is not None else input_.name
372442

373443
# The key in filenames dict should be the input label (what's used in nodes dict)
374444
filenames[input_label] = filename

tests/cases/large/config/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ tasks:
114114
mount_point: runtime/mount/point
115115
- icon:
116116
plugin: icon
117+
src: ./ICON/bin/icon
117118
nodes: 40
118119
walltime: 23:59:59
119120
namelists:

tests/cases/large/data/ICON_namelists/icon_master.namelist_2025-01-01_00:00:00

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
&master_nml
2-
lrestart = .true.
3-
read_restart_namelists = .true.
2+
lrestart = .false.
3+
read_restart_namelists = .false.
44
/
55

66
&master_time_control_nml

0 commit comments

Comments
 (0)