Skip to content

Commit

Permalink
start wiring up dense rewarding with implicit prm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 6, 2025
1 parent f3e20cf commit f721db2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
12 changes: 3 additions & 9 deletions palm_rlhf_pytorch/implicit_process_reward.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Free Process Rewards without Process Labels
# Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime

from __future__ import annotations
from copy import deepcopy

Expand All @@ -18,9 +21,6 @@ def get_logprob_at(logits, seq):
log_prob = log_probs.gather(-1, seq)
return rearrange(log_prob, '... 1 -> ...')

# Free Process Rewards without Process Labels
# Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime

class ImplicitPRM(Module):
""" PRM stands for process reward model, an openai paper that shows that rewarding the steps a model takes to its outcome is better than only rewarding based on final answer or outcome. basically same as when a teacher gives you some credit for showing your steps on an exam """

Expand Down Expand Up @@ -51,12 +51,6 @@ def forward(
seq,
labels = None
):
"""
b - batch
n - sequence
l - logit dimension (num tokens)
"""

source_seq, target_seq = seq[:, :-1], seq[:, 1:]

mask = target_seq >= 0 # assume any token ids < 0 to be padding
Expand Down
61 changes: 42 additions & 19 deletions palm_rlhf_pytorch/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

from palm_rlhf_pytorch.palm import PaLM
from palm_rlhf_pytorch.reward import RewardModel
from palm_rlhf_pytorch.implicit_process_reward import ImplicitPRM
from palm_rlhf_pytorch.utils import masked_mean, eval_decorator

from accelerate import Accelerator

# actor critic - PaLM with lora
Expand All @@ -47,7 +47,7 @@ class ActorCritic(Module):
def __init__(
self,
palm: PaLM,
critic_palm: PaLM | None = None,
critic: PaLM | ImplicitPRM | None = None,
pooled_values = False,
actor_lora = True,
critic_lora = True,
Expand All @@ -61,13 +61,26 @@ def __init__(
super().__init__()
self.actor_palm = palm

self.critic_palm = critic_palm
# detect implicit prm and auto-set some hyperparameters

critic_is_prm = isinstance(critic, ImplicitPRM)

critic_lora &= not critic_is_prm
pooled_values |= critic_is_prm

self.critic_is_prm = critic_is_prm

# critic

self.critic = critic

if not exists(self.critic_palm):
self.critic_palm = copy.deepcopy(palm)
if not exists(self.critic):
self.critic = copy.deepcopy(palm)

self.actor_palm.set_dropout(actor_dropout)
self.critic_palm.set_dropout(critic_dropout)

if not critic_is_prm:
self.critic.set_dropout(critic_dropout)

self.actor_lora = actor_lora
self.critic_lora = critic_lora
Expand All @@ -79,16 +92,19 @@ def __init__(
self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)

if self.critic_lora:
self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
self.critic.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)

self.pooled_values = pooled_values
self.value_head = nn.Sequential(
nn.Linear(palm.dim, 1),
Rearrange('... 1 -> ...')
)
self.value_head = nn.Identity()

if not critic_is_prm:
self.value_head = nn.Sequential(
nn.Linear(palm.dim, 1),
Rearrange('... 1 -> ...')
)

nn.init.zeros_(self.value_head[0].bias)
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))
nn.init.zeros_(self.value_head[0].bias)
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))

def actor_parameters(self):
if not self.actor_lora:
Expand All @@ -99,11 +115,14 @@ def actor_parameters(self):
]

def critic_parameters(self):
if self.critic_is_prm:
return self.critic.parameters()

if not self.actor_lora:
return [*self.critic_palm.parameters(), *self.value_head.parameters()]
return [*self.critic.parameters(), *self.value_head.parameters()]

return [
*self.critic_palm.finetune_parameters(self.critic_lora_scope),
*self.critic.finetune_parameters(self.critic_lora_scope),
*self.value_head.parameters()
]

Expand Down Expand Up @@ -170,7 +189,11 @@ def forward(
if not return_values:
return action_logits, None

critic_embeds = self.critic_palm(
if self.critic_is_prm:
values = self.critic(x)
return action_logits, values

critic_embeds = self.critic(
x,
return_only_embedding = True,
finetune_scope = self.critic_lora_scope
Expand Down Expand Up @@ -287,8 +310,8 @@ def clipped_value_loss(values, rewards, old_values, clip):

# rlhf trainer

@beartype
class RLHFTrainer(Module):
@beartype
def __init__(
self,
*,
Expand All @@ -298,7 +321,7 @@ def __init__(
tokenizer: Callable | None = None,
palm: PaLM,
reward_model: RewardModel,
critic_palm: PaLM | None = None,
critic: PaLM | ImplicitPRM | None = None,
actor_critic: ActorCritic | None = None,
actor_lr = 1e-4,
critic_lr = 1e-4,
Expand Down Expand Up @@ -351,7 +374,7 @@ def __init__(
if not exists(actor_critic):
actor_critic = ActorCritic(
palm = palm,
critic_palm = critic_palm,
critic = critic,
actor_lora = actor_lora,
critic_lora = critic_lora,
actor_lora_r = actor_lora_r,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.7',
version = '0.3.9',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit f721db2

Please sign in to comment.