From f721db22f4080c7f20fef9f90f0202270c7ebf42 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 6 Jan 2025 08:39:00 -0800 Subject: [PATCH] start wiring up dense rewarding with implicit prm --- palm_rlhf_pytorch/implicit_process_reward.py | 12 +--- palm_rlhf_pytorch/ppo.py | 61 ++++++++++++++------ setup.py | 2 +- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/palm_rlhf_pytorch/implicit_process_reward.py b/palm_rlhf_pytorch/implicit_process_reward.py index ed0df2c..2cc382c 100644 --- a/palm_rlhf_pytorch/implicit_process_reward.py +++ b/palm_rlhf_pytorch/implicit_process_reward.py @@ -1,3 +1,6 @@ +# Free Process Rewards without Process Labels +# Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime + from __future__ import annotations from copy import deepcopy @@ -18,9 +21,6 @@ def get_logprob_at(logits, seq): log_prob = log_probs.gather(-1, seq) return rearrange(log_prob, '... 1 -> ...') -# Free Process Rewards without Process Labels -# Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime - class ImplicitPRM(Module): """ PRM stands for process reward model, an openai paper that shows that rewarding the steps a model takes to its outcome is better than only rewarding based on final answer or outcome. basically same as when a teacher gives you some credit for showing your steps on an exam """ @@ -51,12 +51,6 @@ def forward( seq, labels = None ): - """ - b - batch - n - sequence - l - logit dimension (num tokens) - """ - source_seq, target_seq = seq[:, :-1], seq[:, 1:] mask = target_seq >= 0 # assume any token ids < 0 to be padding diff --git a/palm_rlhf_pytorch/ppo.py b/palm_rlhf_pytorch/ppo.py index 1d0bea9..274879f 100644 --- a/palm_rlhf_pytorch/ppo.py +++ b/palm_rlhf_pytorch/ppo.py @@ -27,8 +27,8 @@ from palm_rlhf_pytorch.palm import PaLM from palm_rlhf_pytorch.reward import RewardModel +from palm_rlhf_pytorch.implicit_process_reward import ImplicitPRM from palm_rlhf_pytorch.utils import masked_mean, eval_decorator - from accelerate import Accelerator # actor critic - PaLM with lora @@ -47,7 +47,7 @@ class ActorCritic(Module): def __init__( self, palm: PaLM, - critic_palm: PaLM | None = None, + critic: PaLM | ImplicitPRM | None = None, pooled_values = False, actor_lora = True, critic_lora = True, @@ -61,13 +61,26 @@ def __init__( super().__init__() self.actor_palm = palm - self.critic_palm = critic_palm + # detect implicit prm and auto-set some hyperparameters + + critic_is_prm = isinstance(critic, ImplicitPRM) + + critic_lora &= not critic_is_prm + pooled_values |= critic_is_prm + + self.critic_is_prm = critic_is_prm + + # critic + + self.critic = critic - if not exists(self.critic_palm): - self.critic_palm = copy.deepcopy(palm) + if not exists(self.critic): + self.critic = copy.deepcopy(palm) self.actor_palm.set_dropout(actor_dropout) - self.critic_palm.set_dropout(critic_dropout) + + if not critic_is_prm: + self.critic.set_dropout(critic_dropout) self.actor_lora = actor_lora self.critic_lora = critic_lora @@ -79,16 +92,19 @@ def __init__( self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r) if self.critic_lora: - self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r) + self.critic.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r) self.pooled_values = pooled_values - self.value_head = nn.Sequential( - nn.Linear(palm.dim, 1), - Rearrange('... 1 -> ...') - ) + self.value_head = nn.Identity() + + if not critic_is_prm: + self.value_head = nn.Sequential( + nn.Linear(palm.dim, 1), + Rearrange('... 1 -> ...') + ) - nn.init.zeros_(self.value_head[0].bias) - nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) + nn.init.zeros_(self.value_head[0].bias) + nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) def actor_parameters(self): if not self.actor_lora: @@ -99,11 +115,14 @@ def actor_parameters(self): ] def critic_parameters(self): + if self.critic_is_prm: + return self.critic.parameters() + if not self.actor_lora: - return [*self.critic_palm.parameters(), *self.value_head.parameters()] + return [*self.critic.parameters(), *self.value_head.parameters()] return [ - *self.critic_palm.finetune_parameters(self.critic_lora_scope), + *self.critic.finetune_parameters(self.critic_lora_scope), *self.value_head.parameters() ] @@ -170,7 +189,11 @@ def forward( if not return_values: return action_logits, None - critic_embeds = self.critic_palm( + if self.critic_is_prm: + values = self.critic(x) + return action_logits, values + + critic_embeds = self.critic( x, return_only_embedding = True, finetune_scope = self.critic_lora_scope @@ -287,8 +310,8 @@ def clipped_value_loss(values, rewards, old_values, clip): # rlhf trainer -@beartype class RLHFTrainer(Module): + @beartype def __init__( self, *, @@ -298,7 +321,7 @@ def __init__( tokenizer: Callable | None = None, palm: PaLM, reward_model: RewardModel, - critic_palm: PaLM | None = None, + critic: PaLM | ImplicitPRM | None = None, actor_critic: ActorCritic | None = None, actor_lr = 1e-4, critic_lr = 1e-4, @@ -351,7 +374,7 @@ def __init__( if not exists(actor_critic): actor_critic = ActorCritic( palm = palm, - critic_palm = critic_palm, + critic = critic, actor_lora = actor_lora, critic_lora = critic_lora, actor_lora_r = actor_lora_r, diff --git a/setup.py b/setup.py index ab3ba41..3b85941 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-rlhf-pytorch', packages = find_packages(exclude=[]), - version = '0.3.7', + version = '0.3.9', license='MIT', description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch', author = 'Phil Wang',