7
7
import json
8
8
import os
9
9
import shutil
10
- from unittest .mock import Mock
11
10
import uuid
11
+ from typing import Optional
12
12
from zipfile import ZipFile
13
13
14
14
import pandas as pd
15
15
import yaml
16
16
from ads .catalog .summary import SummaryList
17
- from ads .common import auth , oci_client , utils , logger
18
- from ads .common .model_artifact import (
19
- ConflictStrategy ,
20
- ModelArtifact ,
21
- OUTPUT_SCHEMA_FILE_NAME ,
22
- )
17
+ from ads .common import auth , logger , oci_client , utils
18
+ from ads .common .decorator .deprecate import deprecated
19
+ from ads .common .model_artifact import ConflictStrategy , ModelArtifact
23
20
from ads .common .model_metadata import (
24
- ModelCustomMetadata ,
25
- ModelTaxonomyMetadata ,
26
21
METADATA_SIZE_LIMIT ,
27
22
MetadataSizeTooLarge ,
23
+ ModelCustomMetadata ,
24
+ ModelTaxonomyMetadata ,
28
25
)
29
- from ads .common .oci_resource import OCIResource , SEARCH_TYPE
26
+ from ads .common .oci_resource import SEARCH_TYPE , OCIResource
30
27
from ads .config import (
31
- OCI_IDENTITY_SERVICE_ENDPOINT ,
32
28
NB_SESSION_COMPARTMENT_OCID ,
33
29
OCI_ODSC_SERVICE_ENDPOINT ,
34
30
PROJECT_OCID ,
44
40
from oci .exceptions import ServiceError
45
41
from oci .identity import IdentityClient
46
42
47
-
48
43
_UPDATE_MODEL_DETAILS_ATTRIBUTES = [
49
44
"display_name" ,
50
45
"description" ,
@@ -566,29 +561,27 @@ class ModelCatalog:
566
561
567
562
def __init__ (
568
563
self ,
569
- compartment_id = None ,
570
- ds_client_auth = None ,
571
- identity_client_auth = None ,
572
- timeout : int = None ,
564
+ compartment_id : Optional [ str ] = None ,
565
+ ds_client_auth : Optional [ dict ] = None ,
566
+ identity_client_auth : Optional [ dict ] = None ,
567
+ timeout : Optional [ int ] = None ,
573
568
):
574
569
"""Initializes model catalog instance.
575
570
576
571
Parameters
577
572
----------
578
- compartment_id : str, optional
579
- OCID of model's compartment
580
- If None, the default compartment ID `config.NB_SESSION_COMPARTMENT_OCID` would be used
581
- ds_client_auth : dict
582
- Default is None. The default authetication is set using `ads.set_auth` API. If you need to override the
573
+ compartment_id : (str, optional). Defaults to None.
574
+ Model compartment OCID. If `None`, the `config.NB_SESSION_COMPARTMENT_OCID` would be used.
575
+ ds_client_auth : (dict, optional). Defaults to None.
576
+ The default authetication is set using `ads.set_auth` API. If you need to override the
583
577
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
584
578
authentication signer and kwargs required to instantiate DataScienceClient object.
585
- identity_client_auth : dict
586
- Default is None. The default authetication is set using `ads.set_auth` API. If you need to override the
579
+ identity_client_auth : ( dict, optional). Defaults to None.
580
+ The default authetication is set using `ads.set_auth` API. If you need to override the
587
581
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
588
582
authentication signer and kwargs required to instantiate IdentityClient object.
589
- timeout: int, optional
583
+ timeout: ( int, optional). Defaults to 10 seconds.
590
584
The connection timeout in seconds for the client.
591
- The default value for connection timeout is 10 seconds.
592
585
593
586
Raises
594
587
------
@@ -864,16 +857,11 @@ def delete_model(self, model, **kwargs):
864
857
logger .error ("Failed to delete the Model." )
865
858
return False
866
859
867
- def download_model (
868
- self ,
869
- model_id : str ,
870
- target_dir : str ,
871
- force_overwrite : bool = False ,
872
- install_libs : bool = False ,
873
- conflict_strategy = ConflictStrategy .IGNORE ,
874
- ):
860
+ def _download_artifacts (
861
+ self , model_id : str , target_dir : str , force_overwrite : Optional [bool ] = False
862
+ ) -> None :
875
863
"""
876
- Downloads the model from model_dir to target_dir based on model_id.
864
+ Downloads the model artifacts from model catalog to target_dir based on model_id.
877
865
878
866
Parameters
879
867
----------
@@ -883,46 +871,89 @@ def download_model(
883
871
The target location of model after download.
884
872
force_overwrite: bool
885
873
Overwrite target_dir if exists.
886
- install_libs: bool, default: False
887
- Install the libraries specified in ds-requirements.txt which are missing in the current environment.
888
- conflict_strategy: ConflictStrategy, default: IGNORE
889
- Determines how to handle version conflicts between the current environment and requirements of
890
- model artifact.
891
- Valid values: "IGNORE", "UPDATE" or ConflictStrategy.
892
- IGNORE: Use the installed version in case of conflict
893
- UPDATE: Force update dependency to the version required by model artifact in case of conflict
874
+
875
+ Raises
876
+ ------
877
+ ValueError
878
+ If targed dir not exists.
879
+ KeyError
880
+ If model id not found.
894
881
895
882
Returns
896
883
-------
897
- ModelArtifact
898
- A ModelArtifact instance.
884
+ None
885
+ Nothing
899
886
"""
900
887
if os .path .exists (target_dir ) and os .listdir (target_dir ):
901
888
if not force_overwrite :
902
889
raise ValueError (
903
- "Target directory already exists. Set 'force_overwrite' to overwrite."
890
+ "Target directory already exists. "
891
+ "Set `force_overwrite` to overwrite."
904
892
)
905
893
shutil .rmtree (target_dir )
906
894
907
895
try :
908
896
zip_contents = self .ds_client .get_model_artifact_content (
909
897
model_id
910
898
).data .content
911
- except ServiceError as se :
912
- if se .status == 404 :
913
- raise KeyError (se .message ) from se
899
+ except ServiceError as ex :
900
+ if ex .status == 404 :
901
+ raise KeyError (ex .message ) from ex
914
902
else :
915
903
raise
916
904
zip_file_path = os .path .join (
917
905
"/tmp" , "saved_model_" + str (uuid .uuid4 ()) + ".zip"
918
906
)
907
+
919
908
# write contents to zip file
920
909
with open (zip_file_path , "wb" ) as zip_file :
921
910
zip_file .write (zip_contents )
911
+
922
912
# Extract all the contents of zip file in target directory
923
913
with ZipFile (zip_file_path ) as zip_file :
924
914
zip_file .extractall (target_dir )
915
+
925
916
os .remove (zip_file_path )
917
+
918
+ @deprecated (
919
+ "2.5.9" ,
920
+ details = "Instead use `ads.common.model_artifact.ModelArtifact.from_model_catalog()`." ,
921
+ )
922
+ def download_model (
923
+ self ,
924
+ model_id : str ,
925
+ target_dir : str ,
926
+ force_overwrite : bool = False ,
927
+ install_libs : bool = False ,
928
+ conflict_strategy = ConflictStrategy .IGNORE ,
929
+ ):
930
+ """
931
+ Downloads the model from model_dir to target_dir based on model_id.
932
+
933
+ Parameters
934
+ ----------
935
+ model_id: str
936
+ The OCID of the model to download.
937
+ target_dir: str
938
+ The target location of model after download.
939
+ force_overwrite: bool
940
+ Overwrite target_dir if exists.
941
+ install_libs: bool, default: False
942
+ Install the libraries specified in ds-requirements.txt which are missing in the current environment.
943
+ conflict_strategy: ConflictStrategy, default: IGNORE
944
+ Determines how to handle version conflicts between the current environment and requirements of
945
+ model artifact.
946
+ Valid values: "IGNORE", "UPDATE" or ConflictStrategy.
947
+ IGNORE: Use the installed version in case of conflict
948
+ UPDATE: Force update dependency to the version required by model artifact in case of conflict
949
+
950
+ Returns
951
+ -------
952
+ ModelArtifact
953
+ A ModelArtifact instance.
954
+ """
955
+ self ._download_artifacts (model_id , target_dir , force_overwrite )
956
+
926
957
result = ModelArtifact (
927
958
target_dir ,
928
959
conflict_strategy = conflict_strategy ,
0 commit comments