Skip to content

Commit fa5f53d

Browse files
authored
Merge branch 'main' into feature/odsc-65115
2 parents 8f21350 + 221eb14 commit fa5f53d

File tree

11 files changed

+163
-51
lines changed

11 files changed

+163
-51
lines changed

ads/aqua/app.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import os
7+
import traceback
78
from dataclasses import fields
89
from typing import Dict, Union
910

@@ -23,7 +24,7 @@
2324
from ads.aqua.constants import UNKNOWN
2425
from ads.common import oci_client as oc
2526
from ads.common.auth import default_signer
26-
from ads.common.utils import extract_region
27+
from ads.common.utils import extract_region, is_path_exists
2728
from ads.config import (
2829
AQUA_TELEMETRY_BUCKET,
2930
AQUA_TELEMETRY_BUCKET_NS,
@@ -296,33 +297,44 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
296297
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
297298

298299
config = {}
299-
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
300+
# if the current model has a service model tag, then
301+
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
302+
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
303+
logger.info(
304+
f"Base model found for the model: {oci_model.id}. "
305+
f"Loading {config_file_name} for base model {base_model_ocid}."
306+
)
307+
base_model = self.ds_client.get_model(base_model_ocid).data
308+
artifact_path = get_artifact_path(base_model.custom_metadata_list)
309+
config_path = f"{os.path.dirname(artifact_path)}/config/"
310+
else:
311+
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
312+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
313+
config_path = f"{artifact_path.rstrip('/')}/config/"
314+
300315
if not artifact_path:
301316
logger.debug(
302317
f"Failed to get artifact path from custom metadata for the model: {model_id}"
303318
)
304319
return config
305320

306-
try:
307-
config_path = f"{os.path.dirname(artifact_path)}/config/"
308-
config = load_config(
309-
config_path,
310-
config_file_name=config_file_name,
311-
)
312-
except Exception:
313-
# todo: temp fix for issue related to config load for byom models, update logic to choose the right path
321+
config_file_path = f"{config_path}{config_file_name}"
322+
if is_path_exists(config_file_path):
314323
try:
315-
config_path = f"{artifact_path.rstrip('/')}/config/"
316324
config = load_config(
317325
config_path,
318326
config_file_name=config_file_name,
319327
)
320328
except Exception:
321-
pass
329+
logger.debug(
330+
f"Error loading the {config_file_name} at path {config_path}.\n"
331+
f"{traceback.format_exc()}"
332+
)
322333

323334
if not config:
324-
logger.error(
325-
f"{config_file_name} is not available for the model: {model_id}. Check if the custom metadata has the artifact path set."
335+
logger.debug(
336+
f"{config_file_name} is not available for the model: {model_id}. "
337+
f"Check if the custom metadata has the artifact path set."
326338
)
327339
return config
328340

ads/aqua/extension/model_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
133133
)
134134
local_dir = input_data.get("local_dir")
135135
cleanup_model_cache = (
136-
str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
136+
str(input_data.get("cleanup_model_cache", "false")).lower() == "true"
137137
)
138138
inference_container_uri = input_data.get("inference_container_uri")
139139
allow_patterns = input_data.get("allow_patterns")

ads/aqua/finetuning/entities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class CreateFineTuningDetails(Serializable):
122122
The log group id for fine tuning job infrastructure.
123123
log_id: (str, optional). Defaults to `None`.
124124
The log id for fine tuning job infrastructure.
125+
watch_logs: (bool, optional). Defaults to `False`.
126+
The flag to watch the job run logs when a fine-tuning job is created.
125127
force_overwrite: (bool, optional). Defaults to `False`.
126128
Whether to force overwrite the existing file in object storage.
127129
freeform_tags: (dict, optional)
@@ -148,6 +150,7 @@ class CreateFineTuningDetails(Serializable):
148150
subnet_id: Optional[str] = None
149151
log_id: Optional[str] = None
150152
log_group_id: Optional[str] = None
153+
watch_logs: Optional[bool] = False
151154
force_overwrite: Optional[bool] = False
152155
freeform_tags: Optional[dict] = None
153156
defined_tags: Optional[dict] = None

ads/aqua/finetuning/finetuning.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import json
66
import os
7+
import time
8+
import traceback
79
from typing import Dict
810

911
from oci.data_science.models import (
@@ -149,6 +151,15 @@ def create(
149151
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
150152
)
151153

154+
if create_fine_tuning_details.watch_logs and not (
155+
create_fine_tuning_details.log_id
156+
and create_fine_tuning_details.log_group_id
157+
):
158+
raise AquaValueError(
159+
"Logging is required for fine tuning if watch_logs is set to True. "
160+
"Please provide log_id and log_group_id with the request parameters."
161+
)
162+
152163
ft_parameters = self._get_finetuning_params(
153164
create_fine_tuning_details.ft_parameters
154165
)
@@ -422,6 +433,20 @@ def create(
422433
value=source.display_name,
423434
)
424435

436+
if create_fine_tuning_details.watch_logs:
437+
logger.info(
438+
f"Watching fine-tuning job run logs for {ft_job_run.id}. Press Ctrl+C to stop watching logs.\n"
439+
)
440+
try:
441+
ft_job_run.watch()
442+
except KeyboardInterrupt:
443+
logger.info(f"\nStopped watching logs for {ft_job_run.id}.\n")
444+
time.sleep(1)
445+
except Exception:
446+
logger.debug(
447+
f"Something unexpected occurred while watching logs.\n{traceback.format_exc()}"
448+
)
449+
425450
return AquaFineTuningSummary(
426451
id=ft_model.id,
427452
name=ft_model.display_name,

ads/aqua/model/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ class ImportModelDetails(CLIBuilderMixin):
283283
os_path: str
284284
download_from_hf: Optional[bool] = True
285285
local_dir: Optional[str] = None
286-
cleanup_model_cache: Optional[bool] = True
286+
cleanup_model_cache: Optional[bool] = False
287287
inference_container: Optional[str] = None
288288
finetuning_container: Optional[str] = None
289289
compartment_id: Optional[str] = None

ads/aqua/model/model.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
LifecycleStatus,
3030
_build_resource_identifier,
3131
cleanup_local_hf_model_artifact,
32-
copy_model_config,
3332
create_word_icon,
3433
generate_tei_cmd_var,
3534
get_artifact_path,
@@ -969,24 +968,6 @@ def _create_model_catalog_entry(
969968
)
970969
tags[Tags.LICENSE] = validation_result.tags.get(Tags.LICENSE, UNKNOWN)
971970

972-
try:
973-
# If verified model already has a artifact json, use that.
974-
artifact_path = metadata.get(MODEL_BY_REFERENCE_OSS_PATH_KEY).value
975-
logger.info(
976-
f"Found model artifact in the service bucket. "
977-
f"Using artifact from service bucket instead of {os_path}."
978-
)
979-
980-
# todo: implement generic copy_folder method
981-
# copy model config from artifact path to user bucket
982-
copy_model_config(
983-
artifact_path=artifact_path, os_path=os_path, auth=default_signer()
984-
)
985-
except Exception:
986-
logger.debug(
987-
f"Proceeding with model registration without copying model config files at {os_path}. "
988-
f"Default configuration will be used for deployment and fine-tuning."
989-
)
990971
# Set artifact location to user bucket, and replace existing key if present.
991972
metadata.add(
992973
key=MODEL_BY_REFERENCE_OSS_PATH_KEY,

ads/cli.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#!/usr/bin/env python
2-
32
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

5+
import json
66
import logging
77
import sys
88
import traceback
99
import uuid
1010

1111
import fire
12+
from pydantic import BaseModel
1213

1314
from ads.common import logger
1415

@@ -84,7 +85,13 @@ def serialize(data):
8485
The string representation of each dataclass object.
8586
"""
8687
if isinstance(data, list):
87-
[print(str(item)) for item in data]
88+
for item in data:
89+
if isinstance(item, BaseModel):
90+
print(json.dumps(item.dict(), indent=4))
91+
else:
92+
print(str(item))
93+
elif isinstance(data, BaseModel):
94+
print(json.dumps(data.dict(), indent=4))
8895
else:
8996
print(str(data))
9097

tests/unitary/with_extras/aqua/test_config.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
66
import os
7-
from unittest.mock import patch
7+
import pytest
8+
from unittest.mock import patch, MagicMock
9+
10+
import oci.data_science.models
811

912
from ads.aqua.common.entities import ContainerSpec
1013
from ads.aqua.config.config import get_evaluation_service_config
14+
from ads.aqua.app import AquaApp
1115

1216

1317
class TestConfig:
@@ -37,3 +41,63 @@ def test_evaluation_service_config(self, mock_get_container_config):
3741
test_result.to_dict()
3842
== expected_result[ContainerSpec.CONTAINER_SPEC]["test_container"]
3943
)
44+
45+
@pytest.mark.parametrize(
46+
"custom_metadata",
47+
[
48+
{
49+
"category": "Other",
50+
"description": "test_desc",
51+
"key": "artifact_location",
52+
"value": "artifact_location",
53+
},
54+
{},
55+
],
56+
)
57+
@pytest.mark.parametrize("verified_model", [True, False])
58+
@pytest.mark.parametrize("path_exists", [True, False])
59+
@patch("ads.aqua.app.load_config")
60+
def test_load_config(
61+
self, mock_load_config, custom_metadata, verified_model, path_exists
62+
):
63+
mock_load_config.return_value = {"config_key": "config_value"}
64+
service_model_tag = (
65+
{"aqua_service_model": "aqua_service_model_id"} if verified_model else {}
66+
)
67+
68+
self.app = AquaApp()
69+
70+
model = {
71+
"id": "mock_id",
72+
"lifecycle_details": "mock_lifecycle_details",
73+
"lifecycle_state": "mock_lifecycle_state",
74+
"project_id": "mock_project_id",
75+
"freeform_tags": {
76+
**{
77+
"OCI_AQUA": "",
78+
},
79+
**service_model_tag,
80+
},
81+
"custom_metadata_list": [
82+
oci.data_science.models.Metadata(**custom_metadata)
83+
],
84+
}
85+
86+
self.app.ds_client.get_model = MagicMock(
87+
return_value=oci.response.Response(
88+
status=200,
89+
request=MagicMock(),
90+
headers=MagicMock(),
91+
data=oci.data_science.models.Model(**model),
92+
)
93+
)
94+
with patch("ads.aqua.app.is_path_exists", return_value=path_exists):
95+
result = self.app.get_config(
96+
model_id="test_model_id", config_file_name="test_config_file_name"
97+
)
98+
if not path_exists:
99+
assert result == {}
100+
if not custom_metadata:
101+
assert result == {}
102+
if path_exists and custom_metadata:
103+
assert result == {"config_key": "config_value"}

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
2323
from ads.aqua.finetuning.entities import AquaFineTuningParams
2424
from ads.jobs.ads_job import Job
25+
from ads.jobs.builders.infrastructure.dsc_job import DataScienceJobRun
2526
from ads.model.datascience_model import DataScienceModel
2627
from ads.model.model_metadata import ModelCustomMetadata
2728
from ads.aqua.common.errors import AquaValueError
@@ -49,6 +50,12 @@ def tearDownClass(cls):
4950
reload(ads.aqua)
5051
reload(ads.aqua.finetuning.finetuning)
5152

53+
@parameterized.expand(
54+
[
55+
("watch_logs", True),
56+
("watch_logs", False),
57+
]
58+
)
5259
@patch.object(Job, "run")
5360
@patch("ads.jobs.ads_job.Job.name", new_callable=PropertyMock)
5461
@patch("ads.jobs.ads_job.Job.id", new_callable=PropertyMock)
@@ -60,6 +67,8 @@ def tearDownClass(cls):
6067
@patch.object(AquaApp, "get_source")
6168
def test_create_fine_tuning(
6269
self,
70+
mock_watch_logs,
71+
mock_watch_logs_called,
6372
mock_get_source,
6473
mock_mvs_create,
6574
mock_ds_model_create,
@@ -117,6 +126,7 @@ def test_create_fine_tuning(
117126
ft_job_run.id = "test_ft_job_run_id"
118127
ft_job_run.lifecycle_details = "Job run artifact execution in progress."
119128
ft_job_run.lifecycle_state = "IN_PROGRESS"
129+
ft_job_run.watch = MagicMock()
120130
mock_job_run.return_value = ft_job_run
121131

122132
self.app.ds_client.update_model = MagicMock()
@@ -144,7 +154,20 @@ def test_create_fine_tuning(
144154
defined_tags=ft_model_defined_tags,
145155
)
146156

147-
aqua_ft_summary = self.app.create(**create_aqua_ft_details)
157+
inputs = {
158+
**create_aqua_ft_details,
159+
**{
160+
mock_watch_logs: mock_watch_logs_called,
161+
"log_id": "test_log_id",
162+
"log_group_id": "test_log_group_id",
163+
},
164+
}
165+
aqua_ft_summary = self.app.create(**inputs)
166+
167+
if mock_watch_logs_called:
168+
ft_job_run.watch.assert_called()
169+
else:
170+
ft_job_run.watch.assert_not_called()
148171

149172
assert aqua_ft_summary.to_dict() == {
150173
"console_url": f"https://cloud.oracle.com/data-science/models/{ft_model.id}?region={self.app.region}",

0 commit comments

Comments
 (0)