Skip to content

Commit 6172d61

Browse files
committed
[tinker-cookbook] rl: avoid hanging in async runs when we run out of data
Previously, on async RL runs, we can hang in shutdown if we run out of data. Thi fixes it to ensure proper shutdown and that all data in queues are drained with the dataloader loop terminates first.
1 parent 5f5ce26 commit 6172d61

File tree

1 file changed

+90
-38
lines changed

1 file changed

+90
-38
lines changed

tinker_cookbook/rl/train.py

Lines changed: 90 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
import logging
88
import os
99
import time
10-
from typing import Any, Callable, List, Literal, Sequence, Iterator
10+
from contextlib import contextmanager
11+
from dataclasses import dataclass
12+
from typing import Any, Callable, Iterator, List, Literal, Sequence
1113

1214
import chz
1315
import numpy as np
1416
import tinker
1517
import torch
18+
1619
from tinker_cookbook import checkpoint_utils
1720
from tinker_cookbook.completers import TinkerTokenCompleter
1821
from tinker_cookbook.display import colorize_example
@@ -39,9 +42,7 @@
3942
from tinker_cookbook.tokenizer_utils import Tokenizer
4043
from tinker_cookbook.utils import logtree, ml_log
4144
from tinker_cookbook.utils.misc_utils import safezip, split_list, timed
42-
from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context
43-
from contextlib import contextmanager
44-
45+
from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init
4546

4647
logger = logging.getLogger(__name__)
4748

@@ -354,7 +355,7 @@ async def do_sync_training_with_stream_minibatch(
354355
):
355356
# Samplers will produce trajectory groups asynchronously,
356357
# and the trainer will consume them as soon as they are ready
357-
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]()
358+
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]()
358359
env_group_builders_P = dataset.get_batch(i_batch)
359360

360361
@scope
@@ -393,17 +394,18 @@ async def trajectory_group_worker_task(
393394
)
394395

395396
# Run multiple optimizer substeps per training iteration
396-
(
397-
sampling_client,
398-
full_batch_metrics,
399-
) = await do_train_step_streaming_and_get_sampling_client(
397+
streaming_result = await do_train_step_streaming_and_get_sampling_client(
400398
cfg,
401399
i_batch,
402400
trajectory_groups_queue,
403401
training_client,
404402
service_client,
405403
tokenizer,
406404
)
405+
if streaming_result is None:
406+
logger.info("[do_sync_training_with_stream_minibatch] Received shutdown signal")
407+
return
408+
sampling_client, full_batch_metrics = streaming_result
407409

408410
# Log metrics
409411
metrics.update(full_batch_metrics)
@@ -428,6 +430,22 @@ class WrappedTrajectoryGroup:
428430
metrics: dict[str, Any] = chz.field(default_factory=dict)
429431

430432

433+
@dataclass
434+
class Shutdown:
435+
pass
436+
437+
438+
class AsyncCounter:
439+
def __init__(self, start: int = 0):
440+
self.value = start
441+
self.lock = asyncio.Lock()
442+
443+
async def decrement_and_get(self) -> int:
444+
async with self.lock:
445+
self.value -= 1
446+
return self.value
447+
448+
431449
@scope
432450
async def do_async_training(
433451
start_batch: int,
@@ -444,13 +462,12 @@ async def do_async_training(
444462
"""Implements async off-policy training, capped at K steps off policy."""
445463
assert cfg.async_config is not None
446464

447-
shutdown_event = asyncio.Event()
448465
# We will have groups_per_batch worker generating rollouts, so cap the
449466
# queue size to be groups_per_batch.
450-
env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | None](
467+
env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | Shutdown](
451468
maxsize=cfg.async_config.groups_per_batch
452469
)
453-
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]()
470+
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]()
454471

455472
# Initial sampling client to use
456473
path_dict = await checkpoint_utils.save_checkpoint_async(
@@ -461,38 +478,46 @@ async def do_async_training(
461478
kind="both",
462479
)
463480

481+
# When the dataloader is out of data, we want to make sure all remaining samples
482+
# are processed before terminating.
483+
evaluation_loop_should_shutdown_event = asyncio.Event()
484+
trajectory_group_worker_alive_counter = AsyncCounter(cfg.async_config.groups_per_batch)
485+
464486
# This will be updated by the training loop
465487
sampling_client = training_client.create_sampling_client(path_dict["sampler_path"])
466488
sampling_client_step = start_batch
467489
sampling_client_updated_event = asyncio.Event()
468490
sampling_client_updated_event.set()
469491

470-
@scope
471-
def shutdown_loops():
472-
"""Trigger all loops to shutdown"""
473-
shutdown_event.set()
474-
assert cfg.async_config is not None
475-
for _ in range(cfg.async_config.groups_per_batch):
476-
env_group_builders_queue.put_nowait(None)
477-
sampling_client_updated_event.set()
478-
479492
@scope
480493
async def dataloader_loop():
481494
"""Gets the next set of env builders to run"""
482495
i_batch = start_batch
483-
while not shutdown_event.is_set() and i_batch < end_batch:
496+
while not i_batch < end_batch:
484497
env_group_builders_P = dataset.get_batch(i_batch)
485498
for env_group_builder in env_group_builders_P:
486499
await env_group_builders_queue.put(env_group_builder)
487500
i_batch += 1
488501

502+
# We are done with the data loader loop, enqueue sentinel values
503+
# to allow the trajectory group worker loops to terminate.
504+
logger.info("[dataloader_loop] No more data, shutting down trajectory group worker loops")
505+
assert cfg.async_config is not None
506+
for _ in range(cfg.async_config.groups_per_batch):
507+
await env_group_builders_queue.put(Shutdown())
508+
logger.info("[dataloader_loop] Data loader loop terminated")
509+
489510
@scope
490511
async def trajectory_group_worker_loop():
491512
"""Generates trajectories for a single env builder"""
492-
while not shutdown_event.is_set():
513+
while True:
493514
env_group_builder = await env_group_builders_queue.get()
494-
if env_group_builder is None:
495-
break
515+
match env_group_builder:
516+
case EnvGroupBuilder():
517+
pass
518+
case Shutdown():
519+
logger.info("[trajectory_group_worker_loop] Received shutdown signal")
520+
break
496521

497522
metrics = {}
498523
t_start = time.time()
@@ -518,6 +543,14 @@ async def trajectory_group_worker_loop():
518543
metrics=metrics,
519544
)
520545
)
546+
num_alive_workers = await trajectory_group_worker_alive_counter.decrement_and_get()
547+
if num_alive_workers == 0:
548+
# All workers are done, enqueue a sentinel to terminate the training loop
549+
logger.info(
550+
"[trajectory_group_worker_loop] Last worker terminated, shutting down training loop"
551+
)
552+
trajectory_groups_queue.put_nowait(Shutdown())
553+
logger.info("[trajectory_group_worker_loop] Trajectory group worker loop terminated")
521554

522555
@scope
523556
async def training_loop():
@@ -530,9 +563,6 @@ async def training_loop():
530563
i_batch = start_batch
531564
wrapped_trajectory_groups = []
532565
while i_batch < end_batch:
533-
wrapped_trajectory_group = await trajectory_groups_queue.get()
534-
if wrapped_trajectory_group is None:
535-
continue
536566

537567
@scope
538568
def filter_stale_trajectory_group(
@@ -567,10 +597,7 @@ def filter_stale_trajectory_group(
567597
nonlocal sampling_client
568598
nonlocal sampling_client_step
569599
if cfg.stream_minibatch_config is not None:
570-
(
571-
sampling_client,
572-
train_step_metrics,
573-
) = await do_train_step_streaming_and_get_sampling_client(
600+
streaming_result = await do_train_step_streaming_and_get_sampling_client(
574601
cfg,
575602
i_batch,
576603
trajectory_groups_queue,
@@ -579,7 +606,21 @@ def filter_stale_trajectory_group(
579606
tokenizer,
580607
filter_stale_trajectory_group,
581608
)
609+
if streaming_result is None:
610+
logger.info("[training_loop] Received shutdown signal")
611+
break
612+
sampling_client, train_step_metrics = streaming_result
582613
else:
614+
wrapped_trajectory_group = await trajectory_groups_queue.get()
615+
match wrapped_trajectory_group:
616+
case WrappedTrajectoryGroup():
617+
pass
618+
case Shutdown():
619+
logger.info("[training_loop] Received shutdown signal")
620+
break
621+
case None:
622+
continue
623+
583624
if not filter_stale_trajectory_group(wrapped_trajectory_group):
584625
continue
585626

@@ -618,15 +659,17 @@ def filter_stale_trajectory_group(
618659
i_batch += 1
619660
wrapped_trajectory_groups = []
620661

621-
shutdown_loops()
662+
evaluation_loop_should_shutdown_event.set()
663+
sampling_client_updated_event.set()
664+
logger.info("[training_loop] Training loop terminated")
622665

623666
@scope
624667
async def evaluation_loop():
625668
"""Runs evals periodically"""
626669
if len(evaluators) == 0 or cfg.eval_every == 0:
627670
return
628671

629-
while not shutdown_event.is_set():
672+
while not evaluation_loop_should_shutdown_event.is_set():
630673
await sampling_client_updated_event.wait()
631674
sampling_client_updated_event.clear()
632675

@@ -643,6 +686,7 @@ async def evaluation_loop():
643686
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
644687
metrics["time/evaluation_loop/total"] = time.time() - t_start
645688
ml_logger.log_metrics(metrics, step=sampling_client_eval_step)
689+
logger.info("[evaluation_loop] Evaluation loop terminated")
646690

647691
await asyncio.gather(
648692
asyncio.create_task(dataloader_loop(), name="dataloader_loop"),
@@ -787,12 +831,12 @@ async def compute_full_batch_metrics_and_get_sampling_client(
787831
async def do_train_step_streaming_and_get_sampling_client(
788832
cfg: Config,
789833
i_batch: int,
790-
trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | None],
834+
trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None],
791835
training_client: tinker.TrainingClient,
792836
service_client: tinker.ServiceClient,
793837
tokenizer: Tokenizer,
794838
trajectory_group_filter: Callable[[WrappedTrajectoryGroup | None], bool] = lambda _: True,
795-
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
839+
) -> tuple[tinker.SamplingClient, dict[str, Any]] | None:
796840
"""
797841
As soon as we have enough trajectories for a minibatch, we will train on them.
798842
This allows us to overlap sampling and training.
@@ -825,8 +869,16 @@ async def do_train_step_streaming_and_get_sampling_client(
825869
i_minibatch = 0
826870
while i_minibatch < cfg.stream_minibatch_config.num_minibatches:
827871
wrapped_trajectory_group = await trajectory_groups_queue.get()
828-
if not trajectory_group_filter(wrapped_trajectory_group):
829-
continue
872+
match wrapped_trajectory_group:
873+
case WrappedTrajectoryGroup():
874+
pass
875+
case Shutdown():
876+
logger.info(
877+
"[do_train_step_streaming_and_get_sampling_client] Received shutdown signal"
878+
)
879+
return None
880+
case None:
881+
continue
830882
wrapped_trajectory_groups.append(wrapped_trajectory_group)
831883

832884
if len(wrapped_trajectory_groups) < groups_per_minibatch:

0 commit comments

Comments
 (0)