Skip to content

Commit d16e396

Browse files
committed
Actual implementation changes
1 parent 807a6fb commit d16e396

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ dependencies = [
3535
"termcolor",
3636
"pygraphviz",
3737
"lxml",
38-
"f90nml"
38+
"f90nml",
39+
"aiida-shell @ git+https://github.com/sphuber/aiida-shell.git@fix/105/handle-remote-data-argument-placeholders",
3940
]
4041
license = {file = "LICENSE"}
4142

@@ -76,6 +77,9 @@ ignore = [
7677

7778
## Hatch configurations
7879

80+
[tool.hatch.metadata]
81+
allow-direct-references = true
82+
7983
[tool.hatch.build.targets.sdist]
8084
include = [
8185
"src/sirocco/",
@@ -148,3 +152,7 @@ ignore_missing_imports = true
148152
[[tool.mypy.overrides]]
149153
module = ["aiida_workgraph.sockets.builtins"]
150154
ignore_missing_imports = true
155+
156+
[[tool.mypy.overrides]]
157+
module = ["termcolor._types"]
158+
ignore_missing_imports = true

src/sirocco/core/graph_items.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> AvailableData
5151
data_class = AvailableData if isinstance(config, ConfigAvailableData) else GeneratedData
5252
return data_class(
5353
name=config.name,
54+
computer=config.computer,
5455
type=config.type,
5556
src=config.src,
5657
coordinates=coordinates,

src/sirocco/parsing/yaml_data_models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,14 @@ class ConfigShellTaskSpecs:
279279
plugin: ClassVar[Literal["shell"]] = "shell"
280280
port_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"{PORT(\[sep=.+\])?::(.+?)}"), repr=False)
281281
sep_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"\[sep=(.+)\]"), repr=False)
282-
src: str | None = None
282+
src: str | None = Field(
283+
default=None,
284+
description=(
285+
"If `src` not absolute, this ends up to be relative to the root directory of the config file."
286+
"This should also be solved by registering `Code`s in AiiDA for the required scripts."
287+
"See issue #127 and #60"
288+
),
289+
)
283290
command: str
284291
env_source_files: list[str] = field(default_factory=list)
285292

src/sirocco/workgraph.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,15 @@ def _prepare_for_shell_task(task: dict, inputs: dict) -> dict:
5656
# Workaround starts here
5757
# This part is part of the workaround. We need to manually add the outputs from the task.
5858
# Because kwargs are not populated with outputs
59-
default_outputs = {"remote_folder", "remote_stash", "retrieved", "_outputs", "_wait", "stdout", "stderr"}
59+
default_outputs = {
60+
"remote_folder",
61+
"remote_stash",
62+
"retrieved",
63+
"_outputs",
64+
"_wait",
65+
"stdout",
66+
"stderr",
67+
}
6068
task_outputs = set(task["outputs"].keys())
6169
task_outputs = task_outputs.union(set(inputs.pop("outputs", [])))
6270
missing_outputs = task_outputs.difference(default_outputs)
@@ -105,6 +113,7 @@ def __init__(self, core_workflow: core.Workflow):
105113
for task in self._core_workflow.tasks:
106114
if isinstance(task, core.ShellTask):
107115
self._set_shelljob_arguments(task)
116+
self._set_shelljob_filenames(task)
108117

109118
# link wait on to workgraph tasks
110119
for task in self._core_workflow.tasks:
@@ -238,6 +247,8 @@ def _create_shell_task_node(self, task: core.ShellTask):
238247
]
239248
prepend_text = "\n".join([f"source {env_source_path}" for env_source_path in env_source_paths])
240249
metadata["options"] = {"prepend_text": prepend_text}
250+
# NOTE: Hardcoded for now, possibly make user-facing option
251+
metadata["options"]["use_symlinks"] = True
241252

242253
## computer
243254
if task.computer is not None:
@@ -292,7 +303,10 @@ def _link_input_node_to_shelltask(self, task: core.ShellTask, input_: core.Data)
292303
socket = getattr(workgraph_task.inputs.nodes, f"{input_label}")
293304
socket.value = self.data_from_core(input_)
294305
elif isinstance(input_, core.GeneratedData):
295-
self._workgraph.add_link(self.socket_from_core(input_), workgraph_task.inputs[f"nodes.{input_label}"])
306+
self._workgraph.add_link(
307+
self.socket_from_core(input_),
308+
workgraph_task.inputs[f"nodes.{input_label}"],
309+
)
296310
else:
297311
raise TypeError
298312

@@ -317,6 +331,23 @@ def _set_shelljob_arguments(self, task: core.ShellTask):
317331
_, arguments = self.split_cmd_arg(task.resolve_ports(input_labels))
318332
workgraph_task_arguments.value = arguments
319333

334+
def _set_shelljob_filenames(self, task: core.ShellTask):
335+
"""set AiiDA ShellJob filenames for AvailableData entities"""
336+
337+
filenames = {}
338+
339+
for input_ in task.input_data_nodes():
340+
# Some empty lists appear here?
341+
if not input_:
342+
continue
343+
344+
core_input = input_[0]
345+
if isinstance(core_input, core.AvailableData):
346+
filenames[core_input.name] = Path(core_input.src).name
347+
348+
workgraph_task = self.task_from_core(task)
349+
workgraph_task.inputs.filenames.value = filenames
350+
320351
def run(
321352
self,
322353
inputs: None | dict[str, Any] = None,

0 commit comments

Comments
 (0)