Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 90 additions & 38 deletions tinker_cookbook/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import logging
import os
import time
from typing import Any, Callable, List, Literal, Sequence, Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Iterator, List, Literal, Sequence

import chz
import numpy as np
import tinker
import torch

from tinker_cookbook import checkpoint_utils
from tinker_cookbook.completers import TinkerTokenCompleter
from tinker_cookbook.display import colorize_example
Expand All @@ -39,9 +42,7 @@
from tinker_cookbook.tokenizer_utils import Tokenizer
from tinker_cookbook.utils import logtree, ml_log
from tinker_cookbook.utils.misc_utils import safezip, split_list, timed
from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context
from contextlib import contextmanager

from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -354,7 +355,7 @@ async def do_sync_training_with_stream_minibatch(
):
# Samplers will produce trajectory groups asynchronously,
# and the trainer will consume them as soon as they are ready
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]()
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]()
env_group_builders_P = dataset.get_batch(i_batch)

@scope
Expand Down Expand Up @@ -393,17 +394,18 @@ async def trajectory_group_worker_task(
)

# Run multiple optimizer substeps per training iteration
(
sampling_client,
full_batch_metrics,
) = await do_train_step_streaming_and_get_sampling_client(
streaming_result = await do_train_step_streaming_and_get_sampling_client(
cfg,
i_batch,
trajectory_groups_queue,
training_client,
service_client,
tokenizer,
)
if streaming_result is None:
logger.info("[do_sync_training_with_stream_minibatch] Received shutdown signal")
return
sampling_client, full_batch_metrics = streaming_result

# Log metrics
metrics.update(full_batch_metrics)
Expand All @@ -428,6 +430,22 @@ class WrappedTrajectoryGroup:
metrics: dict[str, Any] = chz.field(default_factory=dict)


@dataclass
class Shutdown:
pass


class AsyncCounter:
def __init__(self, start: int = 0):
self.value = start
self.lock = asyncio.Lock()

async def decrement_and_get(self) -> int:
async with self.lock:
self.value -= 1
return self.value


@scope
async def do_async_training(
start_batch: int,
Expand All @@ -444,13 +462,12 @@ async def do_async_training(
"""Implements async off-policy training, capped at K steps off policy."""
assert cfg.async_config is not None

shutdown_event = asyncio.Event()
# We will have groups_per_batch worker generating rollouts, so cap the
# queue size to be groups_per_batch.
env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | None](
env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | Shutdown](
maxsize=cfg.async_config.groups_per_batch
)
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]()
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]()

# Initial sampling client to use
path_dict = await checkpoint_utils.save_checkpoint_async(
Expand All @@ -461,38 +478,46 @@ async def do_async_training(
kind="both",
)

# When the dataloader is out of data, we want to make sure all remaining samples
# are processed before terminating.
evaluation_loop_should_shutdown_event = asyncio.Event()
trajectory_group_worker_alive_counter = AsyncCounter(cfg.async_config.groups_per_batch)

# This will be updated by the training loop
sampling_client = training_client.create_sampling_client(path_dict["sampler_path"])
sampling_client_step = start_batch
sampling_client_updated_event = asyncio.Event()
sampling_client_updated_event.set()

@scope
def shutdown_loops():
"""Trigger all loops to shutdown"""
shutdown_event.set()
assert cfg.async_config is not None
for _ in range(cfg.async_config.groups_per_batch):
env_group_builders_queue.put_nowait(None)
sampling_client_updated_event.set()

@scope
async def dataloader_loop():
"""Gets the next set of env builders to run"""
i_batch = start_batch
while not shutdown_event.is_set() and i_batch < end_batch:
while i_batch < end_batch:
env_group_builders_P = dataset.get_batch(i_batch)
for env_group_builder in env_group_builders_P:
await env_group_builders_queue.put(env_group_builder)
i_batch += 1

# We are done with the data loader loop, enqueue sentinel values
# to allow the trajectory group worker loops to terminate.
logger.info("[dataloader_loop] No more data, shutting down trajectory group worker loops")
assert cfg.async_config is not None
for _ in range(cfg.async_config.groups_per_batch):
await env_group_builders_queue.put(Shutdown())
logger.info("[dataloader_loop] Data loader loop terminated")

@scope
async def trajectory_group_worker_loop():
"""Generates trajectories for a single env builder"""
while not shutdown_event.is_set():
while True:
env_group_builder = await env_group_builders_queue.get()
if env_group_builder is None:
break
match env_group_builder:
case EnvGroupBuilder():
pass
case Shutdown():
logger.info("[trajectory_group_worker_loop] Received shutdown signal")
break

metrics = {}
t_start = time.time()
Expand All @@ -518,6 +543,14 @@ async def trajectory_group_worker_loop():
metrics=metrics,
)
)
num_alive_workers = await trajectory_group_worker_alive_counter.decrement_and_get()
if num_alive_workers == 0:
# All workers are done, enqueue a sentinel to terminate the training loop
logger.info(
"[trajectory_group_worker_loop] Last worker terminated, shutting down training loop"
)
trajectory_groups_queue.put_nowait(Shutdown())
logger.info("[trajectory_group_worker_loop] Trajectory group worker loop terminated")

@scope
async def training_loop():
Expand All @@ -530,9 +563,6 @@ async def training_loop():
i_batch = start_batch
wrapped_trajectory_groups = []
while i_batch < end_batch:
wrapped_trajectory_group = await trajectory_groups_queue.get()
if wrapped_trajectory_group is None:
continue

@scope
def filter_stale_trajectory_group(
Expand Down Expand Up @@ -567,10 +597,7 @@ def filter_stale_trajectory_group(
nonlocal sampling_client
nonlocal sampling_client_step
if cfg.stream_minibatch_config is not None:
(
sampling_client,
train_step_metrics,
) = await do_train_step_streaming_and_get_sampling_client(
streaming_result = await do_train_step_streaming_and_get_sampling_client(
cfg,
i_batch,
trajectory_groups_queue,
Expand All @@ -579,7 +606,21 @@ def filter_stale_trajectory_group(
tokenizer,
filter_stale_trajectory_group,
)
if streaming_result is None:
logger.info("[training_loop] Received shutdown signal")
break
sampling_client, train_step_metrics = streaming_result
else:
wrapped_trajectory_group = await trajectory_groups_queue.get()
match wrapped_trajectory_group:
case WrappedTrajectoryGroup():
pass
case Shutdown():
logger.info("[training_loop] Received shutdown signal")
break
case None:
continue

if not filter_stale_trajectory_group(wrapped_trajectory_group):
continue

Expand Down Expand Up @@ -618,15 +659,17 @@ def filter_stale_trajectory_group(
i_batch += 1
wrapped_trajectory_groups = []

shutdown_loops()
evaluation_loop_should_shutdown_event.set()
sampling_client_updated_event.set()
logger.info("[training_loop] Training loop terminated")

@scope
async def evaluation_loop():
"""Runs evals periodically"""
if len(evaluators) == 0 or cfg.eval_every == 0:
return

while not shutdown_event.is_set():
while not evaluation_loop_should_shutdown_event.is_set():
await sampling_client_updated_event.wait()
sampling_client_updated_event.clear()

Expand All @@ -643,6 +686,7 @@ async def evaluation_loop():
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
metrics["time/evaluation_loop/total"] = time.time() - t_start
ml_logger.log_metrics(metrics, step=sampling_client_eval_step)
logger.info("[evaluation_loop] Evaluation loop terminated")

await asyncio.gather(
asyncio.create_task(dataloader_loop(), name="dataloader_loop"),
Expand Down Expand Up @@ -787,12 +831,12 @@ async def compute_full_batch_metrics_and_get_sampling_client(
async def do_train_step_streaming_and_get_sampling_client(
cfg: Config,
i_batch: int,
trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | None],
trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None],
training_client: tinker.TrainingClient,
service_client: tinker.ServiceClient,
tokenizer: Tokenizer,
trajectory_group_filter: Callable[[WrappedTrajectoryGroup | None], bool] = lambda _: True,
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
) -> tuple[tinker.SamplingClient, dict[str, Any]] | None:
"""
As soon as we have enough trajectories for a minibatch, we will train on them.
This allows us to overlap sampling and training.
Expand Down Expand Up @@ -825,8 +869,16 @@ async def do_train_step_streaming_and_get_sampling_client(
i_minibatch = 0
while i_minibatch < cfg.stream_minibatch_config.num_minibatches:
wrapped_trajectory_group = await trajectory_groups_queue.get()
if not trajectory_group_filter(wrapped_trajectory_group):
continue
match wrapped_trajectory_group:
case WrappedTrajectoryGroup():
pass
case Shutdown():
logger.info(
"[do_train_step_streaming_and_get_sampling_client] Received shutdown signal"
)
return None
case None:
continue
wrapped_trajectory_groups.append(wrapped_trajectory_group)

if len(wrapped_trajectory_groups) < groups_per_minibatch:
Expand Down