22
33import re
44from pathlib import Path
5- from typing import List , Literal
5+ from typing import Dict , List , Literal
66
77from rich import print as rprint
88
3030 TrainingMethodSFT ,
3131 TrainingType ,
3232)
33- from together .types .finetune import (
34- DownloadCheckpointType ,
35- FinetuneEvent ,
36- FinetuneEventType ,
37- )
38- from together .utils import (
39- get_event_step ,
40- log_warn_once ,
41- normalize_key ,
42- )
33+ from together .types .finetune import DownloadCheckpointType
34+ from together .utils import log_warn_once , normalize_key
4335
4436
4537_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
@@ -222,68 +214,38 @@ def create_finetune_request(
222214 return finetune_request
223215
224216
225- def _process_checkpoints_from_events (
226- events : List [FinetuneEvent ], id : str
217+ def _parse_raw_checkpoints (
218+ checkpoints : List [Dict [ str , str ] ], id : str
227219) -> List [FinetuneCheckpoint ]:
228220 """
229- Helper function to process events and create checkpoint list.
221+ Helper function to process raw checkpoints and create checkpoint list.
230222
231223 Args:
232- events (List[FinetuneEvent] ): List of fine-tune events to process
224+ checkpoints (List[Dict[str, str]] ): List of raw checkpoints metadata
233225 id (str): Fine-tune job ID
234226
235227 Returns:
236228 List[FinetuneCheckpoint]: List of available checkpoints
237229 """
238- checkpoints : List [FinetuneCheckpoint ] = []
239-
240- for event in events :
241- event_type = event .type
242-
243- if event_type == FinetuneEventType .CHECKPOINT_SAVE :
244- step = get_event_step (event )
245- checkpoint_name = f"{ id } :{ step } " if step is not None else id
246-
247- checkpoints .append (
248- FinetuneCheckpoint (
249- type = (
250- f"Intermediate (step { step } )"
251- if step is not None
252- else "Intermediate"
253- ),
254- timestamp = event .created_at ,
255- name = checkpoint_name ,
256- )
257- )
258- elif event_type == FinetuneEventType .JOB_COMPLETE :
259- if hasattr (event , "model_path" ):
260- checkpoints .append (
261- FinetuneCheckpoint (
262- type = (
263- "Final Merged"
264- if hasattr (event , "adapter_path" )
265- else "Final"
266- ),
267- timestamp = event .created_at ,
268- name = id ,
269- )
270- )
271230
272- if hasattr (event , "adapter_path" ):
273- checkpoints .append (
274- FinetuneCheckpoint (
275- type = (
276- "Final Adapter" if hasattr (event , "model_path" ) else "Final"
277- ),
278- timestamp = event .created_at ,
279- name = id ,
280- )
281- )
231+ parsed_checkpoints = []
232+ for checkpoint in checkpoints :
233+ step = checkpoint ["step" ]
234+ checkpoint_type = checkpoint ["checkpoint_type" ]
235+ checkpoint_name = (
236+ f"{ id } :{ step } " if "intermediate" in checkpoint_type .lower () else id
237+ )
282238
283- # Sort by timestamp (newest first)
284- checkpoints .sort (key = lambda x : x .timestamp , reverse = True )
239+ parsed_checkpoints .append (
240+ FinetuneCheckpoint (
241+ type = checkpoint_type ,
242+ timestamp = checkpoint ["created_at" ],
243+ name = checkpoint_name ,
244+ )
245+ )
285246
286- return checkpoints
247+ parsed_checkpoints .sort (key = lambda x : x .timestamp , reverse = True )
248+ return parsed_checkpoints
287249
288250
289251class FineTuning :
@@ -561,8 +523,21 @@ def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
561523 Returns:
562524 List[FinetuneCheckpoint]: List of available checkpoints
563525 """
564- events = self .list_events (id ).data or []
565- return _process_checkpoints_from_events (events , id )
526+ requestor = api_requestor .APIRequestor (
527+ client = self ._client ,
528+ )
529+
530+ response , _ , _ = requestor .request (
531+ options = TogetherRequest (
532+ method = "GET" ,
533+ url = f"fine-tunes/{ id } /checkpoints" ,
534+ ),
535+ stream = False ,
536+ )
537+ assert isinstance (response , TogetherResponse )
538+
539+ raw_checkpoints = response .data ["data" ]
540+ return _parse_raw_checkpoints (raw_checkpoints , id )
566541
567542 def download (
568543 self ,
@@ -936,11 +911,9 @@ async def list_events(self, id: str) -> FinetuneListEvents:
936911 ),
937912 stream = False ,
938913 )
914+ assert isinstance (events_response , TogetherResponse )
939915
940- # FIXME: API returns "data" field with no object type (should be "list")
941- events_list = FinetuneListEvents (object = "list" , ** events_response .data )
942-
943- return events_list
916+ return FinetuneListEvents (** events_response .data )
944917
945918 async def list_checkpoints (self , id : str ) -> List [FinetuneCheckpoint ]:
946919 """
@@ -950,11 +923,23 @@ async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
950923 id (str): Unique identifier of the fine-tune job to list checkpoints for
951924
952925 Returns:
953- List[FinetuneCheckpoint]: Object containing list of available checkpoints
926+ List[FinetuneCheckpoint]: List of available checkpoints
954927 """
955- events_list = await self .list_events (id )
956- events = events_list .data or []
957- return _process_checkpoints_from_events (events , id )
928+ requestor = api_requestor .APIRequestor (
929+ client = self ._client ,
930+ )
931+
932+ response , _ , _ = await requestor .arequest (
933+ options = TogetherRequest (
934+ method = "GET" ,
935+ url = f"fine-tunes/{ id } /checkpoints" ,
936+ ),
937+ stream = False ,
938+ )
939+ assert isinstance (response , TogetherResponse )
940+
941+ raw_checkpoints = response .data ["data" ]
942+ return _parse_raw_checkpoints (raw_checkpoints , id )
958943
959944 async def download (
960945 self , id : str , * , output : str | None = None , checkpoint_step : int = - 1
0 commit comments