1
1
#!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
# Copyright (c) 2024 Oracle and/or its affiliates.
4
3
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
"""AQUA utils and constants."""
5
+
6
6
import asyncio
7
7
import base64
8
8
import json
19
19
import oci
20
20
from oci .data_science .models import JobRun , Model
21
21
22
- from ads .aqua .common .enums import RqsAdditionalDetails
22
+ from ads .aqua .common .enums import (
23
+ InferenceContainerParamType ,
24
+ InferenceContainerType ,
25
+ RqsAdditionalDetails ,
26
+ )
23
27
from ads .aqua .common .errors import (
24
28
AquaFileNotFoundError ,
25
29
AquaRuntimeError ,
26
30
AquaValueError ,
27
31
)
28
- from ads .aqua .constants import *
32
+ from ads .aqua .constants import (
33
+ AQUA_GA_LIST ,
34
+ COMPARTMENT_MAPPING_KEY ,
35
+ CONSOLE_LINK_RESOURCE_TYPE_MAPPING ,
36
+ CONTAINER_INDEX ,
37
+ MAXIMUM_ALLOWED_DATASET_IN_BYTE ,
38
+ MODEL_BY_REFERENCE_OSS_PATH_KEY ,
39
+ SERVICE_MANAGED_CONTAINER_URI_SCHEME ,
40
+ SUPPORTED_FILE_FORMATS ,
41
+ TGI_INFERENCE_RESTRICTED_PARAMS ,
42
+ UNKNOWN ,
43
+ UNKNOWN_JSON_STR ,
44
+ VLLM_INFERENCE_RESTRICTED_PARAMS ,
45
+ )
29
46
from ads .aqua .data import AquaResourceIdentifier
30
47
from ads .common .auth import default_signer
31
48
from ads .common .decorator .threaded import threaded
@@ -74,15 +91,15 @@ def get_status(evaluation_status: str, job_run_status: str = None):
74
91
75
92
status = LifecycleStatus .UNKNOWN
76
93
if evaluation_status == Model .LIFECYCLE_STATE_ACTIVE :
77
- if (
78
- job_run_status == JobRun .LIFECYCLE_STATE_IN_PROGRESS
79
- or job_run_status == JobRun .LIFECYCLE_STATE_ACCEPTED
80
- ) :
94
+ if job_run_status in {
95
+ JobRun .LIFECYCLE_STATE_IN_PROGRESS ,
96
+ JobRun .LIFECYCLE_STATE_ACCEPTED ,
97
+ } :
81
98
status = JobRun .LIFECYCLE_STATE_IN_PROGRESS
82
- elif (
83
- job_run_status == JobRun .LIFECYCLE_STATE_FAILED
84
- or job_run_status == JobRun .LIFECYCLE_STATE_NEEDS_ATTENTION
85
- ) :
99
+ elif job_run_status in {
100
+ JobRun .LIFECYCLE_STATE_FAILED ,
101
+ JobRun .LIFECYCLE_STATE_NEEDS_ATTENTION ,
102
+ } :
86
103
status = JobRun .LIFECYCLE_STATE_FAILED
87
104
else :
88
105
status = job_run_status
@@ -199,10 +216,7 @@ def read_file(file_path: str, **kwargs) -> str:
199
216
@threaded ()
200
217
def load_config (file_path : str , config_file_name : str , ** kwargs ) -> dict :
201
218
artifact_path = f"{ file_path .rstrip ('/' )} /{ config_file_name } "
202
- if artifact_path .startswith ("oci://" ):
203
- signer = default_signer ()
204
- else :
205
- signer = {}
219
+ signer = default_signer () if artifact_path .startswith ("oci://" ) else {}
206
220
config = json .loads (
207
221
read_file (file_path = artifact_path , auth = signer , ** kwargs ) or UNKNOWN_JSON_STR
208
222
)
@@ -448,7 +462,7 @@ def _build_resource_identifier(
448
462
449
463
450
464
def _get_experiment_info (
451
- model : Union [oci .resource_search .models .ResourceSummary , DataScienceModel ]
465
+ model : Union [oci .resource_search .models .ResourceSummary , DataScienceModel ],
452
466
) -> tuple :
453
467
"""Returns ocid and name of the experiment."""
454
468
return (
@@ -609,7 +623,7 @@ def extract_id_and_name_from_tag(tag: str):
609
623
base_model_name = UNKNOWN
610
624
try :
611
625
base_model_ocid , base_model_name = tag .split ("#" )
612
- except :
626
+ except Exception :
613
627
pass
614
628
615
629
if not (is_valid_ocid (base_model_ocid ) and base_model_name ):
@@ -646,7 +660,7 @@ def get_resource_name(ocid: str) -> str:
646
660
try :
647
661
resource = query_resource (ocid , return_all = False )
648
662
name = resource .display_name if resource else UNKNOWN
649
- except :
663
+ except Exception :
650
664
name = UNKNOWN
651
665
return name
652
666
@@ -670,8 +684,8 @@ def get_model_by_reference_paths(model_file_description: dict):
670
684
671
685
if not models :
672
686
raise AquaValueError (
673
- f "Model path is not available in the model json artifact. "
674
- f "Please check if the model created by reference has the correct artifact."
687
+ "Model path is not available in the model json artifact. "
688
+ "Please check if the model created by reference has the correct artifact."
675
689
)
676
690
677
691
if len (models ) > 0 :
@@ -848,3 +862,46 @@ def copy_model_config(artifact_path: str, os_path: str, auth: dict = None):
848
862
except Exception as ex :
849
863
logger .debug (ex )
850
864
logger .debug (f"Failed to copy config folder from { artifact_path } to { os_path } ." )
865
+
866
+
867
+ def get_container_params_type (container_type_name : str ) -> str :
868
+ """The utility function accepts the deployment container type name and returns the corresponding params name.
869
+ Parameters
870
+ ----------
871
+ container_type_name: str
872
+ type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
873
+
874
+ Returns
875
+ -------
876
+ InferenceContainerParamType value
877
+
878
+ """
879
+ # check substring instead of direct match in case container_type_name changes in the future
880
+ if InferenceContainerType .CONTAINER_TYPE_VLLM in container_type_name .lower ():
881
+ return InferenceContainerParamType .PARAM_TYPE_VLLM
882
+ elif InferenceContainerType .CONTAINER_TYPE_TGI in container_type_name .lower ():
883
+ return InferenceContainerParamType .PARAM_TYPE_TGI
884
+ else :
885
+ return UNKNOWN
886
+
887
+
888
+ def get_restricted_params_by_container (container_type_name : str ) -> set :
889
+ """The utility function accepts the deployment container type name and returns a set of restricted params
890
+ for that container.
891
+ Parameters
892
+ ----------
893
+ container_type_name: str
894
+ type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
895
+
896
+ Returns
897
+ -------
898
+ Set of restricted params based on container type
899
+
900
+ """
901
+ # check substring instead of direct match in case container_type_name changes in the future
902
+ if InferenceContainerType .CONTAINER_TYPE_VLLM in container_type_name .lower ():
903
+ return VLLM_INFERENCE_RESTRICTED_PARAMS
904
+ elif InferenceContainerType .CONTAINER_TYPE_TGI in container_type_name .lower ():
905
+ return TGI_INFERENCE_RESTRICTED_PARAMS
906
+ else :
907
+ return set ()
0 commit comments