-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'pl' of https://github.com/epfl-dlab/PauseToken into pl
- Loading branch information
Showing
12 changed files
with
326 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# @package _global_ | ||
|
||
defaults: | ||
- override /data: gsm8k | ||
- override /rl_algorithm: star_on_policy | ||
- override /rl_algorithm/policy/model/language_model: pause_from_pretrained | ||
- override /rl_algorithm/policy/model/peft_config: null #peft is already there | ||
- override /rl_algorithm/reward: gsm8k | ||
- override /metrics: gsm8k | ||
|
||
name: "star on gsm8k" | ||
run_name: "debug" | ||
task_name: "train" | ||
|
||
|
||
|
||
data: | ||
additional_transformation: | ||
_target_: functools.partial | ||
_args_: | ||
- ${get_method:src.utils.trainer_utils.inference_formatting_function} | ||
eos_token: ${get_obj_attr:${rl_algorithm.policy.model.tokenizer},[eos_token]} | ||
|
||
|
||
trainer: | ||
inner_loop_timesteps: 10 | ||
n_outer_loops: 5 | ||
progress_bar: false | ||
num_val_samples: 10 | ||
save_top_k: 3 | ||
metric_for_best_model: "val/accuracy" | ||
#whether the metric for the best model is min or max (True = min (the lower the better), False = max (the higher the better)) | ||
metric_for_best_model_mode_is_min: false | ||
|
||
|
||
|
||
rl_algorithm: | ||
n_envs: 4 | ||
loss_computed_in_forward_pass: true | ||
buffer: | ||
advantage_threshold: 0 | ||
|
||
policy: | ||
model: | ||
language_model: | ||
pretrained_model_name_or_path: /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-08-29_10-23-33/final | ||
post_instanciation_method_calls: | ||
- method: unfreeze_all | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
_target_: lm_stable_baselines.buffers.LMRolloutBuffer | ||
|
||
advantage_threshold: null | ||
filler_token: -100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
defaults: | ||
- environment: language_model_env | ||
- reward: default | ||
- policy: llm_base_policy | ||
- buffer: lm_rollout_buffer | ||
|
||
|
||
data_collator: | ||
_target_: trl.DataCollatorForCompletionOnlyLM | ||
response_template: | ||
_target_: src.utils.hydra_custom_resolvers.get_module_attr | ||
module_and_attr: src.utils.constants.ANSWER_TEMPLATE | ||
|
||
_target_: stable_baselines3.common.on_policy_algorithm.OnPolicyAlgorithm | ||
n_envs: 1 | ||
## TODO: What to do here ? Should I just put this in another config file ? | ||
policy_class: ${.policy._target_} | ||
policy_kwargs: ${get_dict_except:${.policy},"_target_"} | ||
|
||
buffer_class_keyword: 'rollout_buffer' | ||
|
||
# env: | ||
learning_rate: 1e-5 | ||
n_steps: 10 | ||
gamma: 0.99 | ||
gae_lambda: 0.999 | ||
ent_coef: 0.5 | ||
vf_coef: 0.5 | ||
max_grad_norm: 0.5 | ||
use_sde: false | ||
sde_sample_freq: -1 | ||
rollout_buffer_class: ${.buffer._target_} # #fetch _target_ argument from buffer | ||
rollout_buffer_kwargs: ${get_dict_except:${.buffer},"_target_"} # #fetch all arguments from buffer except _target_ | ||
stats_window_size: ${.n_steps} | ||
tensorboard_log: null | ||
monitor_wrapper: true | ||
|
||
verbose: 0 | ||
seed: 42 | ||
device: "auto" | ||
_init_setup_model: true | ||
supported_action_spaces: null | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
defaults: | ||
- off_policy | ||
|
||
_target_: lm_stable_baselines.training_algorithms.STaR | ||
_target_: lm_stable_baselines.training_algorithms.star.STaR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
defaults: | ||
- on_policy | ||
|
||
_target_: lm_stable_baselines.training_algorithms.star_on_policy.STaR | ||
batch_size: 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from lm_stable_baselines.buffers.lm_replay_buffer import LMReplayBuffer | ||
from lm_stable_baselines.buffers.lm_replay_buffer import LMReplayBuffer | ||
from lm_stable_baselines.buffers.lm_rollout_buffer import LMRolloutBuffer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from stable_baselines3.common.buffers import RolloutBuffer | ||
import warnings | ||
from typing import Union, List, Dict, Any, Optional | ||
import transformers | ||
import numpy as np | ||
import torch | ||
from stable_baselines3.common.vec_env import VecNormalize | ||
from stable_baselines3.common.type_aliases import RolloutBufferSamples | ||
from lm_stable_baselines.utils import remove_filler_tokens | ||
|
||
def double_indexing(array: np.ndarray, idx1: np.ndarray, idx2: Optional[np.ndarray] = None) -> np.ndarray: | ||
if idx2 is None: | ||
return array[idx1] | ||
return array[idx1][idx2] | ||
|
||
class LMRolloutBuffer(RolloutBuffer): | ||
def __init__( | ||
self, | ||
*args, | ||
tokenizer: transformers.PreTrainedTokenizer = None, | ||
advantage_threshold: float = None, | ||
filler_token = -100, | ||
**kwargs | ||
): | ||
super().__init__(*args, **kwargs) | ||
self.set_filler_token(filler_token) | ||
self.tokenizer = tokenizer | ||
self.advantage_threshold = advantage_threshold | ||
self.above_threshold_indices = None | ||
self.data_size = 0 | ||
|
||
def set_tokenizer(self, tokenizer): | ||
self.tokenizer = tokenizer | ||
|
||
def reset(self) -> None: | ||
super().reset() | ||
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.long) | ||
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.long) | ||
self.above_threshold_indices = None | ||
self.data_size = 0 | ||
|
||
|
||
def set_filler_token(self, filler_token): | ||
self.filler_token = filler_token | ||
self.observations.fill(filler_token) | ||
self.actions.fill(filler_token) | ||
|
||
|
||
def to_torch(self, array: Union[np.ndarray, torch.Tensor, transformers.BatchEncoding], copy: bool = True) -> Union[torch.Tensor, transformers.BatchEncoding]: | ||
if isinstance(array, transformers.BatchEncoding): | ||
return {k: v.to(self.device) for k,v in array.items()} | ||
elif isinstance(array, torch.Tensor): | ||
return array.to(self.device) | ||
return super().to_torch(array, copy) | ||
|
||
|
||
def find_where_advantage_exceeds_threshold(self, advantage: np.ndarray) -> None: | ||
self.above_threshold_indices = np.where(advantage > self.advantage_threshold) | ||
self.data_size = len(self.above_threshold_indices[0]) | ||
|
||
|
||
def sample_batch(self, batch_size, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: | ||
# Get the positions of the allowed indices (where the matrix is 1) | ||
allowed_indices = self.above_threshold_indices if self.above_threshold_indices is not None else np.arange(self.buffer_size) | ||
|
||
# Sample randomly from the allowed indices | ||
idx = np.random.choice(len(allowed_indices), size=batch_size, replace=True) | ||
sampled_positions = (allowed_indices[0][idx], allowed_indices[1][idx]) | ||
|
||
obs = self.observations[sampled_positions] | ||
|
||
obs = self.tokenizer( | ||
self.tokenizer.batch_decode( | ||
remove_filler_tokens(obs[..., 1:].long(), self.filler_token) # remove the first token (the bos token, tokenizer will re-add it) | ||
), | ||
return_tensors="pt", padding=True, truncation=True | ||
) | ||
|
||
actions = self.tokenizer( | ||
self.tokenizer.batch_decode( | ||
remove_filler_tokens(self.actions[sampled_positions], self.filler_token) # don't remove the first token (since it's an action, it didn't start with a bos token) | ||
), | ||
return_tensors="pt", padding=True, truncation=True | ||
)["input_ids"][..., 1:] # remove the first token (the bos token, actions should not have it) | ||
|
||
data = ( | ||
self.observations[sampled_positions], | ||
self.actions[sampled_positions], | ||
self.values[sampled_positions].flatten(), | ||
self.log_probs[sampled_positions].flatten(), | ||
self.advantages[sampled_positions].flatten(), | ||
self.returns[sampled_positions].flatten(), | ||
) | ||
|
||
return RolloutBufferSamples(*tuple(map(self.to_torch, data))) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm | ||
import torch | ||
from stable_baselines3.common.type_aliases import PyTorchObs | ||
from lm_stable_baselines.environments import LanguageModelEnv | ||
from stable_baselines3.common.callbacks import BaseCallback | ||
from stable_baselines3.common.vec_env import VecEnv | ||
from stable_baselines3.common.noise import ActionNoise | ||
from stable_baselines3.common.type_aliases import RolloutReturn, TrainFreq | ||
from stable_baselines3.common.buffers import RolloutBuffer, DictRolloutBuffer | ||
from gymnasium import spaces | ||
from typing import Optional, Union, Dict, Any, List | ||
import numpy as np | ||
from lm_stable_baselines.utils import add_filler_tokens | ||
from copy import deepcopy | ||
|
||
class STaR(OnPolicyAlgorithm): | ||
|
||
def __init__(self,*args, loss_computed_in_forward_pass, batch_size ,**kwargs): | ||
super().__init__(*args, **kwargs) | ||
assert all([isinstance(myenv, LanguageModelEnv) for myenv in self.env.envs]), "All environments must be of type LanguageModelEnv" | ||
all_filler_token = [myenv.filler_token for myenv in self.env.envs] | ||
assert all([filler_token == all_filler_token[0] for filler_token in all_filler_token]), "All environments must have the same filler token" | ||
self.policy.filler_token = all_filler_token[0] | ||
self.rollout_buffer.set_filler_token(all_filler_token[0]) | ||
self.env.set_filler_token(all_filler_token[0]) | ||
self.loss_computed_in_forward_pass = loss_computed_in_forward_pass | ||
self.policy.predict_values = self.predict_values | ||
self.batch_size = batch_size | ||
|
||
def collect_rollouts( | ||
self, | ||
env: VecEnv, | ||
callback: BaseCallback, | ||
rollout_buffer: RolloutBuffer, | ||
n_rollout_steps: int, | ||
) -> RolloutReturn: | ||
|
||
og_padding_side = self.policy.tokenizer.padding_side | ||
self.policy.tokenizer.padding_side = "left" | ||
res = super().collect_rollouts( | ||
env, | ||
callback, | ||
rollout_buffer, | ||
n_rollout_steps, | ||
) | ||
self.policy.tokenizer.padding_side = og_padding_side | ||
return res | ||
|
||
# set the forward pass of the base policy | ||
@staticmethod | ||
def predict_values(obs: PyTorchObs) -> torch.Tensor: | ||
# return -1 for all values | ||
return torch.ones(obs.shape[0]) * 0 | ||
|
||
|
||
def train(self) -> None: | ||
self.policy.train() | ||
|
||
self._update_learning_rate(self.policy.optimizer) | ||
nll_losses = [] | ||
|
||
self.rollout_buffer.find_where_advantage_exceeds_threshold(self.rollout_buffer.advantages) | ||
n_batches = self.rollout_buffer.data_size // self.batch_size | ||
|
||
for _ in range(n_batches): | ||
|
||
self._n_updates += 1 | ||
|
||
data = self.rollout_buffer.sample_batch(self.batch_size, env=self._vec_normalize_env) | ||
|
||
if self.loss_computed_in_forward_pass: | ||
labels = data.next_observations["input_ids"] | ||
labels_list = list(labels.cpu()) | ||
collated_labels = self.data_collator(labels_list) | ||
labels = collated_labels["labels"] # check with self.policy.tokenizer.decode(labels[0][labels[0]>0]) | ||
else: | ||
labels = None | ||
|
||
output = self.policy(data.next_observations, labels=labels) | ||
|
||
if self.loss_computed_in_forward_pass: | ||
nll_loss = output.loss | ||
else: | ||
nll_loss = self.policy.compute_nll_loss(output.logits, data.next_observations) | ||
|
||
nll_losses.append(nll_loss.item()) | ||
|
||
self.policy.optimizer.zero_grad() | ||
|
||
nll_loss.backward() | ||
|
||
self.policy.optimizer.step() | ||
|
||
|
||
self.logger.record("train/nll_loss", np.mean(nll_losses)) | ||
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters