diff --git a/pyproject.toml b/pyproject.toml index 2b5ccaf..a3f5353 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.5.8" +version = "1.5.9" authors = ["Together AI "] description = "Python client for Together's Cloud Platform!" readme = "README.md" diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 8d0bf97..1f494fc 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -2,7 +2,7 @@ import re from pathlib import Path -from typing import Dict, List, Literal +from typing import List, Dict, Literal from rich import print as rprint @@ -545,7 +545,7 @@ def download( *, output: Path | str | None = None, checkpoint_step: int | None = None, - checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT, + checkpoint_type: DownloadCheckpointType | str = DownloadCheckpointType.DEFAULT, ) -> FinetuneDownloadResult: """ Downloads compressed fine-tuned model or checkpoint to local disk. @@ -558,7 +558,7 @@ def download( Defaults to None. checkpoint_step (int, optional): Specifies step number for checkpoint to download. Defaults to -1 (download the final model) - checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download. + checkpoint_type (CheckpointType | str, optional): Specifies which checkpoint to download. Defaults to CheckpointType.DEFAULT. Returns: @@ -582,6 +582,16 @@ def download( ft_job = self.retrieve(id) + # convert str to DownloadCheckpointType + if isinstance(checkpoint_type, str): + try: + checkpoint_type = DownloadCheckpointType(checkpoint_type.lower()) + except ValueError: + enum_strs = ", ".join(e.value for e in DownloadCheckpointType) + raise ValueError( + f"Invalid checkpoint type: {checkpoint_type}. Choose one of {{{enum_strs}}}." + ) + if isinstance(ft_job.training_type, FullTrainingType): if checkpoint_type != DownloadCheckpointType.DEFAULT: raise ValueError( @@ -592,10 +602,11 @@ def download( if checkpoint_type == DownloadCheckpointType.DEFAULT: checkpoint_type = DownloadCheckpointType.MERGED - if checkpoint_type == DownloadCheckpointType.MERGED: - url += f"&checkpoint={DownloadCheckpointType.MERGED.value}" - elif checkpoint_type == DownloadCheckpointType.ADAPTER: - url += f"&checkpoint={DownloadCheckpointType.ADAPTER.value}" + if checkpoint_type in { + DownloadCheckpointType.MERGED, + DownloadCheckpointType.ADAPTER, + }: + url += f"&checkpoint={checkpoint_type.value}" else: raise ValueError( f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"