Skip to content

Commit

Permalink
Ignore func kwargs when running locally
Browse files Browse the repository at this point in the history
* `kwargs` will only contain Task/Step kwargs so if the function is being
called outside of a Hera context or during the decorator setup, we can
just drop the kwargs and call func(*args) (which will only contain a Hera Input)

Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton committed Oct 1, 2024
1 parent 19be598 commit df6819b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/hera/workflows/_meta_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,8 @@ def script_call_wrapper(*args, **kwargs) -> Union[FuncR, Step, Task, None]:

if _context.pieces:
return script_template.__call__(*args, **kwargs)
return func(*args, **kwargs)

return func(*args)

# Set the wrapped function to the original function so that we can use it later
script_call_wrapper.wrapped_function = func # type: ignore
Expand Down Expand Up @@ -732,7 +733,7 @@ def container_call_wrapper(*args, **kwargs) -> Union[FuncR, Step, Task, None]:

if _context.pieces:
return container_template.__call__(*args, **kwargs)
return func(*args, **kwargs)
return func(*args)

# Set the template name to the inferred name
container_call_wrapper.template_name = name # type: ignore
Expand Down Expand Up @@ -817,7 +818,7 @@ def call_wrapper(*args, **kwargs):

return self._create_subnode(subnode_name, func, template, *args, **kwargs)

return func(*args, **kwargs)
return func(*args)

call_wrapper.template_name = name # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion tests/test_unit/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_dag_task_auto_depends():
task_b = next(iter([t for t in dag_template.tasks if t.name == "task-b"]), None)
assert task_b.depends == "setup-task"

final_task = next(iter([t for t in dag_template.tasks if t.name == "final-task"]), None)
final_task = next(iter([t for t in dag_template.tasks if t.name == "final-task-name"]), None)
assert final_task.depends == "task-a && task-b"


Expand Down
2 changes: 1 addition & 1 deletion tests/workflow_decorators/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ def worker(worker_input: WorkerInput) -> WorkerOutput:
setup_task = setup()
task_a = concat(ConcatInput(word_a=worker_input.value_a, word_b=setup_task.environment_parameter))
task_b = concat(ConcatInput(word_a=worker_input.value_b, word_b=setup_task.result))
final_task = concat(ConcatInput(word_a=task_a.result, word_b=task_b.result))
final_task = concat(ConcatInput(word_a=task_a.result, word_b=task_b.result), name="final-task-name")

return WorkerOutput(value=final_task.result)

0 comments on commit df6819b

Please sign in to comment.