Skip to content

Commit 461bf9b

Browse files
maitchisonMatthew AitchisonTiiiger
authored
Support new LossFnType (#132)
Co-authored-by: Matthew Aitchison <[email protected]> Co-authored-by: Tianyi <[email protected]>
1 parent 6e6dbfe commit 461bf9b

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

tinker_cookbook/distillation/train_on_policy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import logging
88
import os
99
import time
10-
from typing import Any, Dict, List, Literal, Sequence, cast
10+
from typing import Any, Dict, List, Sequence, cast
1111

1212
import chz
1313
import tinker
1414
import torch
1515

16+
from tinker.types import LossFnType
17+
1618
from tinker_cookbook import checkpoint_utils
1719
from tinker_cookbook.display import colorize_example
1820
from tinker_cookbook.distillation.datasets import (
@@ -141,7 +143,7 @@ class Config:
141143
kl_discount_factor: float = 0.0
142144

143145
# Loss function to use for training: "importance_sampling" or "ppo"
144-
loss_fn: Literal["importance_sampling", "ppo"] = "importance_sampling"
146+
loss_fn: LossFnType = "importance_sampling"
145147

146148
# Number of optimizer steps per training iteration.
147149
# Useful for very large batch sizes.

tinker_cookbook/recipes/math_rl/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import logging
33
from datetime import datetime
4-
from typing import Literal
54

65
import chz
76
from tinker_cookbook import cli_utils, model_info
@@ -11,6 +10,7 @@
1110
)
1211
from tinker_cookbook.rl.train import AsyncConfig, Config, main
1312
from tinker_cookbook.rl.types import RLDatasetBuilder
13+
from tinker.types import LossFnType
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -59,7 +59,7 @@ class CLIConfig:
5959
behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask"
6060

6161
max_steps_off_policy: int | None = None
62-
loss_fn: Literal["importance_sampling", "ppo"] = "importance_sampling"
62+
loss_fn: LossFnType = "importance_sampling"
6363

6464

6565
def get_dataset_builder(

tinker_cookbook/rl/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import os
99
import time
1010
from contextlib import contextmanager
11-
from typing import Any, Callable, Iterator, List, Literal, Sequence
11+
from typing import Any, Callable, Iterator, List, Sequence
1212

1313
import chz
1414
import numpy as np
1515
import tinker
1616
import torch
17-
17+
from tinker.types import LossFnType
1818
from tinker_cookbook import checkpoint_utils
1919
from tinker_cookbook.completers import TinkerTokenCompleter
2020
from tinker_cookbook.display import colorize_example
@@ -157,7 +157,7 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum:
157157
async def forward_backward(
158158
training_client: tinker.TrainingClient,
159159
batch_d: List[tinker.Datum],
160-
loss_fn: Literal["importance_sampling", "ppo"],
160+
loss_fn: LossFnType,
161161
) -> List[torch.Tensor]:
162162
"""Accumulate gradients on a minibatch of data"""
163163
fwd_bwd_future = await training_client.forward_backward_async(
@@ -181,7 +181,7 @@ async def train_step(
181181
training_client: tinker.TrainingClient,
182182
learning_rate: float,
183183
num_substeps: int,
184-
loss_fn: Literal["importance_sampling", "ppo"],
184+
loss_fn: LossFnType,
185185
) -> List[torch.Tensor]:
186186
"""Train the model on collected trajectories."""
187187
batches_md = split_list(data_D, min(num_substeps, len(data_D)))
@@ -238,7 +238,7 @@ class Config:
238238
kl_discount_factor: float = 0.0
239239

240240
# Loss function to use for training: "importance_sampling" or "ppo"
241-
loss_fn: Literal["importance_sampling", "ppo"] = "importance_sampling"
241+
loss_fn: LossFnType = "importance_sampling"
242242

243243
# Number of optimizer steps per training iteration.
244244
# Useful for very large batch sizes.

0 commit comments

Comments
 (0)