Skip to content

Commit f624b90

Browse files
authored
supervised/train.py: shift evals so they occur after 0, k, 2k, ... steps instead of 1, k+1, 2k+1, etc where k=config.eval_every (#90)
1 parent 6c9f7a4 commit f624b90

File tree

2 files changed

+56
-16
lines changed

2 files changed

+56
-16
lines changed

tinker_cookbook/supervised/train.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,25 @@ class SubmittedBatch:
8383
epoch_idx: int
8484
batch_idx: int
8585
batch_start_time: float
86+
eval_metrics: dict[str, float] | None = None
87+
infrequent_eval_metrics: dict[str, float] | None = None
8688

8789

8890
async def run_evals(
8991
evaluators: list[Evaluator],
9092
training_client: tinker.TrainingClient,
9193
step: int,
9294
) -> dict[str, float]:
93-
"""Run all evaluators and return metrics with test/ prefix."""
95+
"""Evaluate the current model weights and prefix results with ``test/``.
96+
97+
The helper is called immediately before optimizer step `step` is submitted, so it
98+
measures the weights produced after step `step-1` (or the initial weights for step 0).
99+
Training-client evaluators run against the mutable training client, while sampling
100+
evaluators request a fresh `SamplingClient` snapshot via
101+
`save_weights_and_get_sampling_client_async` to ensure their work uses a fixed
102+
checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next
103+
to the same-step training metrics.
104+
"""
94105
metrics = {}
95106
sampling_client = None
96107

@@ -100,6 +111,7 @@ async def run_evals(
100111
elif isinstance(evaluator, SamplingClientEvaluator):
101112
# Create sampling client lazily, only when needed
102113
if sampling_client is None:
114+
# Snapshot the current pre-step weights and create a new sampling client.
103115
sampling_client = await training_client.save_weights_and_get_sampling_client_async(
104116
f"evals_step_{step}"
105117
)
@@ -114,14 +126,28 @@ async def run_evals(
114126

115127

116128
async def main(config: Config):
117-
"""Main training function that runs the complete training process."""
129+
"""Run the standard supervised learning loop used by the supervised recipes.
130+
131+
Responsibilities:
132+
1. Initialize logging, build the dataset/evaluator objects, construct (or resume) the
133+
training client, and determine the ``epoch``/``batch`` indices to start from.
134+
2. Iterate over batches: fetch data, optionally run evaluations before submitting the
135+
optimizer step (so they observe pre-step weights), issue `forward_backward` and
136+
`optim_step` requests, and log metrics once the futures resolve.
137+
3. Save checkpoints at the configured cadence so runs can resume or export weights,
138+
then emit a final checkpoint when training completes.
139+
140+
Training and evaluation metrics share the same ``step`` index to keep dashboards easy
141+
to read.
142+
"""
118143
resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)
119144
if resume_info:
120145
start_epoch = resume_info["epoch"]
121146
start_batch = resume_info["batch"]
122147
else:
123148
start_epoch = 0
124149
start_batch = 0
150+
# (start_epoch, start_batch) now represent the next batch to execute if resuming.
125151

126152
ml_logger = ml_log.setup_logging(
127153
log_dir=config.log_path,
@@ -191,6 +217,23 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
191217
if data:
192218
logger.info(colorize_example(data[0], tokenizer))
193219

220+
# Trigger evaluations BEFORE submitting training operations so they snapshot pre-step weights
221+
eval_metrics = None
222+
if evaluators and config.eval_every > 0 and step % config.eval_every == 0:
223+
with timed("evals", metrics):
224+
eval_metrics = await run_evals(evaluators, training_client, step)
225+
226+
infrequent_eval_metrics = None
227+
if (
228+
infrequent_evaluators
229+
and config.infrequent_eval_every > 0
230+
and step % config.infrequent_eval_every == 0
231+
):
232+
with timed("infrequent_evals", metrics):
233+
infrequent_eval_metrics = await run_evals(
234+
infrequent_evaluators, training_client, step
235+
)
236+
194237
fwd_bwd_future = await training_client.forward_backward_async(data, loss_fn="cross_entropy")
195238
optim_step_future = await training_client.optim_step_async(adam_params)
196239

@@ -203,6 +246,8 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
203246
epoch_idx=epoch_idx,
204247
batch_idx=batch_idx,
205248
batch_start_time=batch_start_time,
249+
eval_metrics=eval_metrics,
250+
infrequent_eval_metrics=infrequent_eval_metrics,
206251
)
207252

208253
async def finish_batch(submitted: SubmittedBatch):
@@ -211,6 +256,8 @@ async def finish_batch(submitted: SubmittedBatch):
211256

212257
if submitted.step % config.save_every == 0 and submitted.step > 0:
213258
with timed("save_checkpoint", metrics):
259+
# Enqueue a checkpoint save after the forward/backward and optimizer
260+
# requests for this step; the snapshot will reflect post-step weights.
214261
await checkpoint_utils.save_checkpoint_async(
215262
training_client=training_client,
216263
name=f"{submitted.step:06d}",
@@ -237,22 +284,14 @@ async def finish_batch(submitted: SubmittedBatch):
237284
)
238285
metrics["time/total"] = time.time() - submitted.batch_start_time
239286

240-
if evaluators and config.eval_every > 0 and submitted.step % config.eval_every == 0:
241-
with timed("evals", metrics):
242-
eval_metrics = await run_evals(evaluators, training_client, submitted.step)
243-
metrics.update(eval_metrics)
287+
# Merge evaluation metrics gathered before the training step was submitted
288+
if submitted.eval_metrics is not None:
289+
metrics.update(submitted.eval_metrics)
244290

245-
if (
246-
infrequent_evaluators
247-
and config.infrequent_eval_every > 0
248-
and submitted.step % config.infrequent_eval_every == 0
249-
):
250-
with timed("infrequent_evals", metrics):
251-
eval_metrics = await run_evals(
252-
infrequent_evaluators, training_client, submitted.step
253-
)
254-
metrics.update(eval_metrics)
291+
if submitted.infrequent_eval_metrics is not None:
292+
metrics.update(submitted.infrequent_eval_metrics)
255293

294+
# Emit all metrics for this step (train and eval) on the `submitted.step` row.
256295
ml_logger.log_metrics(metrics=metrics, step=submitted.step)
257296

258297
pending_batch: SubmittedBatch | None = None

tinker_cookbook/tests/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,6 @@ def log_metrics(metrics: dict[str, Any], step: int):
4242

4343
mock_logger.log_metrics = log_metrics
4444
mock_logger.close = MagicMock()
45+
mock_logger.get_logger_url = MagicMock(return_value=None)
4546

4647
return mock_logger

0 commit comments

Comments
 (0)