77import logging
88import os
99import time
10- from typing import Any , Callable , List , Literal , Sequence , Iterator
10+ from contextlib import contextmanager
11+ from dataclasses import dataclass
12+ from typing import Any , Callable , Iterator , List , Literal , Sequence
1113
1214import chz
1315import numpy as np
1416import tinker
1517import torch
18+
1619from tinker_cookbook import checkpoint_utils
1720from tinker_cookbook .completers import TinkerTokenCompleter
1821from tinker_cookbook .display import colorize_example
3942from tinker_cookbook .tokenizer_utils import Tokenizer
4043from tinker_cookbook .utils import logtree , ml_log
4144from tinker_cookbook .utils .misc_utils import safezip , split_list , timed
42- from tinker_cookbook .utils .trace import scope , trace_init , get_scope_context
43- from contextlib import contextmanager
44-
45+ from tinker_cookbook .utils .trace import get_scope_context , scope , trace_init
4546
4647logger = logging .getLogger (__name__ )
4748
@@ -354,7 +355,7 @@ async def do_sync_training_with_stream_minibatch(
354355 ):
355356 # Samplers will produce trajectory groups asynchronously,
356357 # and the trainer will consume them as soon as they are ready
357- trajectory_groups_queue = asyncio .Queue [WrappedTrajectoryGroup | None ]()
358+ trajectory_groups_queue = asyncio .Queue [WrappedTrajectoryGroup | Shutdown | None ]()
358359 env_group_builders_P = dataset .get_batch (i_batch )
359360
360361 @scope
@@ -393,17 +394,18 @@ async def trajectory_group_worker_task(
393394 )
394395
395396 # Run multiple optimizer substeps per training iteration
396- (
397- sampling_client ,
398- full_batch_metrics ,
399- ) = await do_train_step_streaming_and_get_sampling_client (
397+ streaming_result = await do_train_step_streaming_and_get_sampling_client (
400398 cfg ,
401399 i_batch ,
402400 trajectory_groups_queue ,
403401 training_client ,
404402 service_client ,
405403 tokenizer ,
406404 )
405+ if streaming_result is None :
406+ logger .info ("[do_sync_training_with_stream_minibatch] Received shutdown signal" )
407+ return
408+ sampling_client , full_batch_metrics = streaming_result
407409
408410 # Log metrics
409411 metrics .update (full_batch_metrics )
@@ -428,6 +430,22 @@ class WrappedTrajectoryGroup:
428430 metrics : dict [str , Any ] = chz .field (default_factory = dict )
429431
430432
433+ @dataclass
434+ class Shutdown :
435+ pass
436+
437+
438+ class AsyncCounter :
439+ def __init__ (self , start : int = 0 ):
440+ self .value = start
441+ self .lock = asyncio .Lock ()
442+
443+ async def decrement_and_get (self ) -> int :
444+ async with self .lock :
445+ self .value -= 1
446+ return self .value
447+
448+
431449@scope
432450async def do_async_training (
433451 start_batch : int ,
@@ -444,13 +462,12 @@ async def do_async_training(
444462 """Implements async off-policy training, capped at K steps off policy."""
445463 assert cfg .async_config is not None
446464
447- shutdown_event = asyncio .Event ()
448465 # We will have groups_per_batch worker generating rollouts, so cap the
449466 # queue size to be groups_per_batch.
450- env_group_builders_queue = asyncio .Queue [EnvGroupBuilder | None ](
467+ env_group_builders_queue = asyncio .Queue [EnvGroupBuilder | Shutdown ](
451468 maxsize = cfg .async_config .groups_per_batch
452469 )
453- trajectory_groups_queue = asyncio .Queue [WrappedTrajectoryGroup | None ]()
470+ trajectory_groups_queue = asyncio .Queue [WrappedTrajectoryGroup | Shutdown | None ]()
454471
455472 # Initial sampling client to use
456473 path_dict = await checkpoint_utils .save_checkpoint_async (
@@ -461,38 +478,46 @@ async def do_async_training(
461478 kind = "both" ,
462479 )
463480
481+ # When the dataloader is out of data, we want to make sure all remaining samples
482+ # are processed before terminating.
483+ evaluation_loop_should_shutdown_event = asyncio .Event ()
484+ trajectory_group_worker_alive_counter = AsyncCounter (cfg .async_config .groups_per_batch )
485+
464486 # This will be updated by the training loop
465487 sampling_client = training_client .create_sampling_client (path_dict ["sampler_path" ])
466488 sampling_client_step = start_batch
467489 sampling_client_updated_event = asyncio .Event ()
468490 sampling_client_updated_event .set ()
469491
470- @scope
471- def shutdown_loops ():
472- """Trigger all loops to shutdown"""
473- shutdown_event .set ()
474- assert cfg .async_config is not None
475- for _ in range (cfg .async_config .groups_per_batch ):
476- env_group_builders_queue .put_nowait (None )
477- sampling_client_updated_event .set ()
478-
479492 @scope
480493 async def dataloader_loop ():
481494 """Gets the next set of env builders to run"""
482495 i_batch = start_batch
483- while not shutdown_event . is_set () and i_batch < end_batch :
496+ while not i_batch < end_batch :
484497 env_group_builders_P = dataset .get_batch (i_batch )
485498 for env_group_builder in env_group_builders_P :
486499 await env_group_builders_queue .put (env_group_builder )
487500 i_batch += 1
488501
502+ # We are done with the data loader loop, enqueue sentinel values
503+ # to allow the trajectory group worker loops to terminate.
504+ logger .info ("[dataloader_loop] No more data, shutting down trajectory group worker loops" )
505+ assert cfg .async_config is not None
506+ for _ in range (cfg .async_config .groups_per_batch ):
507+ await env_group_builders_queue .put (Shutdown ())
508+ logger .info ("[dataloader_loop] Data loader loop terminated" )
509+
489510 @scope
490511 async def trajectory_group_worker_loop ():
491512 """Generates trajectories for a single env builder"""
492- while not shutdown_event . is_set () :
513+ while True :
493514 env_group_builder = await env_group_builders_queue .get ()
494- if env_group_builder is None :
495- break
515+ match env_group_builder :
516+ case EnvGroupBuilder ():
517+ pass
518+ case Shutdown ():
519+ logger .info ("[trajectory_group_worker_loop] Received shutdown signal" )
520+ break
496521
497522 metrics = {}
498523 t_start = time .time ()
@@ -518,6 +543,14 @@ async def trajectory_group_worker_loop():
518543 metrics = metrics ,
519544 )
520545 )
546+ num_alive_workers = await trajectory_group_worker_alive_counter .decrement_and_get ()
547+ if num_alive_workers == 0 :
548+ # All workers are done, enqueue a sentinel to terminate the training loop
549+ logger .info (
550+ "[trajectory_group_worker_loop] Last worker terminated, shutting down training loop"
551+ )
552+ trajectory_groups_queue .put_nowait (Shutdown ())
553+ logger .info ("[trajectory_group_worker_loop] Trajectory group worker loop terminated" )
521554
522555 @scope
523556 async def training_loop ():
@@ -530,9 +563,6 @@ async def training_loop():
530563 i_batch = start_batch
531564 wrapped_trajectory_groups = []
532565 while i_batch < end_batch :
533- wrapped_trajectory_group = await trajectory_groups_queue .get ()
534- if wrapped_trajectory_group is None :
535- continue
536566
537567 @scope
538568 def filter_stale_trajectory_group (
@@ -567,10 +597,7 @@ def filter_stale_trajectory_group(
567597 nonlocal sampling_client
568598 nonlocal sampling_client_step
569599 if cfg .stream_minibatch_config is not None :
570- (
571- sampling_client ,
572- train_step_metrics ,
573- ) = await do_train_step_streaming_and_get_sampling_client (
600+ streaming_result = await do_train_step_streaming_and_get_sampling_client (
574601 cfg ,
575602 i_batch ,
576603 trajectory_groups_queue ,
@@ -579,7 +606,21 @@ def filter_stale_trajectory_group(
579606 tokenizer ,
580607 filter_stale_trajectory_group ,
581608 )
609+ if streaming_result is None :
610+ logger .info ("[training_loop] Received shutdown signal" )
611+ break
612+ sampling_client , train_step_metrics = streaming_result
582613 else :
614+ wrapped_trajectory_group = await trajectory_groups_queue .get ()
615+ match wrapped_trajectory_group :
616+ case WrappedTrajectoryGroup ():
617+ pass
618+ case Shutdown ():
619+ logger .info ("[training_loop] Received shutdown signal" )
620+ break
621+ case None :
622+ continue
623+
583624 if not filter_stale_trajectory_group (wrapped_trajectory_group ):
584625 continue
585626
@@ -618,15 +659,17 @@ def filter_stale_trajectory_group(
618659 i_batch += 1
619660 wrapped_trajectory_groups = []
620661
621- shutdown_loops ()
662+ evaluation_loop_should_shutdown_event .set ()
663+ sampling_client_updated_event .set ()
664+ logger .info ("[training_loop] Training loop terminated" )
622665
623666 @scope
624667 async def evaluation_loop ():
625668 """Runs evals periodically"""
626669 if len (evaluators ) == 0 or cfg .eval_every == 0 :
627670 return
628671
629- while not shutdown_event .is_set ():
672+ while not evaluation_loop_should_shutdown_event .is_set ():
630673 await sampling_client_updated_event .wait ()
631674 sampling_client_updated_event .clear ()
632675
@@ -643,6 +686,7 @@ async def evaluation_loop():
643686 metrics .update ({f"test/{ k } " : v for k , v in eval_metrics .items ()})
644687 metrics ["time/evaluation_loop/total" ] = time .time () - t_start
645688 ml_logger .log_metrics (metrics , step = sampling_client_eval_step )
689+ logger .info ("[evaluation_loop] Evaluation loop terminated" )
646690
647691 await asyncio .gather (
648692 asyncio .create_task (dataloader_loop (), name = "dataloader_loop" ),
@@ -787,12 +831,12 @@ async def compute_full_batch_metrics_and_get_sampling_client(
787831async def do_train_step_streaming_and_get_sampling_client (
788832 cfg : Config ,
789833 i_batch : int ,
790- trajectory_groups_queue : asyncio .Queue [WrappedTrajectoryGroup | None ],
834+ trajectory_groups_queue : asyncio .Queue [WrappedTrajectoryGroup | Shutdown | None ],
791835 training_client : tinker .TrainingClient ,
792836 service_client : tinker .ServiceClient ,
793837 tokenizer : Tokenizer ,
794838 trajectory_group_filter : Callable [[WrappedTrajectoryGroup | None ], bool ] = lambda _ : True ,
795- ) -> tuple [tinker .SamplingClient , dict [str , Any ]]:
839+ ) -> tuple [tinker .SamplingClient , dict [str , Any ]] | None :
796840 """
797841 As soon as we have enough trajectories for a minibatch, we will train on them.
798842 This allows us to overlap sampling and training.
@@ -825,8 +869,16 @@ async def do_train_step_streaming_and_get_sampling_client(
825869 i_minibatch = 0
826870 while i_minibatch < cfg .stream_minibatch_config .num_minibatches :
827871 wrapped_trajectory_group = await trajectory_groups_queue .get ()
828- if not trajectory_group_filter (wrapped_trajectory_group ):
829- continue
872+ match wrapped_trajectory_group :
873+ case WrappedTrajectoryGroup ():
874+ pass
875+ case Shutdown ():
876+ logger .info (
877+ "[do_train_step_streaming_and_get_sampling_client] Received shutdown signal"
878+ )
879+ return None
880+ case None :
881+ continue
830882 wrapped_trajectory_groups .append (wrapped_trajectory_group )
831883
832884 if len (wrapped_trajectory_groups ) < groups_per_minibatch :
0 commit comments