File tree Expand file tree Collapse file tree 2 files changed +26
-4
lines changed
Expand file tree Collapse file tree 2 files changed +26
-4
lines changed Original file line number Diff line number Diff line change 1818
1919
2020def compute_advantages (trajectory_groups_P : List [TrajectoryGroup ]) -> List [torch .Tensor ]:
21- """Compute advantages for each trajectory, centered within groups."""
21+ """Compute advantages for each trajectory, centered within groups.
22+
23+ For single-trajectory groups, centers across the entire batch.
24+ """
25+ # Flatten all rewards
26+ all_rewards = torch .cat (
27+ [torch .tensor (traj_group .get_total_rewards ()) for traj_group in trajectory_groups_P ]
28+ )
29+
30+ # Compute baseline per group (or global if group size is 1)
2231 advantages_P : list [torch .Tensor ] = []
2332
2433 for traj_group in trajectory_groups_P :
2534 rewards_G = torch .tensor (traj_group .get_total_rewards ())
26- # Center advantages within the group
27- advantages_G = rewards_G - rewards_G .mean ()
28- advantages_P .append (advantages_G )
35+ group_size = len (rewards_G )
36+
37+ # Use group mean if > 1 trajectory, else use batch mean
38+ baseline = rewards_G .mean () if group_size > 1 else all_rewards .mean ()
39+ advantages_P .append (rewards_G - baseline )
2940
3041 return advantages_P
3142
Original file line number Diff line number Diff line change @@ -725,6 +725,17 @@ async def prepare_minibatch(
725725 advantages_P = compute_advantages (trajectory_groups_P )
726726 data_D , _metadata_D = assemble_training_data (trajectory_groups_P , advantages_P )
727727
728+ # Log advantage statistics
729+ all_advantages = torch .cat (advantages_P )
730+ metrics .update (
731+ {
732+ "advantages/mean" : all_advantages .mean ().item (),
733+ "advantages/std" : all_advantages .std ().item (),
734+ "advantages/min" : all_advantages .min ().item (),
735+ "advantages/max" : all_advantages .max ().item (),
736+ }
737+ )
738+
728739 # Incorporate KL penalty if configured
729740 if kl_penalty_coef > 0 :
730741 with timed ("kl_vs_base" , metrics ):
You can’t perform that action at this time.
0 commit comments