diff --git a/vertexai/tuning/__init__.py b/vertexai/tuning/__init__.py index 9b8362969b..2826753903 100644 --- a/vertexai/tuning/__init__.py +++ b/vertexai/tuning/__init__.py @@ -16,8 +16,10 @@ # We just want to re-export certain classes # pylint: disable=g-multiple-import,g-importing-member +from vertexai.tuning._tuning import SourceModel from vertexai.tuning._tuning import TuningJob __all__ = [ + "SourceModel", "TuningJob", ] diff --git a/vertexai/tuning/_supervised_tuning.py b/vertexai/tuning/_supervised_tuning.py index a8b5b1ed17..7ab280b893 100644 --- a/vertexai/tuning/_supervised_tuning.py +++ b/vertexai/tuning/_supervised_tuning.py @@ -21,11 +21,15 @@ ) from vertexai import generative_models from vertexai.tuning import _tuning +from vertexai.tuning import SourceModel def train( *, - source_model: Union[str, generative_models.GenerativeModel], + source_model: Union[ + str, + generative_models.GenerativeModel, + SourceModel], train_dataset: str, validation_dataset: Optional[str] = None, tuned_model_display_name: Optional[str] = None, @@ -33,6 +37,7 @@ def train( learning_rate_multiplier: Optional[float] = None, adapter_size: Optional[Literal[1, 4, 8, 16]] = None, labels: Optional[Dict[str, str]] = None, + output_uri: Optional[str] = None, ) -> "SupervisedTuningJob": """Tunes a model using supervised training. @@ -49,7 +54,7 @@ def train( learning_rate_multiplier: Learning rate multiplier for tuning. adapter_size: Adapter size for tuning. labels: User-defined metadata to be associated with trained models - + output_uri: The Google Cloud Storage location to write the model artifacts. Returns: A `TuningJob` object. """ @@ -94,6 +99,7 @@ def train( tuning_spec=supervised_tuning_spec, tuned_model_display_name=tuned_model_display_name, labels=labels, + output_uri=output_uri, ) ) _ipython_utils.display_model_tuning_button(supervised_tuning_job) diff --git a/vertexai/tuning/_tuning.py b/vertexai/tuning/_tuning.py index f080608dd9..1e0e34731d 100644 --- a/vertexai/tuning/_tuning.py +++ b/vertexai/tuning/_tuning.py @@ -42,6 +42,42 @@ _LOGGER = aiplatform_base.Logger(__name__) +class SourceModel: + r"""A model that is used in managed OSS supervised tuning. + + Usage: + ``` + model = SourceModel( + base_model="meta/llama3.1-8b", + custom_base_model="gs://user-bucket/custom-weights", + ) + sft_tuning_job = sft.train( + source_model=model, + train_dataset="gs://my-bucket/train.jsonl", + validation_dataset="gs://my-bucket/validation.jsonl", + epochs=4, + learning_rate_multiplier=0.5, + tuned_model_display_name="my-tuned-model", + output_uri="gs://user-bucket/tuned-model" + ) + + while not sft_tuning_job.has_ended: + time.sleep(60) + sft_tuning_job.refresh() + + tuned_model = aiplatform.Model(sft_tuning_job.tuned_model_name) + ``` + """ + + def __init__( + self, + base_model: str, + custom_base_model: str = "", + ): + r"""Initializes SourceModel.""" + self.base_model = base_model + self.custom_base_model = custom_base_model + class TuningJobClientWithOverride(aiplatform_utils.ClientWithOverride): _is_temporary = True @@ -132,7 +168,7 @@ def tuning_data_statistics(self) -> gca_tuning_job_types.TuningDataStats: def _create( cls, *, - base_model: str, + base_model: Union[str, SourceModel], tuning_spec: Union[ gca_tuning_job_types.SupervisedTuningSpec, gca_tuning_job_types.DistillationSpec, @@ -143,13 +179,13 @@ def _create( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, + output_uri: Optional[str] = None, ) -> "TuningJob": r"""Submits TuningJob. Args: - base_model (str): - Model name for tuning, e.g., "gemini-1.0-pro" - or "gemini-1.0-pro-001". + base_model: Model for tuning: + Supported types: str, SourceModel. This field is a member of `oneof`_ ``source_model``. tuning_spec: Tuning Spec for Fine Tuning. @@ -178,6 +214,7 @@ def _create( Overrides location set in aiplatform.init. credentials: Custom credentials to use to call tuning job service. Overrides credentials set in aiplatform.init. + output_uri: The Google Cloud Storage location to write the artifacts. Returns: Submitted TuningJob. @@ -191,17 +228,26 @@ def _create( tuned_model_display_name = cls._generate_display_name() gca_tuning_job = gca_tuning_job_types.TuningJob( - base_model=base_model, tuned_model_display_name=tuned_model_display_name, description=description, labels=labels, - # The tuning_spec one_of is set later + output_uri=output_uri, ) if isinstance(tuning_spec, gca_tuning_job_types.SupervisedTuningSpec): gca_tuning_job.supervised_tuning_spec = tuning_spec + if isinstance(base_model, SourceModel): + gca_tuning_job.base_model = base_model.base_model + gca_tuning_job.custom_base_model = base_model.custom_base_model + else: + gca_tuning_job.base_model = base_model elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec): gca_tuning_job.distillation_spec = tuning_spec + if isinstance(base_model, SourceModel): + raise RuntimeError( + "Distillation is not supported for custom models." + ) + gca_tuning_job.base_model = base_model else: raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")