diff --git a/configs/experiment/train/star_on_policy_pause.yaml b/configs/experiment/train/star_on_policy_pause.yaml index 7aba413..2e3549e 100644 --- a/configs/experiment/train/star_on_policy_pause.yaml +++ b/configs/experiment/train/star_on_policy_pause.yaml @@ -23,7 +23,7 @@ data: trainer: - inner_loop_timesteps: 10 + inner_loop_timesteps: 3 n_outer_loops: 5 progress_bar: false num_val_samples: 10 diff --git a/lm_stable_baselines/buffers/lm_rollout_buffer.py b/lm_stable_baselines/buffers/lm_rollout_buffer.py index bf3a7db..dc20847 100644 --- a/lm_stable_baselines/buffers/lm_rollout_buffer.py +++ b/lm_stable_baselines/buffers/lm_rollout_buffer.py @@ -22,6 +22,7 @@ def __init__( filler_token = -100, **kwargs ): + self.filler_token = filler_token super().__init__(*args, **kwargs) self.set_filler_token(filler_token) self.tokenizer = tokenizer @@ -34,8 +35,8 @@ def set_tokenizer(self, tokenizer): def reset(self) -> None: super().reset() - self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.long) - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.long) + self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.long) + self.filler_token + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.long) +self.filler_token self.above_threshold_indices = None self.data_size = 0 @@ -55,42 +56,45 @@ def to_torch(self, array: Union[np.ndarray, torch.Tensor, transformers.BatchEnco def find_where_advantage_exceeds_threshold(self, advantage: np.ndarray) -> None: + if self.advantage_threshold is None: + self.advantage_threshold = - np.inf self.above_threshold_indices = np.where(advantage > self.advantage_threshold) + self.remaining_indices = None self.data_size = len(self.above_threshold_indices[0]) - - + def sample_batch(self, batch_size, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: - # Get the positions of the allowed indices (where the matrix is 1) - allowed_indices = self.above_threshold_indices if self.above_threshold_indices is not None else np.arange(self.buffer_size) - - # Sample randomly from the allowed indices - idx = np.random.choice(len(allowed_indices), size=batch_size, replace=True) - sampled_positions = (allowed_indices[0][idx], allowed_indices[1][idx]) - - obs = self.observations[sampled_positions] + # Initialize remaining indices if it's the first pass or if we've exhausted the dataset + if self.remaining_indices is None or len(self.remaining_indices[0]) == 0: + allowed_indices = self.above_threshold_indices if self.above_threshold_indices is not None else np.arange(self.buffer_size) + # Shuffle the allowed indices + shuffled_indices = np.random.permutation(np.arange(len(allowed_indices[0]))) + # Store shuffled indices for further sampling + self.remaining_indices = (allowed_indices[0][shuffled_indices], allowed_indices[1][shuffled_indices]) - obs = self.tokenizer( - self.tokenizer.batch_decode( - remove_filler_tokens(obs[..., 1:].long(), self.filler_token) # remove the first token (the bos token, tokenizer will re-add it) - ), - return_tensors="pt", padding=True, truncation=True - ) + # Sample from the remaining indices without replacement + num_remaining = len(self.remaining_indices[0]) + num_to_sample = min(batch_size, num_remaining) - actions = self.tokenizer( - self.tokenizer.batch_decode( - remove_filler_tokens(self.actions[sampled_positions], self.filler_token) # don't remove the first token (since it's an action, it didn't start with a bos token) - ), - return_tensors="pt", padding=True, truncation=True - )["input_ids"][..., 1:] # remove the first token (the bos token, actions should not have it) + idx = np.arange(num_remaining)[:num_to_sample] + sampled_positions = (self.remaining_indices[0][idx], self.remaining_indices[1][idx]) + # Remove the sampled positions from remaining indices + self.remaining_indices = ( + np.delete(self.remaining_indices[0], idx), + np.delete(self.remaining_indices[1], idx) + ) + + return self.sample_indices(sampled_positions) + + def sample_indices(self, idx, padding='right') -> RolloutBufferSamples: + assert idx[0].shape == idx[1].shape, "The indices must have the same shape" data = ( - self.observations[sampled_positions], - self.actions[sampled_positions], - self.values[sampled_positions].flatten(), - self.log_probs[sampled_positions].flatten(), - self.advantages[sampled_positions].flatten(), - self.returns[sampled_positions].flatten(), + self.observations[idx], + self.actions[idx], + self.values[idx].flatten(), + self.log_probs[idx].flatten(), + self.advantages[idx].flatten(), + self.returns[idx].flatten(), ) return RolloutBufferSamples(*tuple(map(self.to_torch, data))) - diff --git a/lm_stable_baselines/environments/language_model_env.py b/lm_stable_baselines/environments/language_model_env.py index 3a732ce..8935071 100644 --- a/lm_stable_baselines/environments/language_model_env.py +++ b/lm_stable_baselines/environments/language_model_env.py @@ -7,6 +7,8 @@ import numpy as np from lm_stable_baselines.utils import remove_filler_tokens import warnings +import torch +from torch import LongTensor class LanguageModelEnv(Env): """ Environment for language models. This class is a subclass of gym.Env and is used to handle language model environments. This environment allows to sample from a dataset and compute rewards based on the model output and the ground truth. @@ -75,12 +77,22 @@ def reprermute_dataset_id_list(cls): # TODO: check if this is necessary # NICKY: # I don't think we nee this. We want dataset_id_list to be a static variable that is shared across all instances of the class - #   We know which sample to take thanks to LanguageModelEnv.next_idx + # We know which sample to take thanks to LanguageModelEnv.next_idx # if LanguageModelEnv.n_envs != -1: # self.dataset_id_list = LanguageModelEnv.dataset_id_list[self.env_idx::LanguageModelEnv.n_envs] # else: # self.dataset_id_list = LanguageModelEnv.dataset_id_list + def _step(self, curr_obs, action): + if isinstance(curr_obs, list): + curr_obs.extend(action) + elif isinstance(curr_obs, torch.Tensor): + curr_obs = torch.cat([curr_obs, action], dim = 0) + else: + raise ValueError("curr_obs should be a list or a tensor") + return curr_obs + + def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]: """ Apply an action to the environment. For a language model it's simply adding the action to the current state @@ -91,11 +103,21 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[ """ clean_action = remove_filler_tokens(action, self.filler_token).squeeze(-1).tolist() - self.current_state.extend(clean_action) + self.current_state = self._step(self.current_state, clean_action) observation , reward, terminated, truncated, info = self._get_obs() return observation, reward, terminated, truncated, info + + def next_observation_from_observation_and_action(self, obs: LongTensor, actions: LongTensor) -> List[List[int]]: + #assumption: filler tokens have been removed + unpadded_obs = remove_filler_tokens(obs, self.filler_token) + unpadded_acts = remove_filler_tokens(actions, self.filler_token) + + new_observations = [self._step(observation,action) for observation, action in zip(unpadded_obs,unpadded_acts)] + return new_observations + + def is_terminated(self, state: List[int]): """ Check if the state is terminated @@ -216,6 +238,8 @@ def _get_obs(self): info = {} return np.array(self.current_state) , reward, is_terminated, is_truncated, info + + def render(self): """ Render the current state @@ -230,4 +254,5 @@ def close(self): This is critical for closing rendering windows, database or HTTP connections. Calling ``close`` on an already closed environment has no effect and won't raise an error. """ - pass \ No newline at end of file + pass + diff --git a/lm_stable_baselines/environments/vectorized_environments/lm_dummy_vec_enc.py b/lm_stable_baselines/environments/vectorized_environments/lm_dummy_vec_enc.py index ea61de8..1f73f1e 100644 --- a/lm_stable_baselines/environments/vectorized_environments/lm_dummy_vec_enc.py +++ b/lm_stable_baselines/environments/vectorized_environments/lm_dummy_vec_enc.py @@ -49,3 +49,5 @@ def set_stage(self, stage: str, **kwargs): """ for env in self.envs: env.set_stage(stage, **kwargs) + + diff --git a/lm_stable_baselines/policies/llm_base_policy.py b/lm_stable_baselines/policies/llm_base_policy.py index b92cc5d..b917902 100644 --- a/lm_stable_baselines/policies/llm_base_policy.py +++ b/lm_stable_baselines/policies/llm_base_policy.py @@ -186,14 +186,14 @@ def forward(self, obs: PyTorchObs, labels = None) -> torch.Tensor: # feature = self.extract_features(obs) # feature = {k: v.to(self.device) for k, v in feature.items()} # feature["labels"] = labels.to(self.device) if labels is not None else None - actions = self._predict(obs) - padded_actions = actions.clone() - padded_actions[actions == self.filler_token] = self.tokenizer.pad_token_id - logprobs = torch.log_softmax(self.lm(padded_actions).logits, dim = -1) - mask = (actions != self.filler_token).float() - logprob_actions = torch.gather(logprobs, 2, padded_actions.unsqueeze(-1)).squeeze(-1) + + next_obs, actions, unpadded_actions = self._predict(obs).values() + + logprobs = torch.log_softmax(self.lm(next_obs).logits, dim = -1)[:, (-unpadded_actions.shape[1]-1):-1, ...] + mask = (unpadded_actions != self.tokenizer.pad_token_id).float() + logprob_actions = torch.gather(logprobs, 2, unpadded_actions.unsqueeze(-1)).squeeze(-1) logprobs = (logprob_actions * mask).sum(dim = 1) - values = self.predict_values(obs) + values = self.predict_values(next_obs) return actions, values, logprobs def predict_values(self, obs: PyTorchObs) -> torch.Tensor: @@ -201,14 +201,15 @@ def predict_values(self, obs: PyTorchObs) -> torch.Tensor: def post_predict(self, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor: #remove the input tokens from the output - actions = outputs[:, inputs.shape[-1]:] + actions = outputs[:, inputs.shape[-1]:].clone() + padded_actions = actions.clone() #replace all pad tokens with filler tokens - actions[actions == self.tokenizer.pad_token_id] = self.filler_token + padded_actions[actions == self.tokenizer.pad_token_id] = self.filler_token action_space_dim = self.action_space.shape[0] - actions = add_filler_tokens(actions, action_space_dim, self.filler_token) + padded_actions = add_filler_tokens(padded_actions, action_space_dim, self.filler_token) - return actions + return {'next_observation':outputs, 'actions': padded_actions, 'unpadded_actions': actions} def pre_predict(self, feature: PyTorchObs) -> PyTorchObs: pass diff --git a/lm_stable_baselines/training_algorithms/star_on_policy.py b/lm_stable_baselines/training_algorithms/star_on_policy.py index 14f51ee..55b3ee8 100644 --- a/lm_stable_baselines/training_algorithms/star_on_policy.py +++ b/lm_stable_baselines/training_algorithms/star_on_policy.py @@ -52,6 +52,15 @@ def predict_values(obs: PyTorchObs) -> torch.Tensor: # return -1 for all values return torch.ones(obs.shape[0]) * 0 + def process_rollouts(self, data): + next_obs = self.env.envs[0].next_observation_from_observation_and_action(data.observations[:,1:], data.actions) + #create the next observation by interacting with the environment and then tokenizing to get input_ids + attention mask + next_observation = self.policy.tokenizer.pad( + {'input_ids': next_obs}, + return_tensors="pt", + padding=True, + ) + return next_observation def train(self) -> None: self.policy.train() @@ -61,27 +70,32 @@ def train(self) -> None: self.rollout_buffer.find_where_advantage_exceeds_threshold(self.rollout_buffer.advantages) n_batches = self.rollout_buffer.data_size // self.batch_size - + + self.policy.tokenizer.padding_side = "right" for _ in range(n_batches): self._n_updates += 1 - data = self.rollout_buffer.sample_batch(self.batch_size, env=self._vec_normalize_env) - + next_observation = self.process_rollouts(data) if self.loss_computed_in_forward_pass: - labels = data.next_observations["input_ids"] + labels = next_observation["input_ids"] labels_list = list(labels.cpu()) collated_labels = self.data_collator(labels_list) labels = collated_labels["labels"] # check with self.policy.tokenizer.decode(labels[0][labels[0]>0]) else: labels = None - output = self.policy(data.next_observations, labels=labels) + output = self.policy.lm(input_ids=next_observation['input_ids'].to(self.device), + attention_mask=next_observation['attention_mask'].to(self.device), + labels=labels.to(self.device)) if self.loss_computed_in_forward_pass: nll_loss = output.loss + #if control token model you can also get these losses: + #control_token_loss = output.ctrl_tok_loss + #lm_loss = output.lm_loss else: - nll_loss = self.policy.compute_nll_loss(output.logits, data.next_observations) + nll_loss = self.policy.compute_nll_loss(output.logits, labels) nll_losses.append(nll_loss.item()) diff --git a/src/trainer/lmsb_trainer.py b/src/trainer/lmsb_trainer.py index 42b2745..72df58e 100644 --- a/src/trainer/lmsb_trainer.py +++ b/src/trainer/lmsb_trainer.py @@ -2,6 +2,7 @@ from stable_baselines3.common.base_class import BaseAlgorithm from lm_stable_baselines.buffers import LMReplayBuffer import warnings +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import TrainFreq, TrainFrequencyUnit from copy import deepcopy import numpy as np @@ -83,44 +84,101 @@ def evaluation(self, stage: str): ################# PART 1: set arguments necessary for performing rollout #################  n_steps = int(math.ceil(self.num_val_samples/self.rl_algorithm.n_envs)) - kwargs = deepcopy(self.rl_algorithm.replay_buffer_kwargs) - kwargs["reward_threshold"] = None - buffer_size = n_steps * self.rl_algorithm.n_envs + 1 # avoids weird behavior of sb3 when buffer is full (make it easier for sampling, see below) - validation_replay_buffer = LMReplayBuffer( + buffer_name = self.rl_algorithm.buffer_class_keyword + buffer_name_kwargs = buffer_name + "_kwargs" + buffer_class = getattr(self.rl_algorithm, buffer_name).__class__ + kwargs = deepcopy(getattr(self.rl_algorithm, buffer_name_kwargs)) + kwargs["advantage_threshold"] = None + + buffer_size = n_steps+ 1 # avoids weird behavior of sb3 when buffer is full (make it easier for sampling, see below) + + validation_buffer = buffer_class( buffer_size, self.rl_algorithm.observation_space, self.rl_algorithm.action_space, device = self.rl_algorithm.device, - n_envs = 1, - optimize_memory_usage = False, + n_envs = self.rl_algorithm.n_envs, **kwargs ) - train_freq = TrainFreq(frequency= n_steps, unit = TrainFrequencyUnit.STEP) ################# PART 2: perform rollout ################# #TODO: For the moment, this is fine because 1 step = 1 sample, but in the future, we need to change this for the correct number of samples - rollout = self.rl_algorithm.collect_rollouts( - self.rl_algorithm.env, - train_freq= train_freq, - action_noise=self.rl_algorithm.action_noise, - learning_starts=0, - replay_buffer=validation_replay_buffer, - log_interval=self.learn_kwargs["log_interval"], - callback=self.learn_kwargs["callback"] - ) - - #safety check - if rollout.n_episodes < self.num_val_samples: - raise ValueError( - f"Expected {self.num_val_samples} samples, but got {rollout.n_episodes} samples, this may be due to the environment not being terminated" + if isinstance(self.rl_algorithm, OffPolicyAlgorithm): + train_freq = TrainFreq(frequency= n_steps, unit = TrainFrequencyUnit.STEP) + + rollout = self.rl_algorithm.collect_rollouts( + self.rl_algorithm.env, + train_freq= train_freq, + action_noise=self.rl_algorithm.action_noise, + learning_starts=0, + replay_buffer=validation_buffer, + log_interval=self.learn_kwargs["log_interval"], + callback=self.learn_kwargs["callback"] ) + + #safety check + if rollout.n_episodes < self.num_val_samples: + raise ValueError( + f"Expected {self.num_val_samples} samples, but got {rollout.n_episodes} samples, this may be due to the environment not being terminated" + ) + + ################# PART 3: Collect rollouts from Replay Buffer ################# + #Not sure why sb3 distiguishes theses cases but it's necessary to do so to read the samples correctly in the proper order + # I could have done just one case (since i set optimize_memory_usage to True above), but I wanted to show the logic of sb3 + samps_order = np.arange(self.num_val_samples) - 1 if validation_buffer.optimize_memory_usage else np.arange(self.num_val_samples) + val_samps = validation_buffer._get_samples(samps_order) + + + + texts = decode_and_strip_pad_tokens( + val_samps.next_observations["input_ids"], + self.rl_algorithm.policy.tokenizer.pad_token_id, + self.rl_algorithm.policy.tokenizer + ) + + input_texts = decode_and_strip_pad_tokens( + val_samps.observations["input_ids"], + self.rl_algorithm.policy.tokenizer.pad_token_id, + self.rl_algorithm.policy.tokenizer + ) + + predicted_outputs = decode_and_strip_pad_tokens( + val_samps.actions, + self.rl_algorithm.policy.tokenizer.pad_token_id, + self.rl_algorithm.policy.tokenizer + ) + + else: + rollout = self.rl_algorithm.collect_rollouts( + self.rl_algorithm.env, + callback=self.learn_kwargs["callback"], + rollout_buffer=validation_buffer, + n_rollout_steps=n_steps+1 + ) + samps_ids = np.where(np.ones((self.rl_algorithm.n_envs, n_steps)) == 1) + val_samps = validation_buffer._get_samples(samps_ids) + + next_obs = self.rl_algorithm.process_rollouts(val_samps) + + texts = decode_and_strip_pad_tokens( + next_obs["input_ids"], + self.rl_algorithm.policy.tokenizer.pad_token_id, + self.rl_algorithm.policy.tokenizer + ) + + input_texts = decode_and_strip_pad_tokens( + val_samps.observations["input_ids"], + self.rl_algorithm.policy.tokenizer.pad_token_id, + self.rl_algorithm.policy.tokenizer + ) + + predicted_outputs = decode_and_strip_pad_tokens( + val_samps.actions, + self.rl_algorithm.policy.tokenizer.pad_token_id, + self.rl_algorithm.policy.tokenizer + ) + - ################# PART 3: Collect rollouts from Replay Buffer ################# - #Not sure why sb3 distiguishes theses cases but it's necessary to do so to read the samples correctly in the proper order - # I could have done just one case (since i set optimize_memory_usage to True above), but I wanted to show the logic of sb3 - samps_order = np.arange(self.num_val_samples) - 1 if validation_replay_buffer.optimize_memory_usage else np.arange(self.num_val_samples) - val_samps = validation_replay_buffer._get_samples(samps_order) - gts = self.rl_algorithm.env.envs[0].get_ground_truths( stage=stage, idxs = list(range(self.num_val_samples)) @@ -128,23 +186,7 @@ def evaluation(self, stage: str): reses = [] - texts = decode_and_strip_pad_tokens( - val_samps.next_observations["input_ids"], - self.rl_algorithm.policy.tokenizer.pad_token_id, - self.rl_algorithm.policy.tokenizer - ) - input_texts = decode_and_strip_pad_tokens( - val_samps.observations["input_ids"], - self.rl_algorithm.policy.tokenizer.pad_token_id, - self.rl_algorithm.policy.tokenizer - ) - - predicted_outputs = decode_and_strip_pad_tokens( - val_samps.actions, - self.rl_algorithm.policy.tokenizer.pad_token_id, - self.rl_algorithm.policy.tokenizer - ) ################# PART 4: Compute metrics ################# diff --git a/src/utils/instantiators.py b/src/utils/instantiators.py index 75348ec..a5739a4 100644 --- a/src/utils/instantiators.py +++ b/src/utils/instantiators.py @@ -138,5 +138,6 @@ def instantiate_rl_algorithm(rl_cfg, lm, tokenizer, environment, logger=None): if logger is not None: rl_alg.set_logger(logger) rl_alg.data_collator = data_collator + rl_alg.buffer_class_keyword = buffer_class_keyword return rl_alg \ No newline at end of file