Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
847c96e
Actual implementation changes
GeigerJ2 May 6, 2025
dabeed4
Add specific tests for filenames argument.
GeigerJ2 May 6, 2025
a2ca536
Add checks for computer existence for filenames
GeigerJ2 May 6, 2025
ee3cbff
fix tests
GeigerJ2 May 8, 2025
996ea7b
Replace None key with src
GeigerJ2 May 9, 2025
2092627
Add expected arguments list for comparison.
GeigerJ2 May 9, 2025
48318a2
Verify with nodes
GeigerJ2 May 9, 2025
5ac98bf
Add minimal CLI interface using typer and rich
GeigerJ2 May 9, 2025
23eeaaa
Merge remote-tracking branch 'upstream/main' into remote-submission
GeigerJ2 May 12, 2025
7c9ed55
Merge remote-tracking branch 'upstream/main' into remote-submission
GeigerJ2 May 12, 2025
9cbac1d
Merge branch 'main' into remote-submission
GeigerJ2 May 22, 2025
5febd4f
Uncomment out previous implementation and duplicate test
GeigerJ2 Jun 2, 2025
00d2eb9
Merge in CLI for easier development
GeigerJ2 Jun 2, 2025
4be222f
Implementation seems to work now
GeigerJ2 Jun 2, 2025
84677c7
.
GeigerJ2 Jun 2, 2025
47e361d
.
GeigerJ2 Jun 2, 2025
7d84dd4
.
GeigerJ2 Jun 2, 2025
31268f0
.
GeigerJ2 Jun 2, 2025
221cfc9
.
GeigerJ2 Jun 2, 2025
a51ab9a
.
GeigerJ2 Jun 2, 2025
c149a10
.
GeigerJ2 Jun 3, 2025
80948e6
.
GeigerJ2 Jun 3, 2025
c230667
hatch fmt and types:check pass
GeigerJ2 Jun 3, 2025
b4ca8ff
.
GeigerJ2 Jun 3, 2025
4807db6
.
GeigerJ2 Jun 3, 2025
6a42fb8
.
GeigerJ2 Jun 3, 2025
3676b14
Allow and properly resolve relative AvailableData src on localhost to…
GeigerJ2 Jun 3, 2025
5237cdf
.
GeigerJ2 Jun 3, 2025
7c3bd68
.
GeigerJ2 Jun 3, 2025
a565fb1
.
GeigerJ2 Jun 3, 2025
83aea88
Remove erroneously introduced output renaming.
GeigerJ2 Jun 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ dependencies = [
"termcolor",
"pygraphviz",
"lxml",
"f90nml"
"f90nml",
"aiida-shell>=0.8.1",
]
license = {file = "LICENSE"}

Expand All @@ -46,7 +47,7 @@ Changelog = "https://github.com/C2SM/Sirocco/blob/main/CHANGELOG.md"

[tool.pytest.ini_options]
# Configuration for [pytest](https://docs.pytest.org)
addopts = "--pdbcls=IPython.terminal.debugger:TerminalPdb"
addopts = "-s --pdbcls=IPython.terminal.debugger:TerminalPdb"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-s disables output capturing, and allows for setting breakpoints in test code

norecursedirs = "tests/cases"
markers = [
"slow: slow integration tests which are not recommended to run locally for normal development",
Expand All @@ -66,7 +67,7 @@ filterwarnings = [
source = ["sirocco"]

[tool.ruff]
include = ["src/*py", "tests/*py"]
include = ["src/*py", "tests/*py"] # PRCOMMENT: Do we want to run Ruff via CI on our test files??
target-version = "py310"

[tool.ruff.lint]
Expand All @@ -75,6 +76,14 @@ ignore = [
"TRY003", # write custom error messages for formatting
]

[tool.ruff.lint.per-file-ignores]
"tests/*py" = [
"SLF001", # Private member accessed
"S101", # Use of assert detected
"T201", # `print` found
"PLR2004", # Magic value used in comparison
]

## Hatch configurations

[tool.hatch.build.targets.sdist]
Expand Down Expand Up @@ -151,3 +160,7 @@ ignore_missing_imports = true
[[tool.mypy.overrides]]
module = ["aiida_workgraph.sockets.builtins"]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["termcolor._types"]
ignore_missing_imports = true
1 change: 1 addition & 0 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> AvailableData
data_class = AvailableData if isinstance(config, ConfigAvailableData) else GeneratedData
return data_class(
name=config.name,
computer=config.computer,
type=config.type,
src=config.src,
coordinates=coordinates,
Expand Down
11 changes: 10 additions & 1 deletion src/sirocco/parsing/yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,16 @@ class ConfigShellTaskSpecs:
plugin: ClassVar[Literal["shell"]] = "shell"
port_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"{PORT(\[sep=.+\])?::(.+?)}"), repr=False)
sep_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"\[sep=(.+)\]"), repr=False)
src: Path | None = None
src: Path | None = field(
default=None,
metadata={
"description": (
"If `src` not absolute, this ends up to be relative to the root directory of the config file."
"This should also be solved by registering `Code`s in AiiDA for the required scripts."
"See issues #60 and #127."
)
},
)
command: str
env_source_files: list[str] = field(default_factory=list)

Expand Down
81 changes: 75 additions & 6 deletions src/sirocco/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ def _execute(self, engine_process, args=None, kwargs=None, var_kwargs=None): #
# Workaround starts here
# This part is part of the workaround. We need to manually add the outputs from the task.
# Because kwargs are not populated with outputs
default_outputs = {"remote_folder", "remote_stash", "retrieved", "_outputs", "_wait", "stdout", "stderr"}
default_outputs = {
"remote_folder",
"remote_stash",
"retrieved",
"_outputs",
"_wait",
"stdout",
"stderr",
}
task_outputs = set(self.outputs._sockets.keys()) # noqa SLF001 # there so public accessor
task_outputs = task_outputs.union(set(inputs.pop("outputs", [])))
missing_outputs = task_outputs.difference(default_outputs)
Expand Down Expand Up @@ -97,6 +105,7 @@ def __init__(self, core_workflow: core.Workflow):
for task in self._core_workflow.tasks:
if isinstance(task, core.ShellTask):
self._set_shelljob_arguments(task)
self._set_shelljob_filenames(task)

# link wait on to workgraph tasks
for task in self._core_workflow.tasks:
Expand Down Expand Up @@ -184,7 +193,16 @@ def _add_aiida_input_data_node(self, data: core.Data):
except NotExistent as err:
msg = f"Could not find computer {data.computer!r} for input {data}."
raise ValueError(msg) from err
self._aiida_data_nodes[label] = aiida.orm.RemoteData(remote_path=data.src, label=label, computer=computer)
# `remote_path` must be str not PosixPath to be JSON-serializable
Copy link
Collaborator Author

@GeigerJ2 GeigerJ2 Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug in the code before this PR that surfaced when actually submitting.

# PRCOMMENT: Hack for now to make the tests pass
if computer.label == "localhost":
self._aiida_data_nodes[label] = aiida.orm.RemoteData(
remote_path=str(data_full_path), label=label, computer=computer
)
else:
self._aiida_data_nodes[label] = aiida.orm.RemoteData(
remote_path=str(data.src), label=label, computer=computer
)
elif data.type == "file":
self._aiida_data_nodes[label] = aiida.orm.SinglefileData(label=label, file=data_full_path)
elif data.type == "dir":
Expand Down Expand Up @@ -229,6 +247,8 @@ def _create_shell_task_node(self, task: core.ShellTask):
]
prepend_text = "\n".join([f"source {env_source_path}" for env_source_path in env_source_paths])
metadata["options"] = {"prepend_text": prepend_text}
# NOTE: Hardcoded for now, possibly make user-facing option
metadata["options"]["use_symlinks"] = True

## computer
if task.computer is not None:
Expand Down Expand Up @@ -283,7 +303,10 @@ def _link_input_node_to_shelltask(self, task: core.ShellTask, input_: core.Data)
socket = getattr(workgraph_task.inputs.nodes, f"{input_label}")
socket.value = self.data_from_core(input_)
elif isinstance(input_, core.GeneratedData):
self._workgraph.add_link(self.socket_from_core(input_), workgraph_task.inputs[f"nodes.{input_label}"])
self._workgraph.add_link(
self.socket_from_core(input_),
workgraph_task.inputs[f"nodes.{input_label}"],
)
else:
raise TypeError

Expand All @@ -293,21 +316,67 @@ def _link_wait_on_to_task(self, task: core.Task):
self.task_from_core(task).wait = [self.task_from_core(wt) for wt in task.wait_on]

def _set_shelljob_arguments(self, task: core.ShellTask):
"""set AiiDA ShellJob arguments by replacing port placeholders by aiida labels"""

"""Set AiiDA ShellJob arguments by replacing port placeholders with AiiDA labels."""
workgraph_task = self.task_from_core(task)
workgraph_task_arguments: SocketAny = workgraph_task.inputs.arguments

if workgraph_task_arguments is None:
msg = (
f"Workgraph task {workgraph_task.name!r} did not initialize arguments nodes in the workgraph "
f"before linking. This is a bug in the code, please contact developers."
)
raise ValueError(msg)

input_labels = {port: list(map(self.label_placeholder, task.inputs[port])) for port in task.inputs}
# Build input_labels dictionary for port resolution
input_labels: dict[str, list[str]] = {}
for port_name, input_list in task.inputs.items():
input_labels[port_name] = []
for input_ in input_list:
# Use the full AiiDA label as the placeholder content
input_label = self.get_aiida_label_from_graph_item(input_)
input_labels[port_name].append(f"{{{input_label}}}")

# Resolve the command with port placeholders replaced by input labels
_, arguments = self.split_cmd_arg(task.resolve_ports(input_labels))
workgraph_task_arguments.value = arguments

def _set_shelljob_filenames(self, task: core.ShellTask):
"""Set AiiDA ShellJob filenames for data entities, including parameterized data."""
filenames = {}
workgraph_task = self.task_from_core(task)

if not workgraph_task.inputs.filenames:
return

# Handle input files
for input_ in task.input_data_nodes():
input_label = self.get_aiida_label_from_graph_item(input_)

if task.computer and input_.computer and isinstance(input_, core.AvailableData):
# For RemoteData on the same computer, use just the filename
filename = Path(input_.src).name
filenames[input_.name] = filename
else:
# For other cases (including GeneratedData), we need to handle parameterized data
# Importantly, multiple data nodes with the same base name but different
# coordinates need unique filenames to avoid conflicts in the working directory

# Count how many inputs have the same base name
same_name_count = sum(1 for inp in task.input_data_nodes() if inp.name == input_.name)

if same_name_count > 1:
# Multiple data nodes with same base name - use full label as filename
# to ensure uniqueness in working directory
filename = input_label
else:
# Single data node with this name - can use simple filename
filename = Path(input_.src).name if hasattr(input_, "src") else input_.name

# The key in filenames dict should be the input label (what's used in nodes dict)
filenames[input_label] = filename

workgraph_task.inputs.filenames.value = filenames

def run(
self,
inputs: None | dict[str, Any] = None,
Expand Down
10 changes: 7 additions & 3 deletions tests/cases/parameters/config/config.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
start_date: &root_start_date '2026-01-01T00:00'
stop_date: &root_stop_date '2028-01-01T00:00'
start_date: &root_start_date "2026-01-01T00:00"
stop_date: &root_stop_date "2028-01-01T00:00"

cycles:
- bimonthly_tasks:
Expand Down Expand Up @@ -49,13 +49,17 @@ cycles:
inputs:
- analysis_foo_bar:
target_cycle:
lag: ['P0M', 'P6M']
lag: ["P0M", "P6M"]
port: None
outputs: [yearly_analysis]

tasks:
- icon:
plugin: shell
# PRCOMMENT
# Relative path -> except if this cannot be resolved to a registered code
# Probably either enforce absolute path, or provide a code argument
# See: https://github.com/C2SM/Sirocco/pull/153
src: scripts/icon.py
command: "icon.py --restart {PORT::restart} --init {PORT::init} --forcing {PORT::forcing}"
parameters: [foo, bar]
Expand Down
14 changes: 9 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, url: str, response: requests.Response):


def download_file(url: str, file_path: pathlib.Path):
response = requests.get(url)
response = requests.get(url) # noqa: S113 request-without-timeout
if not response.ok:
raise DownloadError(url, response)

Expand All @@ -48,7 +48,7 @@ def icon_grid_simple_path(pytestconfig):

@pytest.fixture
def icon_filepath_executable() -> str:
which_icon = subprocess.run(["which", "icon"], capture_output=True, check=False)
which_icon = subprocess.run(["which", "icon"], capture_output=True, check=False) # noqa: S607
if which_icon.returncode:
msg = "Could not find icon executable."
raise FileNotFoundError(msg)
Expand Down Expand Up @@ -87,7 +87,7 @@ def minimal_invert_task_io_config() -> models.ConfigWorkflow:
),
models.ConfigCycleTask(
name="task_a",
inputs=[models.ConfigCycleTaskInput(name="availalble", port="None")],
inputs=[models.ConfigCycleTaskInput(name="available", port="None")],
outputs=[models.ConfigCycleTaskOutput(name="output_a")],
),
],
Expand All @@ -99,7 +99,11 @@ def minimal_invert_task_io_config() -> models.ConfigWorkflow:
],
data=models.ConfigData(
available=[
models.ConfigAvailableData(name="availalble", type=models.DataType.FILE, src=pathlib.Path("foo.txt"))
models.ConfigAvailableData(
name="available",
type=models.DataType.FILE,
src=pathlib.Path("foo.txt"),
)
],
generated=[
models.ConfigGeneratedData(name="output_a", type=models.DataType.DIR, src=pathlib.Path("bar")),
Expand Down Expand Up @@ -154,7 +158,7 @@ def serialize_nml(config_paths: dict[str, pathlib.Path], workflow: workflow.Work

def pytest_configure(config):
if config.getoption("reserialize"):
print("Regenerating serialized references") # noqa: T201 # this is actual UX, not a debug print
print("Regenerating serialized references") # this is actual UX, not a debug print
for config_case in ALL_CONFIG_CASES:
config_paths = generate_config_paths(config_case)
wf = workflow.Workflow.from_config_file(str(config_paths["yml"]))
Expand Down
5 changes: 2 additions & 3 deletions tests/test_wc_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,20 @@ def test_icon():

# configs that are tested for running workgraph
@pytest.mark.slow
@pytest.mark.usefixtures("aiida_localhost")
@pytest.mark.parametrize(
"config_case",
[
"small",
"parameters",
],
)
def test_run_workgraph(config_case, config_paths, aiida_computer): # noqa: ARG001 # config_case is overridden
def test_run_workgraph(config_case, config_paths): # noqa: ARG001 # config_case is overridden
"""Tests end-to-end the parsing from file up to running the workgraph.

Automatically uses the aiida_profile fixture to create a new profile. Note to debug the test with your profile
please run this in a separate file as the profile is deleted after test finishes.
"""
# some configs reference computer "localhost" which we need to create beforehand
aiida_computer("localhost").store()

core_workflow = Workflow.from_config_file(str(config_paths["yml"]))
aiida_workflow = AiidaWorkGraph(core_workflow)
Expand Down
Loading
Loading