Skip to content

Accept strings for checkpoint type on download #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 13, 2025
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "together"
version = "1.5.8"
version = "1.5.9"
authors = ["Together AI <[email protected]>"]
description = "Python client for Together's Cloud Platform!"
readme = "README.md"
Expand Down
25 changes: 18 additions & 7 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from pathlib import Path
from typing import Dict, List, Literal
from typing import List, Dict, Literal

from rich import print as rprint

Expand Down Expand Up @@ -545,7 +545,7 @@ def download(
*,
output: Path | str | None = None,
checkpoint_step: int | None = None,
checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
checkpoint_type: DownloadCheckpointType | str = DownloadCheckpointType.DEFAULT,
) -> FinetuneDownloadResult:
"""
Downloads compressed fine-tuned model or checkpoint to local disk.
Expand All @@ -558,7 +558,7 @@ def download(
Defaults to None.
checkpoint_step (int, optional): Specifies step number for checkpoint to download.
Defaults to -1 (download the final model)
checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download.
checkpoint_type (CheckpointType | str, optional): Specifies which checkpoint to download.
Defaults to CheckpointType.DEFAULT.

Returns:
Expand All @@ -582,6 +582,16 @@ def download(

ft_job = self.retrieve(id)

# convert str to DownloadCheckpointType
if isinstance(checkpoint_type, str):
try:
checkpoint_type = DownloadCheckpointType(checkpoint_type.lower())
except ValueError:
enum_strs = ", ".join(e.value for e in DownloadCheckpointType)
raise ValueError(
f"Invalid checkpoint type: {checkpoint_type}. Choose one of {{{enum_strs}}}."
)

if isinstance(ft_job.training_type, FullTrainingType):
if checkpoint_type != DownloadCheckpointType.DEFAULT:
raise ValueError(
Expand All @@ -592,10 +602,11 @@ def download(
if checkpoint_type == DownloadCheckpointType.DEFAULT:
checkpoint_type = DownloadCheckpointType.MERGED

if checkpoint_type == DownloadCheckpointType.MERGED:
url += f"&checkpoint={DownloadCheckpointType.MERGED.value}"
elif checkpoint_type == DownloadCheckpointType.ADAPTER:
url += f"&checkpoint={DownloadCheckpointType.ADAPTER.value}"
if checkpoint_type in {
DownloadCheckpointType.MERGED,
DownloadCheckpointType.ADAPTER,
}:
url += f"&checkpoint={checkpoint_type.value}"
else:
raise ValueError(
f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"
Expand Down