Skip to content

[AQUA] Integrate aqua to use model group #1214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: feature/model_group
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,18 @@ class ModelFileDescription(Serializable):
class Config:
alias_generator = to_camel
extra = "allow"


class MemberModel(Serializable):
"""Describes the member model of a model group.

Attributes:
model_id (str): The id of member model.
inference_key (str): The inference key of member model.
"""

model_id: str = Field(..., description="The id of member model.")
inference_key: str = Field(None, description="The inference key of member model.")

class Config:
extra = "allow"
161 changes: 17 additions & 144 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import json
import os
import pathlib
import re
Expand Down Expand Up @@ -39,12 +38,11 @@
generate_tei_cmd_var,
get_artifact_path,
get_hf_model_info,
get_preferred_compatible_family,
list_os_files_with_extension,
load_config,
upload_folder,
)
from ads.aqua.config.container_config import AquaContainerConfig, Usage
from ads.aqua.config.container_config import AquaContainerConfig
from ads.aqua.constants import (
AQUA_MODEL_ARTIFACT_CONFIG,
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
Expand Down Expand Up @@ -79,7 +77,7 @@
AquaModelReadme,
AquaModelSummary,
ImportModelDetails,
ModelFileDescription,
MemberModel,
ModelValidationResult,
)
from ads.aqua.model.enums import MultiModelSupportedTaskType
Expand All @@ -102,6 +100,7 @@
)
from ads.model import DataScienceModel
from ads.model.common.utils import MetadataArtifactPathType
from ads.model.datascience_model_group import DataScienceModelGroup
from ads.model.model_metadata import (
MetadataCustomCategory,
ModelCustomMetadata,
Expand Down Expand Up @@ -235,20 +234,23 @@ def create(
def create_multi(
self,
models: List[AquaMultiModelRef],
model_custom_metadata: ModelCustomMetadata,
project_id: Optional[str] = None,
compartment_id: Optional[str] = None,
freeform_tags: Optional[Dict] = None,
defined_tags: Optional[Dict] = None,
source_models: Optional[Dict[str, DataScienceModel]] = None,
**kwargs, # noqa: ARG002
) -> DataScienceModel:
) -> DataScienceModelGroup:
"""
Creates a multi-model grouping using the provided model list.

Parameters
----------
models : List[AquaMultiModelRef]
List of AquaMultiModelRef instances for creating a multi-model group.
model_custom_metadata : ModelCustomMetadata
Custom metadata for creating model group.
project_id : Optional[str]
The project ID for the multi-model group.
compartment_id : Optional[str]
Expand All @@ -264,50 +266,10 @@ def create_multi(

Returns
-------
DataScienceModel
Instance of DataScienceModel object.
DataScienceModelGroup
Instance of DataScienceModelGroup object.
"""

if not models:
raise AquaValueError(
"Model list cannot be empty. Please provide at least one model for deployment."
)

display_name_list = []
model_file_description_list: List[ModelFileDescription] = []
model_custom_metadata = ModelCustomMetadata()

service_inference_containers = (
self.get_container_config().to_dict().get("inference")
)

supported_container_families = [
container_config_item.family
for container_config_item in service_inference_containers
if any(
usage.upper() in container_config_item.usages
for usage in [Usage.MULTI_MODEL, Usage.OTHER]
)
]

if not supported_container_families:
raise AquaValueError(
"Currently, there are no containers that support multi-model deployment."
)

selected_models_deployment_containers = set()

if not source_models:
# Collect all unique model IDs (including fine-tuned models)
source_model_ids = list(
{model_id for model in models for model_id in model.all_model_ids()}
)
logger.debug(
"Fetching source model metadata for model IDs: %s", source_model_ids
)

# Fetch source model metadata
source_models = self.get_multi_source(source_model_ids) or {}

# Process each model in the input list
for model in models:
Expand Down Expand Up @@ -337,11 +299,6 @@ def create_multi(
"Please register the model with a file description."
)

# Track model file description in a validated structure
model_file_description_list.append(
ModelFileDescription(**model_file_description)
)

# Ensure base model has a valid artifact
if not source_model.artifact:
logger.error(
Expand Down Expand Up @@ -396,11 +353,6 @@ def create_multi(
"Please register the model with a file description."
)

# Track model file description in a validated structure
model_file_description_list.append(
ModelFileDescription(**ft_model_file_description)
)

# Extract fine-tuned model path
_, fine_tune_path = extract_fine_tune_artifacts_path(
fine_tune_source_model
Expand All @@ -419,69 +371,12 @@ def create_multi(

display_name_list.append(ft_model.model_name)

# Validate deployment container consistency
deployment_container = source_model.custom_metadata_list.get(
ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
ModelCustomMetadataItem(
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER
),
).value

if deployment_container not in supported_container_families:
logger.error(
"Unsupported deployment container '%s' for model '%s'. Supported: %s",
deployment_container,
source_model.id,
supported_container_families,
)
raise AquaValueError(
f"Unsupported deployment container '{deployment_container}' for model '{source_model.id}'. "
f"Only {supported_container_families} are supported for multi-model deployments."
)

selected_models_deployment_containers.add(deployment_container)

if not selected_models_deployment_containers:
raise AquaValueError(
"None of the selected models are associated with a recognized container family. "
"Please review the selected models, or select a different group of models."
)

# Check if the all models in the group shares same container family
if len(selected_models_deployment_containers) > 1:
deployment_container = get_preferred_compatible_family(
selected_families=selected_models_deployment_containers
)
if not deployment_container:
raise AquaValueError(
"The selected models are associated with different container families: "
f"{list(selected_models_deployment_containers)}."
"For multi-model deployment, all models in the group must belong to the same container "
"family or to compatible container families."
)
else:
deployment_container = selected_models_deployment_containers.pop()

# Generate model group details
timestamp = datetime.now().strftime("%Y%m%d")
model_group_display_name = f"model_group_{timestamp}"
combined_models = ", ".join(display_name_list)
model_group_description = f"Multi-model grouping using {combined_models}."

# Add global metadata
model_custom_metadata.add(
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
value=deployment_container,
description=f"Inference container mapping for {model_group_display_name}",
category="Other",
)
model_custom_metadata.add(
key=ModelCustomMetadataFields.MULTIMODEL_GROUP_COUNT,
value=str(len(models)),
description="Number of models in the group.",
category="Other",
)

# Combine tags. The `Tags.AQUA_TAG` has been excluded, because we don't want to show
# the models created for multi-model purpose in the AQUA models list.
tags = {
Expand All @@ -491,46 +386,24 @@ def create_multi(
}

# Create multi-model group
custom_model = (
DataScienceModel()
custom_model_group = (
DataScienceModelGroup()
.with_compartment_id(compartment_id)
.with_project_id(project_id)
.with_display_name(model_group_display_name)
.with_description(model_group_description)
.with_freeform_tags(**tags)
.with_defined_tags(**(defined_tags or {}))
.with_custom_metadata_list(model_custom_metadata)
.with_member_models(
[MemberModel(model_id=model.model_id).model_dump() for model in models]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the MemberModel(inference_key = ) assigned? Is the inference key optional? (would this be the model name?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should default to using the model name if user doesn't provide one explicitly. Even though it will not be used in multi-model case, it's still a good practice to include name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inference_key is supposed to be the model name, but the model group has a length restiction for it, making it impossible to set for all models. Leave it empty here for now and will address it back.

)
)

# Update multi model file description to attach artifacts
custom_model.with_model_file_description(
json_dict=ModelFileDescription(
models=[
models
for model_file_description in model_file_description_list
for models in model_file_description.models
]
).model_dump(by_alias=True)
)

# Finalize creation
custom_model.create(model_by_reference=True)
custom_model_group.create()

logger.info(
f"Aqua Model '{custom_model.id}' created with models: {', '.join(display_name_list)}."
)

# Create custom metadata for multi model metadata
custom_model.create_custom_metadata_artifact(
metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA,
artifact_path_or_content=json.dumps(
[model.model_dump() for model in models]
).encode(),
path_type=MetadataArtifactPathType.CONTENT,
)

logger.debug(
f"Multi model metadata uploaded for Aqua model: {custom_model.id}."
f"Aqua Model Group'{custom_model_group.id}' created with models: {', '.join(display_name_list)}."
)

# Track telemetry event
Expand All @@ -540,7 +413,7 @@ def create_multi(
detail=combined_models,
)

return custom_model
return custom_model_group

@telemetry(entry_point="plugin=model&action=get", name="aqua")
def get(self, model_id: str) -> "AquaModel":
Expand Down
Loading