2
2
from pathlib import Path
3
3
from typing import Any , Dict , Generic , List , TypeVar
4
4
5
+ import kfp .dsl
5
6
from loguru import logger
6
7
from pydantic import Field , ValidationError , computed_field , model_validator
7
8
from pydantic .functional_validators import ModelWrapValidatorHandler
8
9
from pydantic_core import PydanticCustomError
9
- from typing_extensions import Annotated
10
+ from typing_extensions import Annotated , _AnnotatedAlias
11
+
12
+ try :
13
+ from kfp .dsl import graph_component # since 2.1
14
+ except ImportError :
15
+ from kfp .components import graph_component # until 2.0.1
10
16
11
17
from deployer .constants import TEMP_LOCAL_PACKAGE_PATH
12
18
from deployer .pipeline_deployer import VertexPipelineDeployer
13
19
from deployer .utils .config import list_config_filepaths , load_config
14
20
from deployer .utils .exceptions import BadConfigError
15
21
from deployer .utils .logging import disable_logger
16
- from deployer .utils .models import CustomBaseModel , create_model_from_pipeline
22
+ from deployer .utils .models import CustomBaseModel , create_model_from_func
17
23
from deployer .utils .utils import import_pipeline_from_dir
18
24
19
25
PipelineConfigT = TypeVar ("PipelineConfigT" )
@@ -63,7 +69,7 @@ def populate_config_names(cls, data: Any) -> Any:
63
69
return data
64
70
65
71
@computed_field
66
- def pipeline (self ) -> Any :
72
+ def pipeline (self ) -> graph_component . GraphComponent :
67
73
"""Import pipeline"""
68
74
if getattr (self , "_pipeline" , None ) is None :
69
75
with disable_logger ("deployer.utils.utils" ):
@@ -101,7 +107,9 @@ def compile_pipeline(self):
101
107
def validate_configs (self ):
102
108
"""Validate configs against pipeline parameters definition"""
103
109
logger .debug (f"Validating configs for pipeline { self .pipeline_name } " )
104
- PipelineDynamicModel = create_model_from_pipeline (self .pipeline )
110
+ PipelineDynamicModel = create_model_from_func (
111
+ self .pipeline .pipeline_func , type_converter = _convert_artifact_type_to_str
112
+ )
105
113
ConfigsModel = ConfigsDynamicModel [PipelineDynamicModel ]
106
114
ConfigsModel .model_validate (
107
115
{"configs" : {x .name : {"config_path" : x } for x in self .config_paths }}
@@ -127,3 +135,16 @@ def _init_remove_temp_directory(self, handler: ModelWrapValidatorHandler) -> Any
127
135
shutil .rmtree (TEMP_LOCAL_PACKAGE_PATH )
128
136
129
137
return validated_self
138
+
139
+
140
+ def _convert_artifact_type_to_str (annotation : type ) -> type :
141
+ """Convert a kfp.dsl.Artifact type to a string.
142
+
143
+ This is mandatory for type checking, as kfp.dsl.Artifact types should be passed as strings
144
+ to VertexAI. See https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.PipelineJob
145
+ for details.
146
+ """ # noqa: E501
147
+ if isinstance (annotation , _AnnotatedAlias ):
148
+ if issubclass (annotation .__origin__ , kfp .dsl .Artifact ):
149
+ return str
150
+ return annotation
0 commit comments