Skip to content

Accept strings for checkpoint type on download #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 13, 2025
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "together"
version = "1.5.8"
version = "1.5.9"
authors = ["Together AI <[email protected]>"]
description = "Python client for Together's Cloud Platform!"
readme = "README.md"
25 changes: 18 additions & 7 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import re
from pathlib import Path
from typing import Dict, List, Literal
from typing import List, Dict, Literal

from rich import print as rprint

@@ -545,7 +545,7 @@ def download(
*,
output: Path | str | None = None,
checkpoint_step: int | None = None,
checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
checkpoint_type: DownloadCheckpointType | str = DownloadCheckpointType.DEFAULT,
) -> FinetuneDownloadResult:
"""
Downloads compressed fine-tuned model or checkpoint to local disk.
@@ -558,7 +558,7 @@ def download(
Defaults to None.
checkpoint_step (int, optional): Specifies step number for checkpoint to download.
Defaults to -1 (download the final model)
checkpoint_type (CheckpointType, optional): Specifies which checkpoint to download.
checkpoint_type (CheckpointType | str, optional): Specifies which checkpoint to download.
Defaults to CheckpointType.DEFAULT.

Returns:
@@ -582,6 +582,16 @@ def download(

ft_job = self.retrieve(id)

# convert str to DownloadCheckpointType
if isinstance(checkpoint_type, str):
try:
checkpoint_type = DownloadCheckpointType(checkpoint_type.lower())
except ValueError:
enum_strs = ", ".join(e.value for e in DownloadCheckpointType)
raise ValueError(
f"Invalid checkpoint type: {checkpoint_type}. Choose one of {{{enum_strs}}}."
)

if isinstance(ft_job.training_type, FullTrainingType):
if checkpoint_type != DownloadCheckpointType.DEFAULT:
raise ValueError(
@@ -592,10 +602,11 @@ def download(
if checkpoint_type == DownloadCheckpointType.DEFAULT:
checkpoint_type = DownloadCheckpointType.MERGED

if checkpoint_type == DownloadCheckpointType.MERGED:
url += f"&checkpoint={DownloadCheckpointType.MERGED.value}"
elif checkpoint_type == DownloadCheckpointType.ADAPTER:
url += f"&checkpoint={DownloadCheckpointType.ADAPTER.value}"
if checkpoint_type in {
DownloadCheckpointType.MERGED,
DownloadCheckpointType.ADAPTER,
}:
url += f"&checkpoint={checkpoint_type.value}"
else:
raise ValueError(
f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"