Skip to content

Commit a36fc96

Browse files
committed
add adv
1 parent f624b90 commit a36fc96

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

tinker_cookbook/rl/data_processing.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,26 @@
1818

1919

2020
def compute_advantages(trajectory_groups_P: List[TrajectoryGroup]) -> List[torch.Tensor]:
21-
"""Compute advantages for each trajectory, centered within groups."""
21+
"""Compute advantages for each trajectory, centered within groups.
22+
23+
For single-trajectory groups, centers across the entire batch.
24+
"""
25+
# Flatten all rewards
26+
all_rewards = torch.cat([
27+
torch.tensor(traj_group.get_total_rewards())
28+
for traj_group in trajectory_groups_P
29+
])
30+
31+
# Compute baseline per group (or global if group size is 1)
2232
advantages_P: list[torch.Tensor] = []
2333

2434
for traj_group in trajectory_groups_P:
2535
rewards_G = torch.tensor(traj_group.get_total_rewards())
26-
# Center advantages within the group
27-
advantages_G = rewards_G - rewards_G.mean()
28-
advantages_P.append(advantages_G)
36+
group_size = len(rewards_G)
37+
38+
# Use group mean if > 1 trajectory, else use batch mean
39+
baseline = rewards_G.mean() if group_size > 1 else all_rewards.mean()
40+
advantages_P.append(rewards_G - baseline)
2941

3042
return advantages_P
3143

tinker_cookbook/rl/train.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,15 @@ async def prepare_minibatch(
725725
advantages_P = compute_advantages(trajectory_groups_P)
726726
data_D, _metadata_D = assemble_training_data(trajectory_groups_P, advantages_P)
727727

728+
# Log advantage statistics
729+
all_advantages = torch.cat(advantages_P)
730+
metrics.update({
731+
"advantages/mean": all_advantages.mean().item(),
732+
"advantages/std": all_advantages.std().item(),
733+
"advantages/min": all_advantages.min().item(),
734+
"advantages/max": all_advantages.max().item(),
735+
})
736+
728737
# Incorporate KL penalty if configured
729738
if kl_penalty_coef > 0:
730739
with timed("kl_vs_base", metrics):

0 commit comments

Comments
 (0)