Skip to content

Commit 6e0c223

Browse files
Merge branch 'main' into v2.11.15
2 parents a7a9271 + c32e074 commit 6e0c223

File tree

10 files changed

+294
-263
lines changed

10 files changed

+294
-263
lines changed

ads/aqua/common/utils.py

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
"""AQUA utils and constants."""
5+
66
import asyncio
77
import base64
88
import json
@@ -19,13 +19,30 @@
1919
import oci
2020
from oci.data_science.models import JobRun, Model
2121

22-
from ads.aqua.common.enums import RqsAdditionalDetails
22+
from ads.aqua.common.enums import (
23+
InferenceContainerParamType,
24+
InferenceContainerType,
25+
RqsAdditionalDetails,
26+
)
2327
from ads.aqua.common.errors import (
2428
AquaFileNotFoundError,
2529
AquaRuntimeError,
2630
AquaValueError,
2731
)
28-
from ads.aqua.constants import *
32+
from ads.aqua.constants import (
33+
AQUA_GA_LIST,
34+
COMPARTMENT_MAPPING_KEY,
35+
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
36+
CONTAINER_INDEX,
37+
MAXIMUM_ALLOWED_DATASET_IN_BYTE,
38+
MODEL_BY_REFERENCE_OSS_PATH_KEY,
39+
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
40+
SUPPORTED_FILE_FORMATS,
41+
TGI_INFERENCE_RESTRICTED_PARAMS,
42+
UNKNOWN,
43+
UNKNOWN_JSON_STR,
44+
VLLM_INFERENCE_RESTRICTED_PARAMS,
45+
)
2946
from ads.aqua.data import AquaResourceIdentifier
3047
from ads.common.auth import default_signer
3148
from ads.common.decorator.threaded import threaded
@@ -74,15 +91,15 @@ def get_status(evaluation_status: str, job_run_status: str = None):
7491

7592
status = LifecycleStatus.UNKNOWN
7693
if evaluation_status == Model.LIFECYCLE_STATE_ACTIVE:
77-
if (
78-
job_run_status == JobRun.LIFECYCLE_STATE_IN_PROGRESS
79-
or job_run_status == JobRun.LIFECYCLE_STATE_ACCEPTED
80-
):
94+
if job_run_status in {
95+
JobRun.LIFECYCLE_STATE_IN_PROGRESS,
96+
JobRun.LIFECYCLE_STATE_ACCEPTED,
97+
}:
8198
status = JobRun.LIFECYCLE_STATE_IN_PROGRESS
82-
elif (
83-
job_run_status == JobRun.LIFECYCLE_STATE_FAILED
84-
or job_run_status == JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION
85-
):
99+
elif job_run_status in {
100+
JobRun.LIFECYCLE_STATE_FAILED,
101+
JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
102+
}:
86103
status = JobRun.LIFECYCLE_STATE_FAILED
87104
else:
88105
status = job_run_status
@@ -199,10 +216,7 @@ def read_file(file_path: str, **kwargs) -> str:
199216
@threaded()
200217
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
201218
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
202-
if artifact_path.startswith("oci://"):
203-
signer = default_signer()
204-
else:
205-
signer = {}
219+
signer = default_signer() if artifact_path.startswith("oci://") else {}
206220
config = json.loads(
207221
read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR
208222
)
@@ -448,7 +462,7 @@ def _build_resource_identifier(
448462

449463

450464
def _get_experiment_info(
451-
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel]
465+
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel],
452466
) -> tuple:
453467
"""Returns ocid and name of the experiment."""
454468
return (
@@ -609,7 +623,7 @@ def extract_id_and_name_from_tag(tag: str):
609623
base_model_name = UNKNOWN
610624
try:
611625
base_model_ocid, base_model_name = tag.split("#")
612-
except:
626+
except Exception:
613627
pass
614628

615629
if not (is_valid_ocid(base_model_ocid) and base_model_name):
@@ -646,7 +660,7 @@ def get_resource_name(ocid: str) -> str:
646660
try:
647661
resource = query_resource(ocid, return_all=False)
648662
name = resource.display_name if resource else UNKNOWN
649-
except:
663+
except Exception:
650664
name = UNKNOWN
651665
return name
652666

@@ -670,8 +684,8 @@ def get_model_by_reference_paths(model_file_description: dict):
670684

671685
if not models:
672686
raise AquaValueError(
673-
f"Model path is not available in the model json artifact. "
674-
f"Please check if the model created by reference has the correct artifact."
687+
"Model path is not available in the model json artifact. "
688+
"Please check if the model created by reference has the correct artifact."
675689
)
676690

677691
if len(models) > 0:
@@ -848,3 +862,46 @@ def copy_model_config(artifact_path: str, os_path: str, auth: dict = None):
848862
except Exception as ex:
849863
logger.debug(ex)
850864
logger.debug(f"Failed to copy config folder from {artifact_path} to {os_path}.")
865+
866+
867+
def get_container_params_type(container_type_name: str) -> str:
868+
"""The utility function accepts the deployment container type name and returns the corresponding params name.
869+
Parameters
870+
----------
871+
container_type_name: str
872+
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
873+
874+
Returns
875+
-------
876+
InferenceContainerParamType value
877+
878+
"""
879+
# check substring instead of direct match in case container_type_name changes in the future
880+
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
881+
return InferenceContainerParamType.PARAM_TYPE_VLLM
882+
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
883+
return InferenceContainerParamType.PARAM_TYPE_TGI
884+
else:
885+
return UNKNOWN
886+
887+
888+
def get_restricted_params_by_container(container_type_name: str) -> set:
889+
"""The utility function accepts the deployment container type name and returns a set of restricted params
890+
for that container.
891+
Parameters
892+
----------
893+
container_type_name: str
894+
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
895+
896+
Returns
897+
-------
898+
Set of restricted params based on container type
899+
900+
"""
901+
# check substring instead of direct match in case container_type_name changes in the future
902+
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
903+
return VLLM_INFERENCE_RESTRICTED_PARAMS
904+
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
905+
return TGI_INFERENCE_RESTRICTED_PARAMS
906+
else:
907+
return set()

ads/aqua/constants.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
"""This module defines constants used in ads.aqua module."""
@@ -45,19 +44,33 @@
4544
SUPPORTED_FILE_FORMATS = ["jsonl"]
4645
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
4746

48-
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = dict(
49-
datasciencemodel="models",
50-
datasciencemodeldeployment="model-deployments",
51-
datasciencemodeldeploymentdev="model-deployments",
52-
datasciencemodeldeploymentint="model-deployments",
53-
datasciencemodeldeploymentpre="model-deployments",
54-
datasciencejob="jobs",
55-
datasciencejobrun="job-runs",
56-
datasciencejobrundev="job-runs",
57-
datasciencejobrunint="job-runs",
58-
datasciencejobrunpre="job-runs",
59-
datasciencemodelversionset="model-version-sets",
60-
datasciencemodelversionsetpre="model-version-sets",
61-
datasciencemodelversionsetint="model-version-sets",
62-
datasciencemodelversionsetdev="model-version-sets",
63-
)
47+
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
48+
"datasciencemodel": "models",
49+
"datasciencemodeldeployment": "model-deployments",
50+
"datasciencemodeldeploymentdev": "model-deployments",
51+
"datasciencemodeldeploymentint": "model-deployments",
52+
"datasciencemodeldeploymentpre": "model-deployments",
53+
"datasciencejob": "jobs",
54+
"datasciencejobrun": "job-runs",
55+
"datasciencejobrundev": "job-runs",
56+
"datasciencejobrunint": "job-runs",
57+
"datasciencejobrunpre": "job-runs",
58+
"datasciencemodelversionset": "model-version-sets",
59+
"datasciencemodelversionsetpre": "model-version-sets",
60+
"datasciencemodelversionsetint": "model-version-sets",
61+
"datasciencemodelversionsetdev": "model-version-sets",
62+
}
63+
64+
VLLM_INFERENCE_RESTRICTED_PARAMS = {
65+
"--port",
66+
"--host",
67+
"--served-model-name",
68+
"--seed",
69+
}
70+
TGI_INFERENCE_RESTRICTED_PARAMS = {
71+
"--port",
72+
"--hostname",
73+
"--num-shard",
74+
"--sharded",
75+
"--trust-remote-code",
76+
}

0 commit comments

Comments
 (0)