@@ -57,22 +57,45 @@ def log_results(
5757 print (out )
5858
5959
60- def evaluate (
60+ async def evaluate (
6161 vf_env_id : str ,
6262 vf_env_args : dict ,
63- model_name : str ,
63+ model_name : str | None ,
6464 num_examples : int ,
6565 rollouts_per_example : int ,
6666 max_concurrent : int ,
6767 max_tokens : int ,
6868 temperature : float ,
69+ model_path : str | None = None ,
6970):
71+ service = tinker .ServiceClient ()
72+
73+ # If model_path is provided, get the base model from the training run
74+ if model_path is not None :
75+ rest_client = service .create_rest_client ()
76+ training_run = await rest_client .get_training_run_by_tinker_path_async (model_path )
77+ if model_name :
78+ if model_name != training_run .base_model :
79+ raise ValueError (
80+ f"Model name { model_name } does not match training run base model { training_run .base_model } "
81+ )
82+ else :
83+ model_name = training_run .base_model
84+
85+ if model_name is None :
86+ raise ValueError ("model_name or model_path must be provided" )
87+
7088 env = vf .load_environment (vf_env_id , ** vf_env_args )
7189 tokenizer = get_tokenizer (model_name )
7290 renderer_name = model_info .get_recommended_renderer_name (model_name )
7391 renderer = renderers .get_renderer (renderer_name , tokenizer )
74- service = tinker .ServiceClient ()
75- sampling = service .create_sampling_client (base_model = model_name )
92+
93+ # Create sampling client from checkpoint path or base model
94+ if model_path :
95+ sampling = service .create_sampling_client (model_path = model_path , base_model = model_name )
96+ else :
97+ sampling = service .create_sampling_client (base_model = model_name )
98+
7699 client = TinkerAsyncOpenAIClient (sampling , renderer , tokenizer )
77100 start_time = time .time ()
78101 results = env .evaluate_sync (
@@ -95,11 +118,13 @@ def evaluate(
95118 rollouts_per_example ,
96119 end_time - start_time ,
97120 )
121+ return results
98122
99123
100124@chz .chz
101125class CLIConfig :
102- model_name : str = "Qwen/Qwen3-4B-Instruct-2507"
126+ model_name : str | None = None # Base model name (auto-detected from checkpoint if not provided)
127+ model_path : str | None = None # Path to checkpoint (e.g., from checkpoints.jsonl sampler_path)
103128 vf_env_id : str = "reverse-text"
104129 vf_env_args : str | None = None # JSON string
105130 num_examples : int = 5
@@ -111,7 +136,7 @@ class CLIConfig:
111136
112137async def cli_main (cfg : CLIConfig ):
113138 env_args = json .loads (cfg .vf_env_args ) if cfg .vf_env_args else {}
114- return evaluate (
139+ return await evaluate (
115140 vf_env_id = cfg .vf_env_id ,
116141 vf_env_args = env_args ,
117142 model_name = cfg .model_name ,
@@ -120,6 +145,7 @@ async def cli_main(cfg: CLIConfig):
120145 max_concurrent = cfg .max_concurrent ,
121146 max_tokens = cfg .max_tokens ,
122147 temperature = cfg .temperature ,
148+ model_path = cfg .model_path ,
123149 )
124150
125151
0 commit comments