Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions src/aiida_shell/calculations/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,39 @@ def handle_remote_data_nodes(
instructions = []

for key, node in remote_data_nodes.items():
remote_path = node.get_remote_path()

authinfo = node.get_authinfo()
transport = authinfo.get_transport()

with transport:
if not transport.path_exists(path=remote_path):
msg = f"Remote path `{remote_path}` for node with key `{key}` does not exist"
raise ValueError(msg)

remote_is_file = transport.isfile(remote_path)

# Resolve filename
if key in filenames:
instructions.append((computer.uuid, node.get_remote_path(), filenames[key]))
filename = filenames[key]
# For directory paths with explicit filenames, preserve original behavior
# (copy the entire directory, not just its contents)
if not remote_is_file:
source_path = remote_path
instructions.append((computer.uuid, source_path, filename))
continue
elif remote_is_file:
filename = pathlib.Path(remote_path).name
else:
filename = '.'

# Resolve source_path
if remote_is_file:
source_path = remote_path
else:
instructions.append((computer.uuid, f'{node.get_remote_path()}/*', '.'))
source_path = f"{remote_path}/*"

instructions.append((computer.uuid, source_path, filename))

if use_symlinks:
return [], instructions
Expand Down Expand Up @@ -480,7 +509,18 @@ def prepare_filenames(self, nodes: dict[str, SinglefileData], filenames: dict[st
f'node `{key}` contains the file `{f}` which overlaps with a reserved output filename.'
)
elif isinstance(node, RemoteData):
remote_path = node.get_remote_path()
filename = filenames.get(key, None)

authinfo = node.get_authinfo()
transport = authinfo.get_transport()

with transport:
remote_is_file = transport.isfile(remote_path)

if not filename and remote_is_file:
filename = pathlib.Path(remote_path).name
filenames[key] = filename
else:
continue

Expand Down
19 changes: 18 additions & 1 deletion tests/calculations/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,32 @@ def test_nodes_remote_data_filename(generate_calc_job, generate_code, tmp_path,
remote_path_b.mkdir()
(remote_path_a / 'file_a.txt').write_text('content a')
(remote_path_b / 'file_b.txt').write_text('content b')

remote_path_c = tmp_path / 'remote_c' / 'file_c.txt'
remote_path_d = tmp_path / 'remote_d' / 'file_d.txt'
remote_path_c.parent.mkdir()
remote_path_d.parent.mkdir()
remote_path_c.write_text('content c')
remote_path_d.write_text('content d')

remote_data_a = RemoteData(remote_path=str(remote_path_a.absolute()), computer=aiida_localhost)
remote_data_b = RemoteData(remote_path=str(remote_path_b.absolute()), computer=aiida_localhost)
remote_data_c = RemoteData(remote_path=str(remote_path_c.absolute()), computer=aiida_localhost)
remote_data_d = RemoteData(remote_path=str(remote_path_d.absolute()), computer=aiida_localhost)

inputs = {
'code': generate_code(),
'arguments': ['{remote_a}'],
'nodes': {
'remote_a': remote_data_a,
'remote_b': remote_data_b,
'remote_c': remote_data_c,
'remote_d': remote_data_d,
},
'filenames': {
'remote_a': 'target_remote',
'remote_c': 'target_remote_file',
},
'filenames': {'remote_a': 'target_remote'},
}
dirpath, calc_info = generate_calc_job('core.shell', inputs)

Expand All @@ -136,6 +151,8 @@ def test_nodes_remote_data_filename(generate_calc_job, generate_code, tmp_path,
assert sorted(calc_info.remote_copy_list) == [
(aiida_localhost.uuid, str(remote_path_a), 'target_remote'),
(aiida_localhost.uuid, str(remote_path_b / '*'), '.'),
(aiida_localhost.uuid, str(remote_path_c), 'target_remote_file'),
(aiida_localhost.uuid, str(remote_path_d), 'file_d.txt'),
]
assert sorted(p.name for p in dirpath.iterdir()) == []

Expand Down
21 changes: 21 additions & 0 deletions tests/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def test_nodes_remote_data(tmp_path, aiida_localhost, use_symlinks):

def test_nodes_remote_data_filename(tmp_path_factory, aiida_localhost):
"""Test copying contents of a ``RemoteData`` to specific subdirectory."""
# For `RemoteData` pointing to a directory
dirpath_remote = tmp_path_factory.mktemp('remote')
dirpath_source = dirpath_remote / 'source'
dirpath_source.mkdir()
Expand All @@ -185,6 +186,26 @@ def test_nodes_remote_data_filename(tmp_path_factory, aiida_localhost):
assert (dirpath_working / 'sub_directory' / 'source').is_dir()
assert (dirpath_working / 'sub_directory' / 'source' / 'file.txt').is_file()

# For `RemoteData` pointing to a file
dirpath_remote = tmp_path_factory.mktemp('remote_file')
dirpath_source = dirpath_remote / 'source'
dirpath_source.mkdir()
filepath_source = (dirpath_source / 'file.txt')
filepath_source.touch()
remote_data = RemoteData(remote_path=str(filepath_source), computer=aiida_localhost)

results, node = launch_shell_job(
'echo',
arguments=['{remote}'],
nodes={'remote': remote_data},
filenames={'remote': 'sub_file'},
)
assert node.is_finished_ok
assert results['stdout'].get_content().strip() == 'sub_file'
dirpath_working = pathlib.Path(node.outputs.remote_folder.get_remote_path())
assert (dirpath_working).is_dir()
assert (dirpath_working / 'sub_file').is_file()


def test_nodes_base_types():
"""Test a shellfunction that specifies positional CLI arguments that are interpolated by the ``kwargs``."""
Expand Down