@@ -83,14 +83,25 @@ class SubmittedBatch:
8383 epoch_idx : int
8484 batch_idx : int
8585 batch_start_time : float
86+ eval_metrics : dict [str , float ] | None = None
87+ infrequent_eval_metrics : dict [str , float ] | None = None
8688
8789
8890async def run_evals (
8991 evaluators : list [Evaluator ],
9092 training_client : tinker .TrainingClient ,
9193 step : int ,
9294) -> dict [str , float ]:
93- """Run all evaluators and return metrics with test/ prefix."""
95+ """Evaluate the current model weights and prefix results with ``test/``.
96+
97+ The helper is called immediately before optimizer step `step` is submitted, so it
98+ measures the weights produced after step `step-1` (or the initial weights for step 0).
99+ Training-client evaluators run against the mutable training client, while sampling
100+ evaluators request a fresh `SamplingClient` snapshot via
101+ `save_weights_and_get_sampling_client_async` to ensure their work uses a fixed
102+ checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next
103+ to the same-step training metrics.
104+ """
94105 metrics = {}
95106 sampling_client = None
96107
@@ -100,6 +111,7 @@ async def run_evals(
100111 elif isinstance (evaluator , SamplingClientEvaluator ):
101112 # Create sampling client lazily, only when needed
102113 if sampling_client is None :
114+ # Snapshot the current pre-step weights and create a new sampling client.
103115 sampling_client = await training_client .save_weights_and_get_sampling_client_async (
104116 f"evals_step_{ step } "
105117 )
@@ -114,14 +126,28 @@ async def run_evals(
114126
115127
116128async def main (config : Config ):
117- """Main training function that runs the complete training process."""
129+ """Run the standard supervised learning loop used by the supervised recipes.
130+
131+ Responsibilities:
132+ 1. Initialize logging, build the dataset/evaluator objects, construct (or resume) the
133+ training client, and determine the ``epoch``/``batch`` indices to start from.
134+ 2. Iterate over batches: fetch data, optionally run evaluations before submitting the
135+ optimizer step (so they observe pre-step weights), issue `forward_backward` and
136+ `optim_step` requests, and log metrics once the futures resolve.
137+ 3. Save checkpoints at the configured cadence so runs can resume or export weights,
138+ then emit a final checkpoint when training completes.
139+
140+ Training and evaluation metrics share the same ``step`` index to keep dashboards easy
141+ to read.
142+ """
118143 resume_info = checkpoint_utils .get_last_checkpoint (config .log_path )
119144 if resume_info :
120145 start_epoch = resume_info ["epoch" ]
121146 start_batch = resume_info ["batch" ]
122147 else :
123148 start_epoch = 0
124149 start_batch = 0
150+ # (start_epoch, start_batch) now represent the next batch to execute if resuming.
125151
126152 ml_logger = ml_log .setup_logging (
127153 log_dir = config .log_path ,
@@ -191,6 +217,23 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
191217 if data :
192218 logger .info (colorize_example (data [0 ], tokenizer ))
193219
220+ # Trigger evaluations BEFORE submitting training operations so they snapshot pre-step weights
221+ eval_metrics = None
222+ if evaluators and config .eval_every > 0 and step % config .eval_every == 0 :
223+ with timed ("evals" , metrics ):
224+ eval_metrics = await run_evals (evaluators , training_client , step )
225+
226+ infrequent_eval_metrics = None
227+ if (
228+ infrequent_evaluators
229+ and config .infrequent_eval_every > 0
230+ and step % config .infrequent_eval_every == 0
231+ ):
232+ with timed ("infrequent_evals" , metrics ):
233+ infrequent_eval_metrics = await run_evals (
234+ infrequent_evaluators , training_client , step
235+ )
236+
194237 fwd_bwd_future = await training_client .forward_backward_async (data , loss_fn = "cross_entropy" )
195238 optim_step_future = await training_client .optim_step_async (adam_params )
196239
@@ -203,6 +246,8 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
203246 epoch_idx = epoch_idx ,
204247 batch_idx = batch_idx ,
205248 batch_start_time = batch_start_time ,
249+ eval_metrics = eval_metrics ,
250+ infrequent_eval_metrics = infrequent_eval_metrics ,
206251 )
207252
208253 async def finish_batch (submitted : SubmittedBatch ):
@@ -211,6 +256,8 @@ async def finish_batch(submitted: SubmittedBatch):
211256
212257 if submitted .step % config .save_every == 0 and submitted .step > 0 :
213258 with timed ("save_checkpoint" , metrics ):
259+ # Enqueue a checkpoint save after the forward/backward and optimizer
260+ # requests for this step; the snapshot will reflect post-step weights.
214261 await checkpoint_utils .save_checkpoint_async (
215262 training_client = training_client ,
216263 name = f"{ submitted .step :06d} " ,
@@ -237,22 +284,14 @@ async def finish_batch(submitted: SubmittedBatch):
237284 )
238285 metrics ["time/total" ] = time .time () - submitted .batch_start_time
239286
240- if evaluators and config .eval_every > 0 and submitted .step % config .eval_every == 0 :
241- with timed ("evals" , metrics ):
242- eval_metrics = await run_evals (evaluators , training_client , submitted .step )
243- metrics .update (eval_metrics )
287+ # Merge evaluation metrics gathered before the training step was submitted
288+ if submitted .eval_metrics is not None :
289+ metrics .update (submitted .eval_metrics )
244290
245- if (
246- infrequent_evaluators
247- and config .infrequent_eval_every > 0
248- and submitted .step % config .infrequent_eval_every == 0
249- ):
250- with timed ("infrequent_evals" , metrics ):
251- eval_metrics = await run_evals (
252- infrequent_evaluators , training_client , submitted .step
253- )
254- metrics .update (eval_metrics )
291+ if submitted .infrequent_eval_metrics is not None :
292+ metrics .update (submitted .infrequent_eval_metrics )
255293
294+ # Emit all metrics for this step (train and eval) on the `submitted.step` row.
256295 ml_logger .log_metrics (metrics = metrics , step = submitted .step )
257296
258297 pending_batch : SubmittedBatch | None = None
0 commit comments