Skip to content

Commit

Permalink
Merge branch 'pl' of https://github.com/epfl-dlab/PauseToken into pl
Browse files Browse the repository at this point in the history
  • Loading branch information
nbaldwin98 committed Oct 23, 2024
2 parents 5511be6 + e30c14f commit 7a7bc8c
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 95 deletions.
2 changes: 1 addition & 1 deletion configs/experiment/train/star_on_policy_pause.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ data:


trainer:
inner_loop_timesteps: 10
inner_loop_timesteps: 3
n_outer_loops: 5
progress_bar: false
num_val_samples: 10
Expand Down
66 changes: 35 additions & 31 deletions lm_stable_baselines/buffers/lm_rollout_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
filler_token = -100,
**kwargs
):
self.filler_token = filler_token
super().__init__(*args, **kwargs)
self.set_filler_token(filler_token)
self.tokenizer = tokenizer
Expand All @@ -34,8 +35,8 @@ def set_tokenizer(self, tokenizer):

def reset(self) -> None:
super().reset()
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.long)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.long)
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.long) + self.filler_token
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.long) +self.filler_token
self.above_threshold_indices = None
self.data_size = 0

Expand All @@ -55,42 +56,45 @@ def to_torch(self, array: Union[np.ndarray, torch.Tensor, transformers.BatchEnco


def find_where_advantage_exceeds_threshold(self, advantage: np.ndarray) -> None:
if self.advantage_threshold is None:
self.advantage_threshold = - np.inf
self.above_threshold_indices = np.where(advantage > self.advantage_threshold)
self.remaining_indices = None
self.data_size = len(self.above_threshold_indices[0])



def sample_batch(self, batch_size, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
# Get the positions of the allowed indices (where the matrix is 1)
allowed_indices = self.above_threshold_indices if self.above_threshold_indices is not None else np.arange(self.buffer_size)

# Sample randomly from the allowed indices
idx = np.random.choice(len(allowed_indices), size=batch_size, replace=True)
sampled_positions = (allowed_indices[0][idx], allowed_indices[1][idx])

obs = self.observations[sampled_positions]
# Initialize remaining indices if it's the first pass or if we've exhausted the dataset
if self.remaining_indices is None or len(self.remaining_indices[0]) == 0:
allowed_indices = self.above_threshold_indices if self.above_threshold_indices is not None else np.arange(self.buffer_size)
# Shuffle the allowed indices
shuffled_indices = np.random.permutation(np.arange(len(allowed_indices[0])))
# Store shuffled indices for further sampling
self.remaining_indices = (allowed_indices[0][shuffled_indices], allowed_indices[1][shuffled_indices])

obs = self.tokenizer(
self.tokenizer.batch_decode(
remove_filler_tokens(obs[..., 1:].long(), self.filler_token) # remove the first token (the bos token, tokenizer will re-add it)
),
return_tensors="pt", padding=True, truncation=True
)
# Sample from the remaining indices without replacement
num_remaining = len(self.remaining_indices[0])
num_to_sample = min(batch_size, num_remaining)

actions = self.tokenizer(
self.tokenizer.batch_decode(
remove_filler_tokens(self.actions[sampled_positions], self.filler_token) # don't remove the first token (since it's an action, it didn't start with a bos token)
),
return_tensors="pt", padding=True, truncation=True
)["input_ids"][..., 1:] # remove the first token (the bos token, actions should not have it)
idx = np.arange(num_remaining)[:num_to_sample]
sampled_positions = (self.remaining_indices[0][idx], self.remaining_indices[1][idx])

# Remove the sampled positions from remaining indices
self.remaining_indices = (
np.delete(self.remaining_indices[0], idx),
np.delete(self.remaining_indices[1], idx)
)

return self.sample_indices(sampled_positions)

def sample_indices(self, idx, padding='right') -> RolloutBufferSamples:
assert idx[0].shape == idx[1].shape, "The indices must have the same shape"
data = (
self.observations[sampled_positions],
self.actions[sampled_positions],
self.values[sampled_positions].flatten(),
self.log_probs[sampled_positions].flatten(),
self.advantages[sampled_positions].flatten(),
self.returns[sampled_positions].flatten(),
self.observations[idx],
self.actions[idx],
self.values[idx].flatten(),
self.log_probs[idx].flatten(),
self.advantages[idx].flatten(),
self.returns[idx].flatten(),
)

return RolloutBufferSamples(*tuple(map(self.to_torch, data)))

31 changes: 28 additions & 3 deletions lm_stable_baselines/environments/language_model_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
from lm_stable_baselines.utils import remove_filler_tokens
import warnings
import torch
from torch import LongTensor
class LanguageModelEnv(Env):
""" Environment for language models. This class is a subclass of gym.Env and is used to handle language model environments.
This environment allows to sample from a dataset and compute rewards based on the model output and the ground truth.
Expand Down Expand Up @@ -75,12 +77,22 @@ def reprermute_dataset_id_list(cls):
# TODO: check if this is necessary
# NICKY:
# I don't think we nee this. We want dataset_id_list to be a static variable that is shared across all instances of the class
#   We know which sample to take thanks to LanguageModelEnv.next_idx
# We know which sample to take thanks to LanguageModelEnv.next_idx
# if LanguageModelEnv.n_envs != -1:
# self.dataset_id_list = LanguageModelEnv.dataset_id_list[self.env_idx::LanguageModelEnv.n_envs]
# else:
# self.dataset_id_list = LanguageModelEnv.dataset_id_list

def _step(self, curr_obs, action):
if isinstance(curr_obs, list):
curr_obs.extend(action)
elif isinstance(curr_obs, torch.Tensor):
curr_obs = torch.cat([curr_obs, action], dim = 0)
else:
raise ValueError("curr_obs should be a list or a tensor")
return curr_obs


def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
""" Apply an action to the environment. For a language model it's simply adding the action to the current state
Expand All @@ -91,11 +103,21 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[
"""

clean_action = remove_filler_tokens(action, self.filler_token).squeeze(-1).tolist()
self.current_state.extend(clean_action)
self.current_state = self._step(self.current_state, clean_action)
observation , reward, terminated, truncated, info = self._get_obs()

return observation, reward, terminated, truncated, info

def next_observation_from_observation_and_action(self, obs: LongTensor, actions: LongTensor) -> List[List[int]]:
#assumption: filler tokens have been removed
unpadded_obs = remove_filler_tokens(obs, self.filler_token)
unpadded_acts = remove_filler_tokens(actions, self.filler_token)

new_observations = [self._step(observation,action) for observation, action in zip(unpadded_obs,unpadded_acts)]

return new_observations


def is_terminated(self, state: List[int]):
""" Check if the state is terminated
Expand Down Expand Up @@ -216,6 +238,8 @@ def _get_obs(self):
info = {}
return np.array(self.current_state) , reward, is_terminated, is_truncated, info



def render(self):
""" Render the current state
Expand All @@ -230,4 +254,5 @@ def close(self):
This is critical for closing rendering windows, database or HTTP connections.
Calling ``close`` on an already closed environment has no effect and won't raise an error.
"""
pass
pass

Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ def set_stage(self, stage: str, **kwargs):
"""
for env in self.envs:
env.set_stage(stage, **kwargs)


23 changes: 12 additions & 11 deletions lm_stable_baselines/policies/llm_base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,30 @@ def forward(self, obs: PyTorchObs, labels = None) -> torch.Tensor:
# feature = self.extract_features(obs)
# feature = {k: v.to(self.device) for k, v in feature.items()}
# feature["labels"] = labels.to(self.device) if labels is not None else None
actions = self._predict(obs)
padded_actions = actions.clone()
padded_actions[actions == self.filler_token] = self.tokenizer.pad_token_id
logprobs = torch.log_softmax(self.lm(padded_actions).logits, dim = -1)
mask = (actions != self.filler_token).float()
logprob_actions = torch.gather(logprobs, 2, padded_actions.unsqueeze(-1)).squeeze(-1)

next_obs, actions, unpadded_actions = self._predict(obs).values()

logprobs = torch.log_softmax(self.lm(next_obs).logits, dim = -1)[:, (-unpadded_actions.shape[1]-1):-1, ...]
mask = (unpadded_actions != self.tokenizer.pad_token_id).float()
logprob_actions = torch.gather(logprobs, 2, unpadded_actions.unsqueeze(-1)).squeeze(-1)
logprobs = (logprob_actions * mask).sum(dim = 1)
values = self.predict_values(obs)
values = self.predict_values(next_obs)
return actions, values, logprobs

def predict_values(self, obs: PyTorchObs) -> torch.Tensor:
raise NotImplementedError

def post_predict(self, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
#remove the input tokens from the output
actions = outputs[:, inputs.shape[-1]:]
actions = outputs[:, inputs.shape[-1]:].clone()
padded_actions = actions.clone()
#replace all pad tokens with filler tokens
actions[actions == self.tokenizer.pad_token_id] = self.filler_token
padded_actions[actions == self.tokenizer.pad_token_id] = self.filler_token

action_space_dim = self.action_space.shape[0]
actions = add_filler_tokens(actions, action_space_dim, self.filler_token)
padded_actions = add_filler_tokens(padded_actions, action_space_dim, self.filler_token)

return actions
return {'next_observation':outputs, 'actions': padded_actions, 'unpadded_actions': actions}

def pre_predict(self, feature: PyTorchObs) -> PyTorchObs:
pass
Expand Down
26 changes: 20 additions & 6 deletions lm_stable_baselines/training_algorithms/star_on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ def predict_values(obs: PyTorchObs) -> torch.Tensor:
# return -1 for all values
return torch.ones(obs.shape[0]) * 0

def process_rollouts(self, data):
next_obs = self.env.envs[0].next_observation_from_observation_and_action(data.observations[:,1:], data.actions)
#create the next observation by interacting with the environment and then tokenizing to get input_ids + attention mask
next_observation = self.policy.tokenizer.pad(
{'input_ids': next_obs},
return_tensors="pt",
padding=True,
)
return next_observation

def train(self) -> None:
self.policy.train()
Expand All @@ -61,27 +70,32 @@ def train(self) -> None:

self.rollout_buffer.find_where_advantage_exceeds_threshold(self.rollout_buffer.advantages)
n_batches = self.rollout_buffer.data_size // self.batch_size


self.policy.tokenizer.padding_side = "right"
for _ in range(n_batches):

self._n_updates += 1

data = self.rollout_buffer.sample_batch(self.batch_size, env=self._vec_normalize_env)

next_observation = self.process_rollouts(data)
if self.loss_computed_in_forward_pass:
labels = data.next_observations["input_ids"]
labels = next_observation["input_ids"]
labels_list = list(labels.cpu())
collated_labels = self.data_collator(labels_list)
labels = collated_labels["labels"] # check with self.policy.tokenizer.decode(labels[0][labels[0]>0])
else:
labels = None

output = self.policy(data.next_observations, labels=labels)
output = self.policy.lm(input_ids=next_observation['input_ids'].to(self.device),
attention_mask=next_observation['attention_mask'].to(self.device),
labels=labels.to(self.device))

if self.loss_computed_in_forward_pass:
nll_loss = output.loss
#if control token model you can also get these losses:
#control_token_loss = output.ctrl_tok_loss
#lm_loss = output.lm_loss
else:
nll_loss = self.policy.compute_nll_loss(output.logits, data.next_observations)
nll_loss = self.policy.compute_nll_loss(output.logits, labels)

nll_losses.append(nll_loss.item())

Expand Down
Loading

0 comments on commit 7a7bc8c

Please sign in to comment.