Skip to content

Commit

Permalink
Merge branch 'pl' of https://github.com/epfl-dlab/PauseToken into pl
Browse files Browse the repository at this point in the history
  • Loading branch information
nbaldwin98 committed Oct 22, 2024
2 parents 99b5979 + 4e7151e commit a609ab7
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 9 deletions.
49 changes: 49 additions & 0 deletions configs/experiment/train/star_on_policy_pause.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# @package _global_

defaults:
- override /data: gsm8k
- override /rl_algorithm: star_on_policy
- override /rl_algorithm/policy/model/language_model: pause_from_pretrained
- override /rl_algorithm/policy/model/peft_config: null #peft is already there
- override /rl_algorithm/reward: gsm8k
- override /metrics: gsm8k

name: "star on gsm8k"
run_name: "debug"
task_name: "train"



data:
additional_transformation:
_target_: functools.partial
_args_:
- ${get_method:src.utils.trainer_utils.inference_formatting_function}
eos_token: ${get_obj_attr:${rl_algorithm.policy.model.tokenizer},[eos_token]}


trainer:
inner_loop_timesteps: 10
n_outer_loops: 5
progress_bar: false
num_val_samples: 10
save_top_k: 3
metric_for_best_model: "val/accuracy"
#whether the metric for the best model is min or max (True = min (the lower the better), False = max (the higher the better))
metric_for_best_model_mode_is_min: false



rl_algorithm:
n_envs: 4
loss_computed_in_forward_pass: true
buffer:
advantage_threshold: 0

policy:
model:
language_model:
pretrained_model_name_or_path: /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-08-29_10-23-33/final
post_instanciation_method_calls:
- method: unfreeze_all

2 changes: 1 addition & 1 deletion configs/experiment/train/star_pause.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

defaults:
- override /data: gsm8k
- override /rl_algorithm: star
- override /rl_algorithm: star_on_policy
- override /rl_algorithm/policy/model/language_model: pause_from_pretrained
- override /rl_algorithm/policy/model/peft_config: null #peft is already there
- override /rl_algorithm/reward: gsm8k
Expand Down
4 changes: 4 additions & 0 deletions configs/rl_algorithm/buffer/lm_rollout_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: lm_stable_baselines.buffers.LMRolloutBuffer

advantage_threshold: null
filler_token: -100
2 changes: 2 additions & 0 deletions configs/rl_algorithm/off_policy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ n_envs: 1
policy_class: ${.policy._target_}
policy_kwargs: ${get_dict_except:${.policy},"_target_"}

buffer_class_keyword: 'replay_buffer'

# #fetch _target_ argument from buffer
replay_buffer_class: ${.buffer._target_}
# #fetch all arguments from buffer except _target_
Expand Down
50 changes: 50 additions & 0 deletions configs/rl_algorithm/on_policy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
defaults:
- environment: language_model_env
- reward: default
- policy: llm_base_policy
- buffer: lm_rollout_buffer


data_collator:
_target_: trl.DataCollatorForCompletionOnlyLM
response_template:
_target_: src.utils.hydra_custom_resolvers.get_module_attr
module_and_attr: src.utils.constants.ANSWER_TEMPLATE

_target_: stable_baselines3.common.on_policy_algorithm.OnPolicyAlgorithm
n_envs: 1
## TODO: What to do here ? Should I just put this in another config file ?
policy_class: ${.policy._target_}
policy_kwargs: ${get_dict_except:${.policy},"_target_"}

buffer_class_keyword: 'rollout_buffer'

# env:
learning_rate: 1e-5
n_steps: 10
gamma: 0.99
gae_lambda: 0.999
ent_coef: 0.5
vf_coef: 0.5
max_grad_norm: 0.5
use_sde: false
sde_sample_freq: -1
rollout_buffer_class: ${.buffer._target_} # #fetch _target_ argument from buffer
rollout_buffer_kwargs: ${get_dict_except:${.buffer},"_target_"} # #fetch all arguments from buffer except _target_
stats_window_size: ${.n_steps}
tensorboard_log: null
monitor_wrapper: true

verbose: 0
seed: 42
device: "auto"
_init_setup_model: true
supported_action_spaces: null








2 changes: 1 addition & 1 deletion configs/rl_algorithm/star.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- off_policy

_target_: lm_stable_baselines.training_algorithms.STaR
_target_: lm_stable_baselines.training_algorithms.star.STaR
5 changes: 5 additions & 0 deletions configs/rl_algorithm/star_on_policy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- on_policy

_target_: lm_stable_baselines.training_algorithms.star_on_policy.STaR
batch_size: 3
3 changes: 2 additions & 1 deletion lm_stable_baselines/buffers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from lm_stable_baselines.buffers.lm_replay_buffer import LMReplayBuffer
from lm_stable_baselines.buffers.lm_replay_buffer import LMReplayBuffer
from lm_stable_baselines.buffers.lm_rollout_buffer import LMRolloutBuffer
96 changes: 96 additions & 0 deletions lm_stable_baselines/buffers/lm_rollout_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from stable_baselines3.common.buffers import RolloutBuffer
import warnings
from typing import Union, List, Dict, Any, Optional
import transformers
import numpy as np
import torch
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.type_aliases import RolloutBufferSamples
from lm_stable_baselines.utils import remove_filler_tokens

def double_indexing(array: np.ndarray, idx1: np.ndarray, idx2: Optional[np.ndarray] = None) -> np.ndarray:
if idx2 is None:
return array[idx1]
return array[idx1][idx2]

class LMRolloutBuffer(RolloutBuffer):
def __init__(
self,
*args,
tokenizer: transformers.PreTrainedTokenizer = None,
advantage_threshold: float = None,
filler_token = -100,
**kwargs
):
super().__init__(*args, **kwargs)
self.set_filler_token(filler_token)
self.tokenizer = tokenizer
self.advantage_threshold = advantage_threshold
self.above_threshold_indices = None
self.data_size = 0

def set_tokenizer(self, tokenizer):
self.tokenizer = tokenizer

def reset(self) -> None:
super().reset()
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.long)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.long)
self.above_threshold_indices = None
self.data_size = 0


def set_filler_token(self, filler_token):
self.filler_token = filler_token
self.observations.fill(filler_token)
self.actions.fill(filler_token)


def to_torch(self, array: Union[np.ndarray, torch.Tensor, transformers.BatchEncoding], copy: bool = True) -> Union[torch.Tensor, transformers.BatchEncoding]:
if isinstance(array, transformers.BatchEncoding):
return {k: v.to(self.device) for k,v in array.items()}
elif isinstance(array, torch.Tensor):
return array.to(self.device)
return super().to_torch(array, copy)


def find_where_advantage_exceeds_threshold(self, advantage: np.ndarray) -> None:
self.above_threshold_indices = np.where(advantage > self.advantage_threshold)
self.data_size = len(self.above_threshold_indices[0])


def sample_batch(self, batch_size, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
# Get the positions of the allowed indices (where the matrix is 1)
allowed_indices = self.above_threshold_indices if self.above_threshold_indices is not None else np.arange(self.buffer_size)

# Sample randomly from the allowed indices
idx = np.random.choice(len(allowed_indices), size=batch_size, replace=True)
sampled_positions = (allowed_indices[0][idx], allowed_indices[1][idx])

obs = self.observations[sampled_positions]

obs = self.tokenizer(
self.tokenizer.batch_decode(
remove_filler_tokens(obs[..., 1:].long(), self.filler_token) # remove the first token (the bos token, tokenizer will re-add it)
),
return_tensors="pt", padding=True, truncation=True
)

actions = self.tokenizer(
self.tokenizer.batch_decode(
remove_filler_tokens(self.actions[sampled_positions], self.filler_token) # don't remove the first token (since it's an action, it didn't start with a bos token)
),
return_tensors="pt", padding=True, truncation=True
)["input_ids"][..., 1:] # remove the first token (the bos token, actions should not have it)

data = (
self.observations[sampled_positions],
self.actions[sampled_positions],
self.values[sampled_positions].flatten(),
self.log_probs[sampled_positions].flatten(),
self.advantages[sampled_positions].flatten(),
self.returns[sampled_positions].flatten(),
)

return RolloutBufferSamples(*tuple(map(self.to_torch, data)))

19 changes: 15 additions & 4 deletions lm_stable_baselines/policies/llm_base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,21 @@ def extract_features(self, obs: PyTorchObs, features_extractor: Optional[BaseFea


def forward(self, obs: PyTorchObs, labels = None) -> torch.Tensor:
feature = self.extract_features(obs)
feature = {k: v.to(self.device) for k, v in feature.items()}
feature["labels"] = labels.to(self.device) if labels is not None else None
return self.lm(**feature)
# feature = self.extract_features(obs)
# feature = {k: v.to(self.device) for k, v in feature.items()}
# feature["labels"] = labels.to(self.device) if labels is not None else None
actions = self._predict(obs)
padded_actions = actions.clone()
padded_actions[actions == self.filler_token] = self.tokenizer.pad_token_id
logprobs = torch.log_softmax(self.lm(padded_actions).logits, dim = -1)
mask = (actions != self.filler_token).float()
logprob_actions = torch.gather(logprobs, 2, padded_actions.unsqueeze(-1)).squeeze(-1)
logprobs = (logprob_actions * mask).sum(dim = 1)
values = self.predict_values(obs)
return actions, values, logprobs

def predict_values(self, obs: PyTorchObs) -> torch.Tensor:
raise NotImplementedError

def post_predict(self, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
#remove the input tokens from the output
Expand Down
98 changes: 98 additions & 0 deletions lm_stable_baselines/training_algorithms/star_on_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
import torch
from stable_baselines3.common.type_aliases import PyTorchObs
from lm_stable_baselines.environments import LanguageModelEnv
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import RolloutReturn, TrainFreq
from stable_baselines3.common.buffers import RolloutBuffer, DictRolloutBuffer
from gymnasium import spaces
from typing import Optional, Union, Dict, Any, List
import numpy as np
from lm_stable_baselines.utils import add_filler_tokens
from copy import deepcopy

class STaR(OnPolicyAlgorithm):

def __init__(self,*args, loss_computed_in_forward_pass, batch_size ,**kwargs):
super().__init__(*args, **kwargs)
assert all([isinstance(myenv, LanguageModelEnv) for myenv in self.env.envs]), "All environments must be of type LanguageModelEnv"
all_filler_token = [myenv.filler_token for myenv in self.env.envs]
assert all([filler_token == all_filler_token[0] for filler_token in all_filler_token]), "All environments must have the same filler token"
self.policy.filler_token = all_filler_token[0]
self.rollout_buffer.set_filler_token(all_filler_token[0])
self.env.set_filler_token(all_filler_token[0])
self.loss_computed_in_forward_pass = loss_computed_in_forward_pass
self.policy.predict_values = self.predict_values
self.batch_size = batch_size

def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
) -> RolloutReturn:

og_padding_side = self.policy.tokenizer.padding_side
self.policy.tokenizer.padding_side = "left"
res = super().collect_rollouts(
env,
callback,
rollout_buffer,
n_rollout_steps,
)
self.policy.tokenizer.padding_side = og_padding_side
return res

# set the forward pass of the base policy
@staticmethod
def predict_values(obs: PyTorchObs) -> torch.Tensor:
# return -1 for all values
return torch.ones(obs.shape[0]) * 0


def train(self) -> None:
self.policy.train()

self._update_learning_rate(self.policy.optimizer)
nll_losses = []

self.rollout_buffer.find_where_advantage_exceeds_threshold(self.rollout_buffer.advantages)
n_batches = self.rollout_buffer.data_size // self.batch_size

for _ in range(n_batches):

self._n_updates += 1

data = self.rollout_buffer.sample_batch(self.batch_size, env=self._vec_normalize_env)

if self.loss_computed_in_forward_pass:
labels = data.next_observations["input_ids"]
labels_list = list(labels.cpu())
collated_labels = self.data_collator(labels_list)
labels = collated_labels["labels"] # check with self.policy.tokenizer.decode(labels[0][labels[0]>0])
else:
labels = None

output = self.policy(data.next_observations, labels=labels)

if self.loss_computed_in_forward_pass:
nll_loss = output.loss
else:
nll_loss = self.policy.compute_nll_loss(output.logits, data.next_observations)

nll_losses.append(nll_loss.item())

self.policy.optimizer.zero_grad()

nll_loss.backward()

self.policy.optimizer.step()


self.logger.record("train/nll_loss", np.mean(nll_losses))
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")


5 changes: 3 additions & 2 deletions src/utils/instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ def instantiate_rl_algorithm(rl_cfg, lm, tokenizer, environment, logger=None):
keys_to_delete = "environment","reward","policy","buffer","n_envs"
for key in keys_to_delete:
del cp[key]
buffer_class_keyword = cp.pop("buffer_class_keyword")

cp["replay_buffer_class"] = hydra.utils.get_class(cp["replay_buffer_class"])
cp["replay_buffer_kwargs"]["tokenizer"] = tokenizer
cp[buffer_class_keyword+"_class"] = hydra.utils.get_class(cp[buffer_class_keyword+"_class"])
cp[buffer_class_keyword+"_kwargs"]["tokenizer"] = tokenizer
cp["policy"] = hydra.utils.get_class(cp.pop("policy_class"))
cp["policy_kwargs"]["lm"] = lm
cp["policy_kwargs"]["tokenizer"] = tokenizer
Expand Down

0 comments on commit a609ab7

Please sign in to comment.