Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b970eb5

Browse files
authoredFeb 3, 2025··
Add ignore validation flag while registering model & improve logging (#1023)
2 parents 6a9fe42 + 85a825b commit b970eb5

File tree

14 files changed

+319
-139
lines changed

14 files changed

+319
-139
lines changed
 

‎ads/aqua/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
@@ -298,7 +298,7 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
298298
config = {}
299299
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
300300
if not artifact_path:
301-
logger.error(
301+
logger.debug(
302302
f"Failed to get artifact path from custom metadata for the model: {model_id}"
303303
)
304304
return config

‎ads/aqua/evaluation/evaluation.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
import base64
55
import json
@@ -199,11 +199,11 @@ def create(
199199
eval_inference_configuration = (
200200
container.spec.evaluation_configuration
201201
)
202-
except Exception:
202+
except Exception as ex:
203203
logger.debug(
204204
f"Could not load inference config details for the evaluation source id: "
205205
f"{create_aqua_evaluation_details.evaluation_source_id}. Please check if the container"
206-
f" runtime has the correct SMC image information."
206+
f" runtime has the correct SMC image information.\nError: {str(ex)}"
207207
)
208208
elif (
209209
DataScienceResource.MODEL
@@ -289,7 +289,7 @@ def create(
289289
f"Invalid experiment name. Please provide an experiment with `{Tags.AQUA_EVALUATION}` in tags."
290290
)
291291
except Exception:
292-
logger.debug(
292+
logger.info(
293293
f"Model version set {experiment_model_version_set_name} doesn't exist. "
294294
"Creating new model version set."
295295
)
@@ -711,21 +711,27 @@ def get(self, eval_id) -> AquaEvaluationDetail:
711711
try:
712712
log = utils.query_resource(log_id, return_all=False)
713713
log_name = log.display_name if log else ""
714-
except Exception:
714+
except Exception as ex:
715+
logger.debug(f"Failed to get associated log name. Error: {ex}")
715716
pass
716717

717718
if loggroup_id:
718719
try:
719720
loggroup = utils.query_resource(loggroup_id, return_all=False)
720721
loggroup_name = loggroup.display_name if loggroup else ""
721-
except Exception:
722+
except Exception as ex:
723+
logger.debug(f"Failed to get associated loggroup name. Error: {ex}")
722724
pass
723725

724726
try:
725727
introspection = json.loads(
726728
self._get_attribute_from_model_metadata(resource, "ArtifactTestResults")
727729
)
728-
except Exception:
730+
except Exception as ex:
731+
logger.debug(
732+
f"There was an issue loading the model attribute as json object for evaluation {eval_id}. "
733+
f"Setting introspection to empty.\n Error:{ex}"
734+
)
729735
introspection = {}
730736

731737
summary = AquaEvaluationDetail(
@@ -878,13 +884,13 @@ def get_status(self, eval_id: str) -> dict:
878884
try:
879885
log_id = job_run_details.log_details.log_id
880886
except Exception as e:
881-
logger.debug(f"Failed to get associated log. {str(e)}")
887+
logger.debug(f"Failed to get associated log.\nError: {str(e)}")
882888
log_id = ""
883889

884890
try:
885891
loggroup_id = job_run_details.log_details.log_group_id
886892
except Exception as e:
887-
logger.debug(f"Failed to get associated log. {str(e)}")
893+
logger.debug(f"Failed to get associated log.\nError: {str(e)}")
888894
loggroup_id = ""
889895

890896
loggroup_url = get_log_links(region=self.region, log_group_id=loggroup_id)
@@ -958,7 +964,7 @@ def load_metrics(self, eval_id: str) -> AquaEvalMetrics:
958964
)
959965
except Exception as e:
960966
logger.debug(
961-
"Failed to load `report.json` from evaluation artifact" f"{str(e)}"
967+
f"Failed to load `report.json` from evaluation artifact.\nError: {str(e)}"
962968
)
963969
json_report = {}
964970

@@ -1047,6 +1053,7 @@ def download_report(self, eval_id) -> AquaEvalReport:
10471053
return report
10481054

10491055
with tempfile.TemporaryDirectory() as temp_dir:
1056+
logger.info(f"Downloading evaluation artifact for {eval_id}.")
10501057
DataScienceModel.from_id(eval_id).download_artifact(
10511058
temp_dir,
10521059
auth=self._auth,
@@ -1200,6 +1207,7 @@ def _delete_job_and_model(job, model):
12001207
def load_evaluation_config(self, container: Optional[str] = None) -> Dict:
12011208
"""Loads evaluation config."""
12021209

1210+
logger.info("Loading evaluation container config.")
12031211
# retrieve the evaluation config by container family name
12041212
evaluation_config = get_evaluation_service_config(container)
12051213

@@ -1279,9 +1287,9 @@ def _get_source(
12791287
raise AquaRuntimeError(
12801288
f"Not supported source type: {resource_type}"
12811289
)
1282-
except Exception:
1290+
except Exception as ex:
12831291
logger.debug(
1284-
f"Failed to retrieve source information for evaluation {evaluation.identifier}."
1292+
f"Failed to retrieve source information for evaluation {evaluation.identifier}.\nError: {str(ex)}"
12851293
)
12861294
source_name = ""
12871295

‎ads/aqua/extension/aqua_ws_msg_handler.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import traceback
7+
import uuid
88
from abc import abstractmethod
99
from http.client import responses
1010
from typing import List
@@ -34,7 +34,7 @@ def __init__(self, message: str):
3434
self.telemetry = TelemetryClient(
3535
bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
3636
)
37-
except:
37+
except Exception:
3838
pass
3939

4040
@staticmethod
@@ -66,24 +66,31 @@ def write_error(self, status_code, **kwargs):
6666
"message": message,
6767
"service_payload": service_payload,
6868
"reason": reason,
69+
"request_id": str(uuid.uuid4()),
6970
}
7071
exc_info = kwargs.get("exc_info")
7172
if exc_info:
72-
logger.error("".join(traceback.format_exception(*exc_info)))
73+
logger.error(
74+
f"Error Request ID: {reply['request_id']}\n"
75+
f"Error: {''.join(traceback.format_exception(*exc_info))}"
76+
)
7377
e = exc_info[1]
7478
if isinstance(e, HTTPError):
7579
reply["message"] = e.log_message or message
7680
reply["reason"] = e.reason
77-
else:
78-
logger.warning(reply["message"])
81+
82+
logger.error(
83+
f"Error Request ID: {reply['request_id']}\n"
84+
f"Error: {reply['message']} {reply['reason']}"
85+
)
7986
# telemetry may not be present if there is an error while initializing
8087
if hasattr(self, "telemetry"):
8188
aqua_api_details = kwargs.get("aqua_api_details", {})
8289
self.telemetry.record_event_async(
8390
category="aqua/error",
8491
action=str(status_code),
8592
value=reason,
86-
**aqua_api_details
93+
**aqua_api_details,
8794
)
8895
response = AquaWsError(
8996
status=status_code,

‎ads/aqua/extension/base_handler.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65

@@ -35,7 +34,7 @@ def __init__(
3534
self.telemetry = TelemetryClient(
3635
bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
3736
)
38-
except:
37+
except Exception:
3938
pass
4039

4140
@staticmethod
@@ -82,19 +81,23 @@ def write_error(self, status_code, **kwargs):
8281
"message": message,
8382
"service_payload": service_payload,
8483
"reason": reason,
84+
"request_id": str(uuid.uuid4()),
8585
}
8686
exc_info = kwargs.get("exc_info")
8787
if exc_info:
88-
logger.error("".join(traceback.format_exception(*exc_info)))
88+
logger.error(
89+
f"Error Request ID: {reply['request_id']}\n"
90+
f"Error: {''.join(traceback.format_exception(*exc_info))}"
91+
)
8992
e = exc_info[1]
9093
if isinstance(e, HTTPError):
9194
reply["message"] = e.log_message or message
9295
reply["reason"] = e.reason if e.reason else reply["reason"]
93-
reply["request_id"] = str(uuid.uuid4())
94-
else:
95-
reply["request_id"] = str(uuid.uuid4())
9696

97-
logger.warning(reply["message"])
97+
logger.error(
98+
f"Error Request ID: {reply['request_id']}\n"
99+
f"Error: {reply['message']} {reply['reason']}"
100+
)
98101

99102
# telemetry may not be present if there is an error while initializing
100103
if hasattr(self, "telemetry"):
@@ -103,7 +106,7 @@ def write_error(self, status_code, **kwargs):
103106
category="aqua/error",
104107
action=str(status_code),
105108
value=reason,
106-
**aqua_api_details
109+
**aqua_api_details,
107110
)
108111

109112
self.finish(json.dumps(reply))

‎ads/aqua/extension/model_handler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
140140
ignore_patterns = input_data.get("ignore_patterns")
141141
freeform_tags = input_data.get("freeform_tags")
142142
defined_tags = input_data.get("defined_tags")
143+
ignore_model_artifact_check = (
144+
str(input_data.get("ignore_model_artifact_check", "false")).lower()
145+
== "true"
146+
)
143147

144148
return self.finish(
145149
AquaModelApp().register(
@@ -158,6 +162,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
158162
ignore_patterns=ignore_patterns,
159163
freeform_tags=freeform_tags,
160164
defined_tags=defined_tags,
165+
ignore_model_artifact_check=ignore_model_artifact_check,
161166
)
162167
)
163168

‎ads/aqua/finetuning/finetuning.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,19 @@ def create(
382382
defined_tags=model_defined_tags,
383383
),
384384
)
385+
logger.debug(
386+
f"Successfully updated model custom metadata list and freeform tags for the model {ft_model.id}."
387+
)
385388

386389
self.update_model_provenance(
387390
model_id=ft_model.id,
388391
update_model_provenance_details=UpdateModelProvenanceDetails(
389392
training_id=ft_job_run.id
390393
),
391394
)
395+
logger.debug(
396+
f"Successfully updated model provenance for the model {ft_model.id}."
397+
)
392398

393399
# tracks the shape and replica used for fine-tuning the service models
394400
telemetry_kwargs = (
@@ -564,7 +570,7 @@ def get_finetuning_config(self, model_id: str) -> Dict:
564570
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
565571
if not config:
566572
logger.debug(
567-
f"Fine-tuning config for custom model: {model_id} is not available."
573+
f"Fine-tuning config for custom model: {model_id} is not available. Use defaults."
568574
)
569575
return config
570576

‎ads/aqua/model/entities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ class ImportModelDetails(CLIBuilderMixin):
294294
ignore_patterns: Optional[List[str]] = None
295295
freeform_tags: Optional[dict] = None
296296
defined_tags: Optional[dict] = None
297+
ignore_model_artifact_check: Optional[bool] = None
297298

298299
def __post_init__(self):
299300
self._command = "model register"

‎ads/aqua/model/model.py

Lines changed: 137 additions & 58 deletions
Large diffs are not rendered by default.

‎ads/aqua/modeldeployment/deployment.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5-
import logging
65
import shlex
76
from typing import Dict, List, Optional, Union
87

@@ -271,7 +270,7 @@ def create(
271270
f"field. Either re-register the model with custom container URI, or set container_image_uri "
272271
f"parameter when creating this deployment."
273272
) from err
274-
logging.info(
273+
logger.info(
275274
f"Aqua Image used for deploying {aqua_model.id} : {container_image_uri}"
276275
)
277276

@@ -282,14 +281,14 @@ def create(
282281
default_cmd_var = shlex.split(cmd_var_string)
283282
if default_cmd_var:
284283
cmd_var = validate_cmd_var(default_cmd_var, cmd_var)
285-
logging.info(f"CMD used for deploying {aqua_model.id} :{cmd_var}")
284+
logger.info(f"CMD used for deploying {aqua_model.id} :{cmd_var}")
286285
except ValueError:
287-
logging.debug(
286+
logger.debug(
288287
f"CMD will be ignored for this deployment as {AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME} "
289288
f"key is not available in the custom metadata field for this model."
290289
)
291290
except Exception as e:
292-
logging.error(
291+
logger.error(
293292
f"There was an issue processing CMD arguments. Error: {str(e)}"
294293
)
295294

@@ -385,7 +384,7 @@ def create(
385384
if key not in env_var:
386385
env_var.update(env)
387386

388-
logging.info(f"Env vars used for deploying {aqua_model.id} :{env_var}")
387+
logger.info(f"Env vars used for deploying {aqua_model.id} :{env_var}")
389388

390389
# Start model deployment
391390
# configure model deployment infrastructure
@@ -440,10 +439,14 @@ def create(
440439
.with_runtime(container_runtime)
441440
).deploy(wait_for_completion=False)
442441

442+
deployment_id = deployment.dsc_model_deployment.id
443+
logger.info(
444+
f"Aqua model deployment {deployment_id} created for model {aqua_model.id}."
445+
)
443446
model_type = (
444447
AQUA_MODEL_TYPE_CUSTOM if is_fine_tuned_model else AQUA_MODEL_TYPE_SERVICE
445448
)
446-
deployment_id = deployment.dsc_model_deployment.id
449+
447450
# we arbitrarily choose last 8 characters of OCID to identify MD in telemetry
448451
telemetry_kwargs = {"ocid": get_ocid_substring(deployment_id, key_len=8)}
449452

@@ -539,25 +542,31 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
539542
value=state,
540543
)
541544

545+
logger.info(
546+
f"Fetched {len(results)} model deployments from compartment_id={compartment_id}."
547+
)
542548
# tracks number of times deployment listing was called
543549
self.telemetry.record_event_async(category="aqua/deployment", action="list")
544550

545551
return results
546552

547553
@telemetry(entry_point="plugin=deployment&action=delete", name="aqua")
548554
def delete(self, model_deployment_id: str):
555+
logger.info(f"Deleting model deployment {model_deployment_id}.")
549556
return self.ds_client.delete_model_deployment(
550557
model_deployment_id=model_deployment_id
551558
).data
552559

553560
@telemetry(entry_point="plugin=deployment&action=deactivate", name="aqua")
554561
def deactivate(self, model_deployment_id: str):
562+
logger.info(f"Deactivating model deployment {model_deployment_id}.")
555563
return self.ds_client.deactivate_model_deployment(
556564
model_deployment_id=model_deployment_id
557565
).data
558566

559567
@telemetry(entry_point="plugin=deployment&action=activate", name="aqua")
560568
def activate(self, model_deployment_id: str):
569+
logger.info(f"Activating model deployment {model_deployment_id}.")
561570
return self.ds_client.activate_model_deployment(
562571
model_deployment_id=model_deployment_id
563572
).data
@@ -579,6 +588,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
579588
AquaDeploymentDetail:
580589
The instance of the Aqua model deployment details.
581590
"""
591+
logger.info(f"Fetching model deployment details for {model_deployment_id}.")
592+
582593
model_deployment = self.ds_client.get_model_deployment(
583594
model_deployment_id=model_deployment_id, **kwargs
584595
).data
@@ -594,7 +605,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
594605

595606
if not oci_aqua:
596607
raise AquaRuntimeError(
597-
f"Target deployment {model_deployment_id} is not Aqua deployment."
608+
f"Target deployment {model_deployment_id} is not Aqua deployment as it does not contain "
609+
f"{Tags.AQUA_TAG} tag."
598610
)
599611

600612
log_id = ""
@@ -652,7 +664,7 @@ def get_deployment_config(self, model_id: str) -> Dict:
652664
config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG)
653665
if not config:
654666
logger.debug(
655-
f"Deployment config for custom model: {model_id} is not available."
667+
f"Deployment config for custom model: {model_id} is not available. Use defaults."
656668
)
657669
return config
658670

‎ads/cli.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

6+
import logging
77
import sys
88
import traceback
9-
from dataclasses import is_dataclass
9+
import uuid
1010

1111
import fire
1212

@@ -27,7 +27,7 @@
2727
)
2828
logger.debug(ex)
2929
logger.debug(traceback.format_exc())
30-
exit()
30+
sys.exit()
3131

3232
# https://packaging.python.org/en/latest/guides/single-sourcing-package-version/#single-sourcing-the-package-version
3333
if sys.version_info >= (3, 8):
@@ -122,8 +122,9 @@ def exit_program(ex: Exception, logger: "logging.Logger") -> None:
122122
... exit_program(e, logger)
123123
"""
124124

125-
logger.debug(traceback.format_exc())
126-
logger.error(str(ex))
125+
request_id = str(uuid.uuid4())
126+
logger.debug(f"Error Request ID: {request_id}\nError: {traceback.format_exc()}")
127+
logger.error(f"Error Request ID: {request_id}\n" f"Error: {str(ex)}")
127128

128129
exit_code = getattr(ex, "exit_code", 1)
129130
logger.error(f"Exit code: {exit_code}")

‎tests/unitary/with_extras/aqua/test_cli.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import logging
88
import os
99
import subprocess
10+
import uuid
1011
from importlib import reload
1112
from unittest import TestCase
1213
from unittest.mock import call, patch
@@ -148,6 +149,7 @@ def test_aqua_cli(self, mock_logger, mock_aqua_command, mock_fire, mock_serializ
148149
]
149150
)
150151
@patch("sys.argv", ["ads", "aqua", "--error-option"])
152+
@patch("uuid.uuid4")
151153
@patch("fire.Fire")
152154
@patch("ads.aqua.cli.AquaCommand")
153155
@patch("ads.aqua.logger.error")
@@ -162,11 +164,17 @@ def test_aqua_cli_with_error(
162164
mock_logger_error,
163165
mock_aqua_command,
164166
mock_fire,
167+
mock_uuid,
165168
):
166169
"""Tests when Aqua Cli gracefully exit when error raised."""
167170
mock_fire.side_effect = mock_side_effect
168171
from ads.cli import cli
169172

173+
uuid_value = "12345678-1234-5678-1234-567812345678"
174+
mock_uuid.return_value = uuid.UUID(uuid_value)
175+
expected_logging_message = type(expected_logging_message)(
176+
f"Error Request ID: {uuid_value}\nError: {str(expected_logging_message)}"
177+
)
170178
cli()
171179
calls = [
172180
call(expected_logging_message),

‎tests/unitary/with_extras/aqua/test_handlers.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import json
@@ -131,9 +131,13 @@ def test_finish(self, name, payload, expected_call, mock_super_finish):
131131
),
132132
aqua_api_details=dict(
133133
aqua_api_name="TestDataset.create",
134-
oci_api_name=TestDataset.mock_service_payload_create["operation_name"],
135-
service_endpoint=TestDataset.mock_service_payload_create["request_endpoint"]
136-
)
134+
oci_api_name=TestDataset.mock_service_payload_create[
135+
"operation_name"
136+
],
137+
service_endpoint=TestDataset.mock_service_payload_create[
138+
"request_endpoint"
139+
],
140+
),
137141
),
138142
"Authorization Failed: The resource you're looking for isn't accessible. Operation Name: get_job_run.",
139143
],
@@ -171,10 +175,13 @@ def test_write_error(self, name, input, expected_msg, mock_uuid, mock_logger):
171175
input.get("status_code"),
172176
),
173177
value=input.get("reason"),
174-
**aqua_api_details
178+
**aqua_api_details,
175179
)
176-
177-
mock_logger.warning.assert_called_with(expected_msg)
180+
error_message = (
181+
f"Error Request ID: {expected_reply['request_id']}\n"
182+
f"Error: {expected_reply['message']} {expected_reply['reason']}"
183+
)
184+
mock_logger.error.assert_called_with(error_message)
178185

179186

180187
class TestHandlers(unittest.TestCase):

‎tests/unitary/with_extras/aqua/test_model.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -937,10 +937,18 @@ def test_import_model_with_project_compartment_override(
937937
assert model.project_id == project_override
938938

939939
@pytest.mark.parametrize(
940-
"download_from_hf",
941-
[True, False],
940+
("ignore_artifact_check", "download_from_hf"),
941+
[
942+
(True, True),
943+
(True, False),
944+
(False, True),
945+
(False, False),
946+
(None, False),
947+
(None, True),
948+
],
942949
)
943950
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
951+
@patch("ads.model.datascience_model.DataScienceModel.sync")
944952
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
945953
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
946954
@patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError)
@@ -953,45 +961,65 @@ def test_import_model_with_missing_config(
953961
mock_load_config,
954962
mock_list_objects,
955963
mock_upload_artifact,
964+
mock_sync,
956965
mock_ocidsc_create,
957-
mock_get_container_config,
966+
ignore_artifact_check,
958967
download_from_hf,
959968
mock_get_hf_model_info,
960969
mock_init_client,
961970
):
962-
"""Test for validating if error is returned when model artifacts are incomplete or not available."""
963-
964-
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
965-
model_name = "oracle/aqua-1t-mega-model"
971+
my_model = "oracle/aqua-1t-mega-model"
966972
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
967-
mock_list_objects.return_value = MagicMock(objects=[])
968-
reload(ads.aqua.model.model)
969-
app = AquaModelApp()
970-
app.list = MagicMock(return_value=[])
973+
# set object list from OSS without config.json
974+
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
971975

976+
# set object list from HF without config.json
972977
if download_from_hf:
973-
with pytest.raises(AquaValueError):
974-
mock_get_hf_model_info.return_value.siblings = []
975-
with tempfile.TemporaryDirectory() as tmpdir:
976-
model: AquaModel = app.register(
977-
model=model_name,
978-
os_path=os_path,
979-
local_dir=str(tmpdir),
980-
download_from_hf=True,
981-
)
978+
mock_get_hf_model_info.return_value.siblings = [
979+
MagicMock(rfilename="model.safetensors")
980+
]
982981
else:
983-
with pytest.raises(AquaRuntimeError):
982+
obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150)
983+
obj1.name = f"prefix/path/model.safetensors"
984+
objects = [obj1]
985+
mock_list_objects.return_value = MagicMock(objects=objects)
986+
987+
reload(ads.aqua.model.model)
988+
app = AquaModelApp()
989+
with patch.object(AquaModelApp, "list") as aqua_model_mock_list:
990+
aqua_model_mock_list.return_value = [
991+
AquaModelSummary(
992+
id="test_id1",
993+
name="organization1/name1",
994+
organization="organization1",
995+
)
996+
]
997+
998+
if ignore_artifact_check:
984999
model: AquaModel = app.register(
985-
model=model_name,
1000+
model=my_model,
9861001
os_path=os_path,
987-
download_from_hf=False,
1002+
inference_container="odsc-vllm-or-tgi-container",
1003+
finetuning_container="odsc-llm-fine-tuning",
1004+
download_from_hf=download_from_hf,
1005+
ignore_model_artifact_check=ignore_artifact_check,
9881006
)
1007+
assert model.ready_to_deploy is True
1008+
else:
1009+
with pytest.raises(AquaRuntimeError):
1010+
model: AquaModel = app.register(
1011+
model=my_model,
1012+
os_path=os_path,
1013+
inference_container="odsc-vllm-or-tgi-container",
1014+
finetuning_container="odsc-llm-fine-tuning",
1015+
download_from_hf=download_from_hf,
1016+
ignore_model_artifact_check=ignore_artifact_check,
1017+
)
9891018

9901019
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
9911020
@patch("ads.model.datascience_model.DataScienceModel.sync")
9921021
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
9931022
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
994-
@patch.object(HfApi, "model_info")
9951023
@patch("ads.aqua.common.utils.load_config", return_value={})
9961024
def test_import_any_model_smc_container(
9971025
self,
@@ -1247,6 +1275,15 @@ def test_import_model_with_input_tags(
12471275
"--download_from_hf True --cleanup_model_cache True --inference_container odsc-vllm-serving --freeform_tags "
12481276
'{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}',
12491277
),
1278+
(
1279+
{
1280+
"os_path": "oci://aqua-bkt@aqua-ns/path",
1281+
"model": "oracle/oracle-1it",
1282+
"inference_container": "odsc-vllm-serving",
1283+
"ignore_model_artifact_check": True,
1284+
},
1285+
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --inference_container odsc-vllm-serving --ignore_model_artifact_check True",
1286+
),
12501287
],
12511288
)
12521289
def test_import_cli(self, data, expected_output):

‎tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_list(self, mock_list):
132132

133133
@parameterized.expand(
134134
[
135-
(None, None, False, None, None, None, None, None),
135+
(None, None, False, None, None, None, None, None, True),
136136
(
137137
"odsc-llm-fine-tuning",
138138
None,
@@ -142,8 +142,9 @@ def test_list(self, mock_list):
142142
["test.json"],
143143
None,
144144
None,
145+
False,
145146
),
146-
(None, "test.gguf", True, None, ["*.json"], None, None, None),
147+
(None, "test.gguf", True, None, ["*.json"], None, None, None, False),
147148
(
148149
None,
149150
None,
@@ -153,6 +154,7 @@ def test_list(self, mock_list):
153154
["test.json"],
154155
None,
155156
None,
157+
False,
156158
),
157159
(
158160
None,
@@ -163,6 +165,7 @@ def test_list(self, mock_list):
163165
None,
164166
{"ftag1": "fvalue1"},
165167
{"dtag1": "dvalue1"},
168+
False,
166169
),
167170
],
168171
)
@@ -178,6 +181,7 @@ def test_register(
178181
ignore_patterns,
179182
freeform_tags,
180183
defined_tags,
184+
ignore_model_artifact_check,
181185
mock_register,
182186
mock_finish,
183187
):
@@ -201,6 +205,7 @@ def test_register(
201205
ignore_patterns=ignore_patterns,
202206
freeform_tags=freeform_tags,
203207
defined_tags=defined_tags,
208+
ignore_model_artifact_check=ignore_model_artifact_check,
204209
)
205210
)
206211
result = self.model_handler.post()
@@ -220,6 +225,7 @@ def test_register(
220225
ignore_patterns=ignore_patterns,
221226
freeform_tags=freeform_tags,
222227
defined_tags=defined_tags,
228+
ignore_model_artifact_check=ignore_model_artifact_check,
223229
)
224230
assert result["id"] == "test_id"
225231
assert result["inference_container"] == "odsc-tgi-serving"

0 commit comments

Comments
 (0)
Please sign in to comment.