Skip to content

Commit a805abe

Browse files
committed
Don't add path for Steps/DAG artifact inputs
* Refactors _get_artifacts and _get_inputs to with add_missing_path input var, with default = False to match the _get_outputs equivalent Signed-off-by: Elliot Gunton <[email protected]>
1 parent 7a79a66 commit a805abe

File tree

7 files changed

+23
-27
lines changed

7 files changed

+23
-27
lines changed

docs/examples/workflows/experimental/new_dag_decorator_artifacts.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@
119119
inputs:
120120
artifacts:
121121
- name: artifact_a
122-
path: /tmp/hera-inputs/artifacts/artifact_a
123122
- name: artifact_b
124-
path: /tmp/hera-inputs/artifacts/artifact_b
125123
name: worker
126124
outputs:
127125
artifacts:

examples/workflows/experimental/new-dag-decorator-artifacts.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ spec:
5555
inputs:
5656
artifacts:
5757
- name: artifact_a
58-
path: /tmp/hera-inputs/artifacts/artifact_a
5958
- name: artifact_b
60-
path: /tmp/hera-inputs/artifacts/artifact_b
6159
name: worker
6260
outputs:
6361
artifacts:

src/hera/workflows/_meta_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ def container_decorator(func: Callable[FuncIns, FuncR]) -> Callable:
698698
if len(func_inputs) >= 1:
699699
input_arg = list(func_inputs.values())[0].annotation
700700
if issubclass(input_arg, (InputV1, InputV2)):
701-
inputs = input_arg._get_inputs()
701+
inputs = input_arg._get_inputs(add_missing_path=True)
702702

703703
func_return = signature.return_annotation
704704
outputs = []

src/hera/workflows/io/_io_mixins.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,19 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet
106106
return parameters
107107

108108
@classmethod
109-
def _get_artifacts(cls) -> List[Artifact]:
109+
def _get_artifacts(cls, add_missing_path: bool = False) -> List[Artifact]:
110110
artifacts = []
111111

112112
for _, _, artifact in _construct_io_from_fields(cls):
113113
if isinstance(artifact, Artifact):
114-
if artifact.path is None:
114+
if add_missing_path and artifact.path is None:
115115
artifact.path = artifact._get_default_inputs_path()
116116
artifacts.append(artifact)
117117
return artifacts
118118

119119
@classmethod
120-
def _get_inputs(cls) -> List[Union[Artifact, Parameter]]:
121-
return cls._get_artifacts() + cls._get_parameters()
120+
def _get_inputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]:
121+
return cls._get_artifacts(add_missing_path) + cls._get_parameters()
122122

123123
@classmethod
124124
def _get_as_templated_arguments(cls) -> Self:

src/hera/workflows/script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ class will be used as inputs, rather than the class itself.
494494
else:
495495
parameters.extend(input_class._get_parameters())
496496

497-
artifacts.extend(input_class._get_artifacts())
497+
artifacts.extend(input_class._get_artifacts(add_missing_path=True))
498498

499499
elif param_or_artifact := get_workflow_annotation(func_param.annotation):
500500
if param_or_artifact.output:

tests/test_unit/test_decorators.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,26 @@ def test_dag_io_declaration():
3636

3737
assert len(model_workflow.spec.templates) == 1
3838

39-
template = model_workflow.spec.templates[0]
39+
dag_template = model_workflow.spec.templates[0]
4040

41-
assert template.inputs
42-
assert len(template.inputs.parameters) == 2
43-
assert template.inputs.parameters == [
41+
assert dag_template.inputs
42+
assert len(dag_template.inputs.parameters) == 2
43+
assert dag_template.inputs.parameters == [
4444
ModelParameter(name="basic_input_parameter"),
4545
ModelParameter(name="my-input-param"),
4646
]
47-
assert len(template.inputs.artifacts) == 1
48-
assert template.inputs.artifacts == [
49-
ModelArtifact(name="my-input-artifact", path="/tmp/hera-inputs/artifacts/my-input-artifact"),
47+
assert len(dag_template.inputs.artifacts) == 1
48+
assert dag_template.inputs.artifacts == [
49+
ModelArtifact(name="my-input-artifact")
5050
]
5151

52-
assert template.outputs
53-
assert len(template.outputs.parameters) == 2
54-
assert template.outputs.parameters == [
52+
assert dag_template.outputs
53+
assert len(dag_template.outputs.parameters) == 2
54+
assert dag_template.outputs.parameters == [
5555
ModelParameter(name="basic_output_parameter"),
5656
ModelParameter(name="my-output-param"),
5757
]
58-
assert template.outputs.artifacts == [
58+
assert dag_template.outputs.artifacts == [
5959
ModelArtifact(name="my-output-artifact"),
6060
]
6161

tests/test_unit/test_io_mixins.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,15 @@ class Foo(Input):
7979
foo: int
8080
bar: str = "a default"
8181

82-
assert Foo._get_artifacts() == []
82+
assert Foo._get_artifacts(add_missing_path=True) == []
8383

8484

8585
def test_get_artifacts_with_pydantic_annotations():
8686
class Foo(Input):
8787
foo: Annotated[int, Field(gt=0)]
8888
bar: Annotated[str, Field(max_length=10)] = "a default"
8989

90-
assert Foo._get_artifacts() == []
90+
assert Foo._get_artifacts(add_missing_path=True) == []
9191

9292

9393
def test_get_artifacts_annotated_with_name():
@@ -96,7 +96,7 @@ class Foo(Input):
9696
bar: Annotated[str, Parameter(name="b_ar")] = "a default"
9797
baz: Annotated[str, Artifact(name="b_az")]
9898

99-
assert Foo._get_artifacts() == [Artifact(name="b_az", path="/tmp/hera-inputs/artifacts/b_az")]
99+
assert Foo._get_artifacts(add_missing_path=True) == [Artifact(name="b_az", path="/tmp/hera-inputs/artifacts/b_az")]
100100

101101

102102
def test_get_artifacts_annotated_with_description():
@@ -105,7 +105,7 @@ class Foo(Input):
105105
bar: Annotated[str, Parameter(description="param bar")] = "a default"
106106
baz: Annotated[str, Artifact(description="artifact baz")]
107107

108-
assert Foo._get_artifacts() == [
108+
assert Foo._get_artifacts(add_missing_path=True) == [
109109
Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz", description="artifact baz")
110110
]
111111

@@ -114,7 +114,7 @@ def test_get_artifacts_annotated_with_path():
114114
class Foo(Input):
115115
baz: Annotated[str, Artifact(path="/tmp/hera-inputs/artifacts/bishbosh")]
116116

117-
assert Foo._get_artifacts() == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/bishbosh")]
117+
assert Foo._get_artifacts(add_missing_path=True) == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/bishbosh")]
118118

119119

120120
def test_get_artifacts_with_multiple_annotations():
@@ -123,7 +123,7 @@ class Foo(Input):
123123
bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default"
124124
baz: Annotated[str, Field(max_length=15), Artifact()]
125125

126-
assert Foo._get_artifacts() == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz")]
126+
assert Foo._get_artifacts(add_missing_path=True) == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz")]
127127

128128

129129
def test_get_as_arguments_unannotated():

0 commit comments

Comments
 (0)