Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ preprocess:
input: actor
output: training_data
n_workers: 8
chunk_n_groups: 2
chunk_n_groups: 8
# queue for loaded raw groups
raw_queue_size: 8
raw_queue_size: 128
# queue for processed chunks of multiple groups
input_queue_size: 32
# queue for ready chunks for multiple groups
Expand Down
2 changes: 1 addition & 1 deletion conf/finetune/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ learning_rate: 1e-6
# How much to clip the gradient (no clipping if null)
gradient_clipping_threshold: 0.3
# Learning rate scheduler type (indexed by completed_steps).
lr_scheduler_type: cosine # could be cosine, constant_with_warmup
lr_scheduler_type: constant # could be cosine, constant_with_warmup
# Number of warmup (completed) steps in the learning rate schedule.
num_warmup_steps: 50
# Number of gradient accumulation steps.
Expand Down
1 change: 1 addition & 0 deletions pipelinerl/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def run(self, dataset: list[tuple[str, dict]]):
"finished_groups": finished_groups,
"trainer_model_version": trainer_version_to_publish,
"time_since_start": time.time() - loop_start_time,
"groups_in_progress": in_progress,
}
trainer_version_to_publish = None
else:
Expand Down
2 changes: 1 addition & 1 deletion pipelinerl/domains/math/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def generate_math_rollout(
messages = []
if cfg.actor.system_prompt:
messages.append({"role": "system", "content": cfg.actor.system_prompt})
messages.append({"role": "user", "content": cfg.actor.task_template.format(task=problem["task"])})
messages.append({"role": "user", "content": f"{problem['task']} \n{cfg.actor.task_prompt}"})
prompt = Prompt(messages=messages)

time_start = time.time()
Expand Down
63 changes: 39 additions & 24 deletions pipelinerl/finetune/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def rl_step(
)

approx_kl = torch.exp(log_ratio_ref_new_clamp) - log_ratio_ref_new_clamp - 1 # Schulman KL approx
approx_kl_new_old = torch.exp(log_ratio_new_old) - log_ratio_new_old - 1 # Schulman KL approx

assert torch.isfinite(approx_kl).all(), f"approx_kl is not finite: {approx_kl}"
entropy_bonus_coef = linear_decay_coef(current_step, max_step, config.entropy_bonus, config.final_entropy_bonus)
Expand Down Expand Up @@ -337,6 +338,7 @@ def rl_step(
"max_advantage": advantages[masks_shifted].max().item(),
"min_advantage": advantages[masks_shifted].min().item(),
"kl": sum_sum(approx_kl / num_labels_in_seq, masks_shifted, segments).item(),
"kl_new_old": sum_sum(approx_kl_new_old / num_labels_in_seq, masks_shifted, segments).item(),
"max_kl": approx_kl[masks_shifted].max().item(),
"min_kl": approx_kl[masks_shifted].min().item(),
"policy_loss": sum_sum(policy_loss / num_labels_in_seq, masks_shifted, segments).item(),
Expand Down Expand Up @@ -381,22 +383,15 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R
"""
Populates a dataset with reinforcement learning specific data columns including
rewards, advantages, and token weights.

Args:
dataset (Dataset): The input dataset to populate with RL data
eos_token_id (int): End of sequence token ID
config (RLConfig): Configuration object containing RL training parameters

Returns:
Dataset: The dataset populated with RL-specific columns
Uses leave-one-out (LOO) reward mean: each rollout's baseline excludes its own reward.
"""
# Convert to pandas for processing
df_init = pd.DataFrame(dataset)
assert isinstance(df_init, pd.DataFrame)

# Step 1: calculate group-level statistics
df_stats = df_init[["group_id", "rollout_index", "step_index"]].copy()
df_stats["num_tokens"] = df_init["input_ids"].apply(lambda x: len(x))
df_stats["num_tokens"] = df_init["input_ids"].apply(len)
# We assume that rewards for all tokens are the same
df_stats["rollout_reward"] = df_init["rewards"].apply(lambda x: x[0])
# Check that the reward is the same for each step in the rollout
Expand All @@ -406,42 +401,60 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R
df_grouped = (
df_stats.groupby("group_id")
.agg(
rollout_reward_mean=("rollout_reward", "mean"),
rollout_reward_sum=("rollout_reward", "sum"),
rollout_reward_count=("rollout_reward", "count"),
rollout_reward_std=("rollout_reward", "std"),
group_tokens=("num_tokens", "mean"),
group_tokens=("num_tokens", "mean"),
)
.reset_index()
)
assert df_grouped.columns.tolist() == ["group_id", "rollout_reward_mean", "rollout_reward_std", "group_tokens"]

# Step 2: calculate advantages for each sample
assert df_grouped.columns.tolist() == [
"group_id",
"rollout_reward_sum",
"rollout_reward_count",
"rollout_reward_std",
"group_tokens",
]

# Step 2: calculate advantages for each sample (with LOO mean)
df_advantages = pd.merge(
df_init[["group_id", "rollout_index", "step_index", "rewards"]],
df_grouped,
on="group_id",
how="left"
)
assert len(df_advantages) == len(df_init)

def calculate_advantages(row):
rewards = row["rewards"]
mean = row["rollout_reward_mean"]
group_sum = row["rollout_reward_sum"]
group_count = row["rollout_reward_count"]
current_reward = rewards[0] # same reward across tokens in rollout

# Leave-one-out mean
if group_count > 1:
loo_mean = (group_sum - current_reward) / (group_count - 1)
else:
loo_mean = current_reward # degenerate case: only one rollout in group

std = row["rollout_reward_std"]
if config.divide_advantage_by_std:
advantages = [(reward - mean) / (np.nan_to_num(std) + 1e-4) for reward in rewards]
advantages = [(r - loo_mean) / (np.nan_to_num(std) + 1e-4) for r in rewards]
else:
advantages = [(reward - mean) for reward in rewards]
advantages = [(r - loo_mean) for r in rewards]
return advantages
df_advantages["advantages"] = df_advantages.apply(
calculate_advantages,
axis=1,

df_advantages["advantages"] = df_advantages.apply(calculate_advantages, axis=1)
df_advantages = df_advantages.drop(
columns=["rewards", "rollout_reward_sum", "rollout_reward_count", "rollout_reward_std"]
)
df_advantages = df_advantages.drop(columns=["rewards", "rollout_reward_mean", "rollout_reward_std"])
assert df_advantages.columns.tolist() == ["group_id", "rollout_index", "step_index", "group_tokens", "advantages"]
assert df_advantages.columns.tolist() == [
"group_id", "rollout_index", "step_index", "group_tokens", "advantages"
]

# Step 3: bring advantages and group level stats back to the main df
df = df_init.drop(columns=["advantages", "group_tokens"])
df = pd.merge(df, df_advantages, on=["group_id", "rollout_index", "step_index"], how="left")
# Debug print lengths of all dataframes
assert len(df) == len(df_init)

# Step 4: make token-level overflow and mean group length information
Expand All @@ -450,7 +463,9 @@ def calculate_advantages(row):
axis=1,
)
df["group_tokens"] = df.apply(lambda row: [row["group_tokens"]] * len(row["input_ids"]), axis=1)
df["num_labels"] = df.apply(lambda row: [sum(1 for label in row["labels"] if label != -100)] * len(row["input_ids"]), axis=1)
df["num_labels"] = df.apply(
lambda row: [sum(1 for label in row["labels"] if label != -100)] * len(row["input_ids"]), axis=1
)

# Step 5: move the results back to the dataset
advantages_list = df["advantages"].tolist()
Expand Down
1 change: 1 addition & 0 deletions pipelinerl/finetune_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def run_finetuning_loop(
finally:
if actor_update_group:
dist.destroy_process_group(actor_update_group)
raise RuntimeError("Finetuning loop finished, exiting worker thread")


def rl_finetuning_worker(
Expand Down
1 change: 1 addition & 0 deletions pipelinerl/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ def run_preprocessing_loop(
"preprocessor/queue/output": output_queue.qsize(),
"preprocessor/filtered_out_samples": num_filtered_out,
"preprocessor/total_filtered_out_samples": total_filtered_out,
"preprocessor/dropped_after_preprocessing": processed_entries_queue_popped_data,
}
if stats_aggregator.has_enough_data():
stats.update({"preprocessor/" + k: v for k, v in stats_aggregator.get_stats().items()})
Expand Down