Skip to content

Commit 40cf35f

Browse files
tmabrahamjoschu
andauthored
add support for model_path to verifiers evaluate script (#64)
Co-authored-by: John Schulman <[email protected]>
1 parent 41f93a8 commit 40cf35f

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

tinker_cookbook/recipes/verifiers_rl/evaluate.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,45 @@ def log_results(
5757
print(out)
5858

5959

60-
def evaluate(
60+
async def evaluate(
6161
vf_env_id: str,
6262
vf_env_args: dict,
63-
model_name: str,
63+
model_name: str | None,
6464
num_examples: int,
6565
rollouts_per_example: int,
6666
max_concurrent: int,
6767
max_tokens: int,
6868
temperature: float,
69+
model_path: str | None = None,
6970
):
71+
service = tinker.ServiceClient()
72+
73+
# If model_path is provided, get the base model from the training run
74+
if model_path is not None:
75+
rest_client = service.create_rest_client()
76+
training_run = await rest_client.get_training_run_by_tinker_path_async(model_path)
77+
if model_name:
78+
if model_name != training_run.base_model:
79+
raise ValueError(
80+
f"Model name {model_name} does not match training run base model {training_run.base_model}"
81+
)
82+
else:
83+
model_name = training_run.base_model
84+
85+
if model_name is None:
86+
raise ValueError("model_name or model_path must be provided")
87+
7088
env = vf.load_environment(vf_env_id, **vf_env_args)
7189
tokenizer = get_tokenizer(model_name)
7290
renderer_name = model_info.get_recommended_renderer_name(model_name)
7391
renderer = renderers.get_renderer(renderer_name, tokenizer)
74-
service = tinker.ServiceClient()
75-
sampling = service.create_sampling_client(base_model=model_name)
92+
93+
# Create sampling client from checkpoint path or base model
94+
if model_path:
95+
sampling = service.create_sampling_client(model_path=model_path, base_model=model_name)
96+
else:
97+
sampling = service.create_sampling_client(base_model=model_name)
98+
7699
client = TinkerAsyncOpenAIClient(sampling, renderer, tokenizer)
77100
start_time = time.time()
78101
results = env.evaluate_sync(
@@ -95,11 +118,13 @@ def evaluate(
95118
rollouts_per_example,
96119
end_time - start_time,
97120
)
121+
return results
98122

99123

100124
@chz.chz
101125
class CLIConfig:
102-
model_name: str = "Qwen/Qwen3-4B-Instruct-2507"
126+
model_name: str | None = None # Base model name (auto-detected from checkpoint if not provided)
127+
model_path: str | None = None # Path to checkpoint (e.g., from checkpoints.jsonl sampler_path)
103128
vf_env_id: str = "reverse-text"
104129
vf_env_args: str | None = None # JSON string
105130
num_examples: int = 5
@@ -111,7 +136,7 @@ class CLIConfig:
111136

112137
async def cli_main(cfg: CLIConfig):
113138
env_args = json.loads(cfg.vf_env_args) if cfg.vf_env_args else {}
114-
return evaluate(
139+
return await evaluate(
115140
vf_env_id=cfg.vf_env_id,
116141
vf_env_args=env_args,
117142
model_name=cfg.model_name,
@@ -120,6 +145,7 @@ async def cli_main(cfg: CLIConfig):
120145
max_concurrent=cfg.max_concurrent,
121146
max_tokens=cfg.max_tokens,
122147
temperature=cfg.temperature,
148+
model_path=cfg.model_path,
123149
)
124150

125151

0 commit comments

Comments
 (0)