Skip to content

Commit eeaf453

Browse files
committed
add adv
1 parent f624b90 commit eeaf453

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

tinker_cookbook/rl/data_processing.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,25 @@
1818

1919

2020
def compute_advantages(trajectory_groups_P: List[TrajectoryGroup]) -> List[torch.Tensor]:
21-
"""Compute advantages for each trajectory, centered within groups."""
21+
"""Compute advantages for each trajectory, centered within groups.
22+
23+
For single-trajectory groups, centers across the entire batch.
24+
"""
25+
# Flatten all rewards
26+
all_rewards = torch.cat(
27+
[torch.tensor(traj_group.get_total_rewards()) for traj_group in trajectory_groups_P]
28+
)
29+
30+
# Compute baseline per group (or global if group size is 1)
2231
advantages_P: list[torch.Tensor] = []
2332

2433
for traj_group in trajectory_groups_P:
2534
rewards_G = torch.tensor(traj_group.get_total_rewards())
26-
# Center advantages within the group
27-
advantages_G = rewards_G - rewards_G.mean()
28-
advantages_P.append(advantages_G)
35+
group_size = len(rewards_G)
36+
37+
# Use group mean if > 1 trajectory, else use batch mean
38+
baseline = rewards_G.mean() if group_size > 1 else all_rewards.mean()
39+
advantages_P.append(rewards_G - baseline)
2940

3041
return advantages_P
3142

tinker_cookbook/rl/train.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,17 @@ async def prepare_minibatch(
725725
advantages_P = compute_advantages(trajectory_groups_P)
726726
data_D, _metadata_D = assemble_training_data(trajectory_groups_P, advantages_P)
727727

728+
# Log advantage statistics
729+
all_advantages = torch.cat(advantages_P)
730+
metrics.update(
731+
{
732+
"advantages/mean": all_advantages.mean().item(),
733+
"advantages/std": all_advantages.std().item(),
734+
"advantages/min": all_advantages.min().item(),
735+
"advantages/max": all_advantages.max().item(),
736+
}
737+
)
738+
728739
# Incorporate KL penalty if configured
729740
if kl_penalty_coef > 0:
730741
with timed("kl_vs_base", metrics):

0 commit comments

Comments
 (0)