Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tinker_cookbook/distillation/train_on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import logging
import os
import time
from typing import Any, Dict, List, Literal, Sequence, cast
from typing import Any, Dict, List, Sequence, cast

import chz
import tinker
import torch

from tinker.types import LossFnType

from tinker_cookbook import checkpoint_utils
from tinker_cookbook.display import colorize_example
from tinker_cookbook.distillation.datasets import (
Expand Down Expand Up @@ -141,7 +143,7 @@ class Config:
kl_discount_factor: float = 0.0

# Loss function to use for training: "importance_sampling" or "ppo"
loss_fn: Literal["importance_sampling", "ppo"] = "importance_sampling"
loss_fn: LossFnType = "importance_sampling"

# Number of optimizer steps per training iteration.
# Useful for very large batch sizes.
Expand Down
4 changes: 2 additions & 2 deletions tinker_cookbook/recipes/math_rl/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
from datetime import datetime
from typing import Literal

import chz
from tinker_cookbook import cli_utils, model_info
Expand All @@ -11,6 +10,7 @@
)
from tinker_cookbook.rl.train import AsyncConfig, Config, main
from tinker_cookbook.rl.types import RLDatasetBuilder
from tinker.types import LossFnType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,7 +59,7 @@ class CLIConfig:
behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask"

max_steps_off_policy: int | None = None
loss_fn: Literal["importance_sampling", "ppo"] = "importance_sampling"
loss_fn: LossFnType = "importance_sampling"


def get_dataset_builder(
Expand Down
10 changes: 5 additions & 5 deletions tinker_cookbook/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import os
import time
from contextlib import contextmanager
from typing import Any, Callable, Iterator, List, Literal, Sequence
from typing import Any, Callable, Iterator, List, Sequence

import chz
import numpy as np
import tinker
import torch

from tinker.types import LossFnType
from tinker_cookbook import checkpoint_utils
from tinker_cookbook.completers import TinkerTokenCompleter
from tinker_cookbook.display import colorize_example
Expand Down Expand Up @@ -157,7 +157,7 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum:
async def forward_backward(
training_client: tinker.TrainingClient,
batch_d: List[tinker.Datum],
loss_fn: Literal["importance_sampling", "ppo"],
loss_fn: LossFnType,
) -> List[torch.Tensor]:
"""Accumulate gradients on a minibatch of data"""
fwd_bwd_future = await training_client.forward_backward_async(
Expand All @@ -181,7 +181,7 @@ async def train_step(
training_client: tinker.TrainingClient,
learning_rate: float,
num_substeps: int,
loss_fn: Literal["importance_sampling", "ppo"],
loss_fn: LossFnType,
) -> List[torch.Tensor]:
"""Train the model on collected trajectories."""
batches_md = split_list(data_D, min(num_substeps, len(data_D)))
Expand Down Expand Up @@ -238,7 +238,7 @@ class Config:
kl_discount_factor: float = 0.0

# Loss function to use for training: "importance_sampling" or "ppo"
loss_fn: Literal["importance_sampling", "ppo"] = "importance_sampling"
loss_fn: LossFnType = "importance_sampling"

# Number of optimizer steps per training iteration.
# Useful for very large batch sizes.
Expand Down