-
Notifications
You must be signed in to change notification settings - Fork 50
[AQUA] Integrate aqua to use model group #1214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feature/model_group
Are you sure you want to change the base?
Changes from all commits
f6cfcbe
00091a7
a0ef971
46d23f0
6c4da24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ | |
AQUA_MODEL_ARTIFACT_FILE, | ||
AQUA_MODEL_TOKENIZER_CONFIG, | ||
AQUA_MODEL_TYPE_CUSTOM, | ||
AQUA_MULTI_MODEL_CONFIG, | ||
HF_METADATA_FOLDER, | ||
LICENSE, | ||
MODEL_BY_REFERENCE_OSS_PATH_KEY, | ||
|
@@ -79,7 +80,7 @@ | |
AquaModelReadme, | ||
AquaModelSummary, | ||
ImportModelDetails, | ||
ModelFileDescription, | ||
MemberModel, | ||
ModelValidationResult, | ||
) | ||
from ads.aqua.model.enums import MultiModelSupportedTaskType | ||
|
@@ -102,6 +103,7 @@ | |
) | ||
from ads.model import DataScienceModel | ||
from ads.model.common.utils import MetadataArtifactPathType | ||
from ads.model.datascience_model_group import DataScienceModelGroup | ||
from ads.model.model_metadata import ( | ||
MetadataCustomCategory, | ||
ModelCustomMetadata, | ||
|
@@ -235,20 +237,27 @@ def create( | |
def create_multi( | ||
self, | ||
models: List[AquaMultiModelRef], | ||
create_deployment_details, | ||
model_config_summary, | ||
project_id: Optional[str] = None, | ||
compartment_id: Optional[str] = None, | ||
freeform_tags: Optional[Dict] = None, | ||
defined_tags: Optional[Dict] = None, | ||
source_models: Optional[Dict[str, DataScienceModel]] = None, | ||
**kwargs, # noqa: ARG002 | ||
) -> DataScienceModel: | ||
) -> DataScienceModelGroup: | ||
""" | ||
Creates a multi-model grouping using the provided model list. | ||
Parameters | ||
---------- | ||
models : List[AquaMultiModelRef] | ||
List of AquaMultiModelRef instances for creating a multi-model group. | ||
create_deployment_details : CreateModelDeploymentDetails | ||
An instance of CreateModelDeploymentDetails containing all required and optional | ||
fields for creating a model deployment via Aqua. | ||
model_config_summary : ModelConfigSummary | ||
Summary Model Deployment configuration for the group of models. | ||
project_id : Optional[str] | ||
The project ID for the multi-model group. | ||
compartment_id : Optional[str] | ||
|
@@ -264,8 +273,8 @@ def create_multi( | |
Returns | ||
------- | ||
DataScienceModel | ||
Instance of DataScienceModel object. | ||
DataScienceModelGroup | ||
Instance of DataScienceModelGroup object. | ||
""" | ||
|
||
if not models: | ||
|
@@ -274,7 +283,6 @@ def create_multi( | |
) | ||
|
||
display_name_list = [] | ||
model_file_description_list: List[ModelFileDescription] = [] | ||
model_custom_metadata = ModelCustomMetadata() | ||
|
||
service_inference_containers = ( | ||
|
@@ -337,11 +345,6 @@ def create_multi( | |
"Please register the model with a file description." | ||
) | ||
|
||
# Track model file description in a validated structure | ||
model_file_description_list.append( | ||
ModelFileDescription(**model_file_description) | ||
) | ||
|
||
# Ensure base model has a valid artifact | ||
if not source_model.artifact: | ||
logger.error( | ||
|
@@ -396,11 +399,6 @@ def create_multi( | |
"Please register the model with a file description." | ||
) | ||
|
||
# Track model file description in a validated structure | ||
model_file_description_list.append( | ||
ModelFileDescription(**ft_model_file_description) | ||
) | ||
|
||
# Extract fine-tuned model path | ||
_, fine_tune_path = extract_fine_tune_artifacts_path( | ||
fine_tune_source_model | ||
|
@@ -481,6 +479,22 @@ def create_multi( | |
description="Number of models in the group.", | ||
category="Other", | ||
) | ||
model_custom_metadata.add( | ||
key=AQUA_MULTI_MODEL_CONFIG, | ||
value=self._build_model_group_config( | ||
create_deployment_details=create_deployment_details, | ||
model_config_summary=model_config_summary, | ||
deployment_container=deployment_container, | ||
), | ||
description="Configs required to deploy multi models.", | ||
category="Other", | ||
) | ||
model_custom_metadata.add( | ||
key=ModelCustomMetadataFields.MULTIMODEL_METADATA, | ||
value=json.dumps([model.model_dump() for model in models]), | ||
description="Metadata to store user's multi model input.", | ||
category="Other", | ||
) | ||
|
||
# Combine tags. The `Tags.AQUA_TAG` has been excluded, because we don't want to show | ||
# the models created for multi-model purpose in the AQUA models list. | ||
|
@@ -491,46 +505,24 @@ def create_multi( | |
} | ||
|
||
# Create multi-model group | ||
custom_model = ( | ||
DataScienceModel() | ||
custom_model_group = ( | ||
DataScienceModelGroup() | ||
.with_compartment_id(compartment_id) | ||
.with_project_id(project_id) | ||
.with_display_name(model_group_display_name) | ||
.with_description(model_group_description) | ||
.with_freeform_tags(**tags) | ||
.with_defined_tags(**(defined_tags or {})) | ||
.with_custom_metadata_list(model_custom_metadata) | ||
.with_member_models( | ||
[MemberModel(model_id=model.model_id).model_dump() for model in models] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is the MemberModel(inference_key = ) assigned? Is the inference key optional? (would this be the model name?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should default to using the model name if user doesn't provide one explicitly. Even though it will not be used in multi-model case, it's still a good practice to include name. |
||
) | ||
) | ||
|
||
# Update multi model file description to attach artifacts | ||
custom_model.with_model_file_description( | ||
json_dict=ModelFileDescription( | ||
models=[ | ||
models | ||
for model_file_description in model_file_description_list | ||
for models in model_file_description.models | ||
] | ||
).model_dump(by_alias=True) | ||
) | ||
|
||
# Finalize creation | ||
custom_model.create(model_by_reference=True) | ||
custom_model_group.create() | ||
|
||
logger.info( | ||
f"Aqua Model '{custom_model.id}' created with models: {', '.join(display_name_list)}." | ||
) | ||
|
||
# Create custom metadata for multi model metadata | ||
custom_model.create_custom_metadata_artifact( | ||
metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA, | ||
artifact_path_or_content=json.dumps( | ||
[model.model_dump() for model in models] | ||
).encode(), | ||
path_type=MetadataArtifactPathType.CONTENT, | ||
) | ||
|
||
logger.debug( | ||
f"Multi model metadata uploaded for Aqua model: {custom_model.id}." | ||
f"Aqua Model Group'{custom_model_group.id}' created with models: {', '.join(display_name_list)}." | ||
) | ||
|
||
# Track telemetry event | ||
|
@@ -540,7 +532,33 @@ def create_multi( | |
detail=combined_models, | ||
) | ||
|
||
return custom_model | ||
return custom_model_group | ||
|
||
def _build_model_group_config( | ||
self, | ||
create_deployment_details, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Types? |
||
model_config_summary, | ||
deployment_container: str, | ||
) -> str: | ||
"""Builds model group config required to deploy multi models.""" | ||
container_type_key = ( | ||
create_deployment_details.container_family or deployment_container | ||
) | ||
container_config = self.get_container_config_item(container_type_key) | ||
container_spec = container_config.spec if container_config else UNKNOWN | ||
|
||
container_params = container_spec.cli_param if container_spec else UNKNOWN | ||
|
||
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put this at the top of this file |
||
|
||
multi_model_config = ModelGroupConfig.from_create_model_deployment_details( | ||
create_deployment_details, | ||
model_config_summary, | ||
container_type_key, | ||
container_params, | ||
) | ||
|
||
return multi_model_config.model_dump_json() | ||
|
||
@telemetry(entry_point="plugin=model&action=get", name="aqua") | ||
def get(self, model_id: str) -> "AquaModel": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Types?