Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tinker_cookbook/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tinker

from tinker_cookbook.utils.file_utils import read_jsonl
from tinker_cookbook.utils.trace import get_scope_context, scope
from tinker_cookbook.utils.trace import scope, update_scope_context

CHECKPOINTS_BASE_NAME = "checkpoints.jsonl"

Expand All @@ -22,8 +22,7 @@ def load_checkpoints_file(log_dir: str) -> list[dict[str, Any]]:
return []

logger.info(f"Reading checkpoints from {checkpoint_path}")
context = get_scope_context()
context.attributes["checkpoint_path"] = checkpoint_path
update_scope_context({"checkpoint_path": checkpoint_path})
return read_jsonl(checkpoint_path)


Expand Down Expand Up @@ -78,8 +77,7 @@ async def save_checkpoint_async(

results = {k: await v.result_async() for k, v in futures.items()}
paths = {k + "_path": v.path for k, v in results.items()}
context = get_scope_context()
context.attributes.update(paths)
update_scope_context(paths)
logger.info(f"Saved checkpoints: {paths}")
full_dict = {"name": name, **loop_state, **paths}
with open(os.path.join(log_path, "checkpoints.jsonl"), "a") as f:
Expand Down
32 changes: 14 additions & 18 deletions tinker_cookbook/distillation/train_on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,39 @@
import logging
import os
import time
from typing import Any, List, Literal, Sequence, Dict, cast
from typing import Any, Dict, List, Literal, Sequence, cast

import chz
import tinker
import torch

from tinker_cookbook import checkpoint_utils
from tinker_cookbook.display import colorize_example
from tinker_cookbook.distillation.datasets import (
CompositeDataset,
DistillationDatasetConfig,
)
from tinker_cookbook.eval.evaluators import SamplingClientEvaluator, SamplingClientEvaluatorBuilder
from tinker_cookbook.rl.data_processing import (
assemble_training_data,
compute_advantages,
)
from tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics
from tinker_cookbook.rl.metrics import discounted_future_sum_vectorized
from tinker_cookbook.rl.train import (
compute_full_batch_metrics_and_get_sampling_client,
do_group_rollout_and_filter_constant_reward,
save_checkpoint_and_get_sampling_client,
train_step,
)
from tinker_cookbook.rl.types import (
EnvGroupBuilder,
TrajectoryGroup,
)
from tinker_cookbook.tokenizer_utils import Tokenizer
from tinker_cookbook.utils import ml_log
from tinker_cookbook.utils.misc_utils import safezip, timed
from tinker_cookbook.utils.trace import scope, get_scope_context, trace_init

# Dataset configuration classes
from tinker_cookbook.distillation.datasets import (
CompositeDataset,
DistillationDatasetConfig,
)

# We re-use these methods from the RL training recipe
from tinker_cookbook.rl.train import (
save_checkpoint_and_get_sampling_client,
train_step,
compute_full_batch_metrics_and_get_sampling_client,
do_group_rollout_and_filter_constant_reward,
)
from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -228,8 +225,7 @@ async def do_train_step_and_get_sampling_client(
dataset_indices_P: List[int],
teacher_clients: List[tinker.SamplingClient],
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
context = get_scope_context()
context.attributes["step"] = i_batch
update_scope_context({"step": i_batch})

metrics = {}
data_D, prepare_minibatch_metrics = await prepare_minibatch(
Expand Down
14 changes: 6 additions & 8 deletions tinker_cookbook/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import logging
import os
import time
from typing import Any, Callable, List, Literal, Sequence, Iterator
from contextlib import contextmanager
from typing import Any, Callable, Iterator, List, Literal, Sequence

import chz
import numpy as np
import tinker
import torch

from tinker_cookbook import checkpoint_utils
from tinker_cookbook.completers import TinkerTokenCompleter
from tinker_cookbook.display import colorize_example
Expand All @@ -39,9 +41,7 @@
from tinker_cookbook.tokenizer_utils import Tokenizer
from tinker_cookbook.utils import logtree, ml_log
from tinker_cookbook.utils.misc_utils import safezip, split_list, timed
from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context
from contextlib import contextmanager

from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -809,8 +809,7 @@ async def do_train_step_streaming_and_get_sampling_client(
# Number of groups per minibatch in each optimizer substep
groups_per_minibatch = groups_per_substep // cfg.stream_minibatch_config.num_minibatches

context = get_scope_context()
context.attributes["step"] = i_batch
update_scope_context({"step": i_batch})

metrics = {}

Expand Down Expand Up @@ -904,8 +903,7 @@ async def do_train_step_and_get_sampling_client(
env_group_builders_P: Sequence[EnvGroupBuilder],
trajectory_groups_P: list[TrajectoryGroup],
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
context = get_scope_context()
context.attributes["step"] = i_batch
update_scope_context({"step": i_batch})

metrics = {}
data_D, prepare_minibatch_metrics = await prepare_minibatch(
Expand Down
24 changes: 12 additions & 12 deletions tinker_cookbook/supervised/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from tinker_cookbook.utils import ml_log
from tinker_cookbook.utils.lr_scheduling import compute_schedule_lr_multiplier
from tinker_cookbook.utils.misc_utils import timed
from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init
from tinker_cookbook.utils.trace import scope, update_scope_context, trace_init

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,22 +107,24 @@ async def run_evals(
checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next
to the same-step training metrics.
"""
context = get_scope_context()
context.attributes["step"] = step
update_scope_context({"step": step})

metrics = {}
sampling_client = None

@scope
async def run_evaluator(evaluator: Evaluator) -> dict[str, float]:
context = get_scope_context()
context.attributes["step"] = step
context.attributes["evaluator_name"] = type(evaluator).__name__
update_scope_context(
{
"step": step,
"evaluator_name": type(evaluator).__name__,
}
)
if isinstance(evaluator, TrainingClientEvaluator):
context.attributes["evaluator_type"] = "TrainingClientEvaluator"
update_scope_context({"evaluator_type": "TrainingClientEvaluator"})
return await evaluator(training_client)
elif isinstance(evaluator, SamplingClientEvaluator):
context.attributes["evaluator_type"] = "SamplingClientEvaluator"
update_scope_context({"evaluator_type": "SamplingClientEvaluator"})
# Create sampling client lazily, only when needed
nonlocal sampling_client
if sampling_client is None:
Expand Down Expand Up @@ -225,8 +227,7 @@ async def main(config: Config):
@scope
async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
step = epoch_idx * n_batches + batch_idx
context = get_scope_context()
context.attributes["step"] = step
update_scope_context({"step": step})

batch_start_time = time.time()
metrics: dict[str, int | float | str] = {"epoch": epoch_idx}
Expand Down Expand Up @@ -286,8 +287,7 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:

@scope
async def finish_batch(submitted: SubmittedBatch):
context = get_scope_context()
context.attributes["step"] = submitted.step
update_scope_context({"step": submitted.step})

metrics = submitted.metrics
metrics["progress"] = min((submitted.step + 1) / progress_denominator, 1.0)
Expand Down
16 changes: 11 additions & 5 deletions tinker_cookbook/tests/test_trace.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import json
from tinker_cookbook.utils.trace import scope, trace_init, trace_shutdown, get_scope_context
import asyncio
import threading
import json
import tempfile
import threading

from tinker_cookbook.utils.trace import (
get_scope_context,
scope,
update_scope_context,
trace_init,
trace_shutdown,
)


@scope
Expand Down Expand Up @@ -30,8 +37,7 @@ def ced():
@scope
async def baz():
await asyncio.sleep(0.02)
context = get_scope_context()
context.attributes["baz"] = "baz"
update_scope_context({"baz": "baz"})
ced()


Expand Down
20 changes: 17 additions & 3 deletions tinker_cookbook/utils/trace.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import argparse
import asyncio
import atexit
import functools
import inspect
import json
import queue
import threading
import time
from contextvars import ContextVar
from typing import Any, Callable
from dataclasses import dataclass, field
from enum import Enum
import argparse
from io import TextIOWrapper
import atexit
from typing import Any, Callable


class EventType(str, Enum):
Expand Down Expand Up @@ -407,6 +407,20 @@ async def foo():
return result


def update_scope_context(values: dict[str, Any]) -> None:
"""Update the current scope's context. Example usage:

@scope
async def foo(step: int):
update_scope_context({"step": step})
await bar()

"""
result = trace_context.get(ScopeContext())
assert result is not None, "Trace context is not set"
result.attributes.update(values)


def convert_jsonl_to_json_main():
"""Helper script to convert the trace events format into a visualizable format"""
parser = argparse.ArgumentParser(
Expand Down