From 8f7e364ffa292ec3e31ee6580742f1f3b68404d3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 5 Mar 2025 11:15:27 -0800 Subject: [PATCH 1/6] Update [ghstack-poisoned] --- tutorials/sphinx-tutorials/llm.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 tutorials/sphinx-tutorials/llm.py diff --git a/tutorials/sphinx-tutorials/llm.py b/tutorials/sphinx-tutorials/llm.py new file mode 100644 index 00000000000..6467e7f6687 --- /dev/null +++ b/tutorials/sphinx-tutorials/llm.py @@ -0,0 +1,17 @@ +""" +Interacting with LLMs in TorchRL +================================ + +**Author**: `Vincent Moens `_ + +.. _gs_first_training: + +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + +""" From b403b60678958125a80f81bb17accf08150c0b4d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 14 Mar 2025 15:34:43 +0000 Subject: [PATCH 2/6] Update [ghstack-poisoned] --- sota-implementations/post-training/grpo.py | 65 ++++--- .../post-training/grpo_utils.py | 171 ++++++++++++++---- torchrl/envs/common.py | 1 - torchrl/envs/custom/llm.py | 6 +- torchrl/envs/transforms/llm.py | 21 ++- torchrl/envs/transforms/transforms.py | 18 +- torchrl/modules/llm/transformers_policy.py | 1 - torchrl/modules/llm/vllm_policy.py | 30 ++- torchrl/objectives/ppo.py | 37 ++-- 9 files changed, 248 insertions(+), 102 deletions(-) diff --git a/sota-implementations/post-training/grpo.py b/sota-implementations/post-training/grpo.py index 7248cec9ab6..0dba78dd5c7 100644 --- a/sota-implementations/post-training/grpo.py +++ b/sota-implementations/post-training/grpo.py @@ -6,16 +6,21 @@ import torch from datasets import load_dataset +from grpo_utils import ( + HF2vLLMLocalWeightUpdater, + PrepareQuestion, + ShapedCorrectnessReward, +) from tensordict import TensorDict +from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from torchrl.collectors import SyncDataCollector from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement -from torchrl.envs import DataLoadingPrimer, KLRewardTransform, LLMEnv, StepCounter, Tokenizer -from torchrl.modules import from_hf_transformers -from torchrl.objectives import ClipPPOLoss, ReinforceLoss -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel -from grpo_utils import ShapedCorrectnessReward, PrepareQuestion -from torch.utils._pytree import tree_map +from torchrl.envs import DataLoadingPrimer, KLRewardTransform, LLMEnv, StepCounter +from torchrl.modules import from_hf_transformers, from_vllm +from torchrl.objectives import ClipPPOLoss +from transformers import AutoTokenizer, GPT2LMHeadModel +from vllm import LLM parser = ArgumentParser() parser.add_argument("--dataset", type=str, default="gsm8k") @@ -25,23 +30,23 @@ parser.add_argument("--steps_per_batch", type=int, default=16) parser.add_argument("--optim_batch_size", type=int, default=4) + def compute_mc_advantage(trajectories): # Get the question answer = trajectories["answer"] # Identify indices where the answers match answer_ids = tree_map(lambda string: hash(string), answer) answer_ids = torch.tensor(answer_ids) - print("answer_ids", answer_ids) unique_qs = answer_ids.view(-1).unique() trajectories["advantage"] = trajectories["next", "reward"] * 0 for u in unique_qs: - idx = answer_ids == u + idx = answer_ids == u rewards = trajectories[idx]["next", "reward"] rewards = (rewards - rewards.mean()) / rewards.std().clamp(min=1e-4) - print("rewards", rewards) trajectories.set_at_("advantage", rewards, idx) return trajectories + if __name__ == "__main__": args = parser.parse_args() # Create env instance: @@ -56,13 +61,13 @@ def collate_fn(batch): # LLM tokenizer = AutoTokenizer.from_pretrained("gpt2") - model = GPT2LMHeadModel(GPT2Config()) - + # inference_model = GPT2LMHeadModel(GPT2Config()) + inference_model = LLM("gpt2") tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" # Env - dataloader = DataLoader( + dataloader = DataLoader( # noqa: TOR401 train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn ) env = LLMEnv.from_dataloader( @@ -80,11 +85,8 @@ def collate_fn(batch): # Finally, we want the env to stop after the first step env.append_transform(StepCounter(max_steps=1)) - print("env", env) - print(env.reset()) - - policy = from_hf_transformers( - model, + policy = from_vllm( + inference_model, tokenizer=tokenizer, from_text=False, generate=True, @@ -95,7 +97,8 @@ def collate_fn(batch): env.append_transform(ShapedCorrectnessReward(tokenizer=tokenizer)) # Ref model - ref_model = GPT2LMHeadModel(GPT2Config()) + ref_model = GPT2LMHeadModel.from_pretrained("gpt2") + TensorDict.from_module(ref_model).data.to_module(ref_model) ref_model = from_hf_transformers( ref_model, tokenizer=tokenizer, @@ -103,19 +106,32 @@ def collate_fn(batch): generate=False, return_log_probs=True, ) - env.append_transform(KLRewardTransform(actor=ref_model, coef=0.1, log_prob_key="log_probs")) + env.append_transform( + KLRewardTransform(actor=ref_model, coef=0.1, log_prob_key="log_probs") + ) # replay buffer - rb = ReplayBuffer(storage=LazyStackStorage(args.steps_per_batch), sampler=SamplerWithoutReplacement(), batch_size=args.optim_batch_size) + rb = ReplayBuffer( + storage=LazyStackStorage(args.steps_per_batch), + sampler=SamplerWithoutReplacement(), + batch_size=args.optim_batch_size, + ) # Collector + train_model = GPT2LMHeadModel.from_pretrained("gpt2") collector = SyncDataCollector( - env, policy, frames_per_batch=args.steps_per_batch, total_frames=1_000_000, + env, + policy, + frames_per_batch=args.steps_per_batch, + total_frames=1_000_000, + local_weights_updater=HF2vLLMLocalWeightUpdater( + hf_model=train_model, vllm_model=inference_model + ), ) # Loss module policy_traning = from_hf_transformers( - model, + train_model, tokenizer=tokenizer, from_text=False, generate=False, @@ -139,14 +155,13 @@ def collate_fn(batch): for trajs in collector: trajs = trajs.reshape(-1) - print('trajs from collector', trajs) trajs = compute_mc_advantage(trajs) rb.extend(trajs) - for i in range(args.epochs): + for _ in range(args.epochs): for batch in rb: - print('running loss with batch', batch) loss = loss_fn(batch) loss_val = loss.mean(reduce=True) loss_val.backward() optim.step() optim.zero_grad() + collector.update_policy_weights_() diff --git a/sota-implementations/post-training/grpo_utils.py b/sota-implementations/post-training/grpo_utils.py index 8063658d34c..8f1b29ac1cb 100644 --- a/sota-implementations/post-training/grpo_utils.py +++ b/sota-implementations/post-training/grpo_utils.py @@ -5,12 +5,79 @@ from __future__ import annotations import torch +from tensordict import NestedKey, TensorDict, TensorDictBase from tensordict.tensorclass import NonTensorData, NonTensorStack -from torchrl.envs import Transform -from torchrl.data import Composite, TensorSpec, Unbounded from tensordict.utils import _zip_strict -from tensordict import TensorDictBase, TensorDict -from tensordict import NestedKey +from torch import nn + +from torchrl.collectors import LocalWeightUpdaterBase +from torchrl.data import Composite, TensorSpec, Unbounded +from torchrl.envs import Transform + + +class HF2vLLMLocalWeightUpdater(LocalWeightUpdaterBase): + hf_params: TensorDictBase | None = None + vllm_params: TensorDictBase | None = None + + def __init__(self, hf_model: nn.Module, vllm_model: vllm.LLM): # noqa + self.vllm_model = vllm_model + self.hf_model = hf_model + + def _get_server_weights(self) -> TensorDictBase: + # Get weight from hf model + if self.hf_params is None: + self.hf_params = TensorDict.from_module(self.hf_model).data.lock_() + return self.hf_params + + def _get_local_weights(self) -> TensorDictBase: + if self.vllm_params is None: + self.vllm_model = TensorDict.from_module( + self.vllm_model.llm_engine.model_executor.driver_worker.model_runner.inference_model + ).data.lock_() + return self.vllm_model + + def _maybe_map_weights( + self, server_weights: TensorDictBase, local_weights: TensorDictBase + ) -> TensorDictBase: + return self.format_hf_weights(td=server_weights, to_vllm=True) + + @classmethod + def format_hf_weights( + cls, + model: nn.Module | None = None, + td: TensorDictBase | None = None, + *, + from_, + to_vllm=False, + update_model=False, + ): + if td is None: + if model is None: + raise TypeError("A model or a tensordict is required.") + td = TensorDict.from_module(model) + if to_vllm and update_model: + raise TypeError( + "Cannot update model with to_vllm=True as the weight format has changed." + ) + for k in list(td.keys(True)): + if k[-1] == "q_proj": + keys = [k, k[:-1] + ("k_proj",), k[:-1] + ("v_proj",)] + qkv = torch.stack([td[_k].data for _k in keys]) + if not to_vllm: + splits = qkv.chunk(3) + for _k, split in zip(keys, splits): + td.pop(_k) + td.set(_k, split.apply(lambda x: torch.nn.Parameter(x))) + else: + qkv = qkv.apply(lambda x: torch.nn.Parameter(x)) + td.set(k[:-1] + ("qkv_proj",), qkv) + if update_model: + if model is None: + raise TypeError("The model must be provided to be updated.") + td.to_module(model) + return td + + BASE_PROMPT = ( "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. " "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. " @@ -18,8 +85,13 @@ "i.e., reasoning process here answer here. User: %s. Assistant: " ) + class PrepareQuestion(Transform): - def __init__(self, in_keys: list[NestedKey] | None = None, out_keys: list[NestedKey] | None = None): + def __init__( + self, + in_keys: list[NestedKey] | None = None, + out_keys: list[NestedKey] | None = None, + ): if in_keys is None: in_keys = ["text"] if out_keys is None: @@ -32,95 +104,124 @@ def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(out_key, self._modify_str(string)) return tensordict - def _modify_str(self, obs: str | list[str] | NonTensorData | NonTensorStack) -> NonTensorData | NonTensorStack: + def _modify_str( + self, obs: str | list[str] | NonTensorData | NonTensorStack + ) -> NonTensorData | NonTensorStack: if isinstance(obs, NonTensorData): return self._modify_str(obs.data) if isinstance(obs, NonTensorStack): return self._modify_str(obs.tolist()) if isinstance(obs, list): - return NonTensorStack( - *[BASE_PROMPT % obs for obs in obs] - ) + return NonTensorStack(*[BASE_PROMPT % obs for obs in obs]) return NonTensorData(BASE_PROMPT % obs) def _apply_transform(self, obs: torch.Tensor) -> None: return obs + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if out_key != in_key: observation_spec[out_key] = observation_spec[in_key].clone() return observation_spec + class ShapedCorrectnessReward(Transform): - def __init__(self, tokenizer, in_keys: list[NestedKey] | None=None, out_keys: list[NestedKey] | None = None): + def __init__( + self, + tokenizer, + in_keys: list[NestedKey] | None = None, + out_keys: list[NestedKey] | None = None, + ): super().__init__() self.tokenizer = tokenizer if in_keys is None: in_keys = ["text", "answer"] if not isinstance(in_keys, list) or len(in_keys) != 2: - raise ValueError("ShapedCorrectnessReward requires in_keys to be of type list and have 2 elements.") + raise ValueError( + "ShapedCorrectnessReward requires in_keys to be of type list and have 2 elements." + ) if out_keys is None: - out_keys = ["reward_answer", "reward_think", "reward_right", "reward_contained", "reward", "success"] + out_keys = [ + "reward_answer", + "reward_think", + "reward_right", + "reward_contained", + "reward", + "success", + ] super().__init__(in_keys, out_keys) def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: from xml.etree import ElementTree as ET + # Get the completion responses = next_tensordict[self.in_keys[0]] # batch_size, grpo_size, L answers = next_tensordict[self.in_keys[1]] # batch_size, grpo_size if isinstance(responses, torch.Tensor): - if responses.ndim == 3: + if responses.ndim == 3: batch_size, grpo_size, _ = responses.shape # decode - text_completion = self.tokenizer.decode( - responses.flatten(0, 1).tolist() - ) + text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist()) else: text_completion = responses # Decomposed reward tds = [] for answer, compl in zip(answers, text_completion): try: - cot, potential_answer = self.extract_tags("" + compl) #.replace("<<", "").replace(">>", "")) + cot, potential_answer = self.extract_tags( + "" + compl + ) # .replace("<<", "").replace(">>", "")) except ET.ParseError: cot, potential_answer = ("", "") - tds.append(self.single_shaped_correctness_reward(potential_answer, cot)) + tds.append( + self.single_shaped_correctness_reward(answer, potential_answer, cot) + ) tds = torch.stack(tds) - if isinstance(responses, torch.Tensor) and responses.ndim == 3: + if isinstance(responses, torch.Tensor) and responses.ndim == 3: tds = tds.reshape(batch_size, grpo_size) tds = tds.apply(lambda t: t.unsqueeze(-1)) return next_tensordict.update(tds) def transform_reward_spec(self, reward_spec: Composite) -> Composite: shape = reward_spec.shape + (1,) - reward_spec.update(Composite( - reward_answer=Unbounded(shape), - reward_think=Unbounded(shape), - reward_right=Unbounded(shape), - reward_contained=Unbounded(shape), - reward=Unbounded(shape), - success=Unbounded(shape, dtype=torch.bool), - )) + reward_spec.update( + Composite( + reward_answer=Unbounded(shape), + reward_think=Unbounded(shape), + reward_right=Unbounded(shape), + reward_contained=Unbounded(shape), + reward=Unbounded(shape), + success=Unbounded(shape, dtype=torch.bool), + ) + ) return reward_spec @classmethod - def single_shaped_correctness_reward(cls, answer: str, cot: str) -> TensorDict: + def single_shaped_correctness_reward( + cls, true_answer: str, potential_answer: list[str], cot: list[str] + ) -> TensorDict: - reward_answer = 5.0 * (len(answer) == 1) + reward_answer = 5.0 * (len(potential_answer) == 1) reward_think = 5.0 * (len(cot) == 1) # One of the answer tags has the right answer - reward_right = 20.0 * (any(attempt == answer for attempt in answer)) + reward_right = 20.0 * ( + any(attempt == true_answer for attempt in potential_answer) + ) # One of the answer tags contains the right answer (might be e.g. $20 instead of 20) - reward_contained = 10.0 * (any((answer in attempt) for attempt in answer)) + reward_contained = 10.0 * ( + any((true_answer in attempt) for attempt in potential_answer) + ) - success = len(answer) > 0 and answer[-1] == answer + success = len(potential_answer) > 0 and potential_answer[-1] == true_answer # Compose the rewards - reward = 100.0 * float(success) + (reward_answer + reward_think + reward_contained + reward_right) * (1- float(success)) + reward = 100.0 * float(success) + ( + reward_answer + reward_think + reward_contained + reward_right + ) * (1 - float(success)) rewards = TensorDict( reward_answer=reward_answer, @@ -133,7 +234,7 @@ def single_shaped_correctness_reward(cls, answer: str, cot: str) -> TensorDict: return rewards @staticmethod - def extract_tags(text: str) -> Tuple[str, str]: + def extract_tags(text: str) -> tuple[str, str]: """ Parse XML-like tags from text. Returns a dictionary with keys 'think' and 'answer'. The values are lists of strings, with each string being the content of a tag. @@ -143,7 +244,7 @@ def extract_tags(text: str) -> Tuple[str, str]: xml_string = f"{text}" try: root = ET.fromstring(xml_string) - except ET.ParseError as e: + except ET.ParseError: return ("", "") return ( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index dba7411ca8c..cdfbe5c19e3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3383,7 +3383,6 @@ def _rollout_stop_early( else: tensordict.clear_device_() # In case policy(..) does not modify in-place - no-op for TensorDict and related - print('policy input', tensordict) tensordict.update(policy(tensordict)) if auto_cast_to_device: if env_device is not None: diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 54ffca0278f..8cb2d3cde84 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -18,6 +18,8 @@ from tensordict.tensorclass import NonTensorData, NonTensorStack from tensordict.utils import _zip_strict from torch.utils.data import DataLoader + +from torchrl._utils import _replace_last from torchrl.data.map.hash import SipHash from torchrl.data.tensor_specs import ( Bounded, @@ -236,7 +238,7 @@ def from_dataloader( cls, dataloader: DataLoader, *, - tokenizer: Tokenizer | None = None, + tokenizer: transformers.PretrainedTokenizerBase | None = None, # noqa token_key: NestedKey | None = None, str_key: NestedKey | None = None, attention_key: NestedKey | None = None, @@ -347,6 +349,7 @@ def from_dataloader( if action_key is None: action_key = cls._DEFAULT_ACTION_STR_KEY tokenizer_transform = Tokenizer( + tokenizer=tokenizer, in_keys=[str_key], out_keys=[token_key], # Assume that the tokens are named according to _DEFAULT_ACTION_TOKENS_KEY @@ -358,6 +361,7 @@ def from_dataloader( ) else: tokenizer_transform = Tokenizer( + tokenizer=tokenizer, in_keys=[str_key], out_keys=[token_key], call_before_reset=True, diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index d004ae47fbc..2470c7440b0 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -6,7 +6,7 @@ from collections import deque from collections.abc import Mapping -from copy import copy, deepcopy +from copy import copy from typing import Any, Callable, Iterable, Literal import torch @@ -17,13 +17,15 @@ TensorDictBase, unravel_key, ) -from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams, ProbabilisticTensorDictSequential +from tensordict.nn import ( + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, +) from tensordict.utils import _zip_strict, is_seq_of_nested_key -from torch import nn from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform -from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param +from torchrl.envs.transforms.utils import _set_missing_tolerance from torchrl.envs.utils import make_composite_from_td @@ -564,7 +566,7 @@ def __init__( out_keys=None, requires_grad=False, log_prob_key: NestedKey = "sample_log_prob", - action_key: NestedKey = "action", + action_key: NestedKey = "action", ): if in_keys is None: in_keys = self.DEFAULT_IN_KEYS @@ -650,12 +652,17 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero()) return next_tensordict # with self.frozen_params.to_module(self.functional_actor): - if isinstance(self.functional_actor, (ProbabilisticTensorDictModule, ProbablisticTensorDictSequential)): + if isinstance( + self.functional_actor, + (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), + ): dist = self.functional_actor.get_dist(next_tensordict.copy()) # get the log_prob given the original model log_prob = dist.log_prob(action) else: - log_prob = self.functional_actor(next_tensordict.copy()).get(self.sample_log_prob_key) + log_prob = self.functional_actor(next_tensordict.copy()).get( + self.sample_log_prob_key + ) reward_key = self.in_keys[0] reward = next_tensordict.get("next").get(reward_key) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 14231871635..776a5601106 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5385,7 +5385,6 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: out_key, observation, ) - print('next tensordict', next_tensordict) elif not self.missing_tolerance or out_key not in next_tensordict.keys( True ): @@ -5434,7 +5433,6 @@ def call_tokenizer_fn(self, value: str | list[str]): if self.return_attention_mask: attention_mask = attention_mask.to(device) if self.return_attention_mask: - print('out, attention_mask', out.shape, attention_mask.shape) return out, attention_mask return out @@ -5459,7 +5457,9 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: elif in_key in input_spec["full_action_spec"].keys(False, True): spec = input_spec["full_action_spec"] else: - raise KeyError(f"The input keys {in_key} wasn't found in the env input specs.") + raise KeyError( + f"The input keys {in_key} wasn't found in the env input specs." + ) local_spec = spec.pop(in_key) new_shape = spec.shape if self.max_length is None: @@ -5501,12 +5501,12 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec ) attention_mask_keys.add(attention_mask_key) observation_spec[attention_mask_key] = Bounded( - 0, - 2, - shape=new_shape, - device=observation_spec[in_key].device, - dtype=observation_spec[in_key].dtype, - ) + 0, + 2, + shape=new_shape, + device=observation_spec[in_key].device, + dtype=observation_spec[in_key].dtype, + ) return observation_spec diff --git a/torchrl/modules/llm/transformers_policy.py b/torchrl/modules/llm/transformers_policy.py index 05896f89a2c..e977fc18be6 100644 --- a/torchrl/modules/llm/transformers_policy.py +++ b/torchrl/modules/llm/transformers_policy.py @@ -91,7 +91,6 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: - "forward" """ tokens_out = td["tokens_response", "input_ids"] - print('tokens_out', tokens_out.shape) seq_len = tokens_out.shape[-1] del td["forward", "past_key_values"] diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index daab91c76d0..06b3d099856 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -10,6 +10,8 @@ import torch from tensordict import ( from_dataclass, + lazy_stack, + LazyStackedTensorDict, maybe_dense_stack, NestedKey, NonTensorData, @@ -151,7 +153,7 @@ def from_vllm( out_keys=["tokens_in"], method_kwargs=tokenizer_kwargs, strict=True, - inplace=False, + inplace="empty", ) else: module_dict["encode"] = Mod( @@ -164,7 +166,7 @@ def from_vllm( in_keys=[text_key, "text_response"], out_keys=["tokens_in", "tokens_response"], strict=True, - inplace=False, + inplace="empty", ) def select(x, y): @@ -196,7 +198,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None): ("tokens_in", "attention_mask"), ], strict=False, - inplace=False, + inplace="empty", ) else: module_dict["move_inputs"] = Mod( @@ -205,7 +207,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None): out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], # It's ok if there's no mask strict=False, - inplace=False, + inplace="empty", ) def to_list(tokens, attention_mask): @@ -261,13 +263,23 @@ def to_list(tokens, attention_mask): strict=True, ) - def get_output_tokens_and_log_probs(td): + padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0] + + def get_output_tokens_and_log_probs(td, padding_value=padding_value): td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) + if td.ndim and not isinstance(td, LazyStackedTensorDict): + td = lazy_stack(list(td.unbind(0))) if generate: # When not generate, we don't want to overwrite this - td["tokens_response"] = td["tokens_out"].outputs.token_ids + tokens_response_td = td["tokens_out"].outputs._tensordict.select( + "token_ids", "logprobs", strict=False + ) + tokens_response_td.rename_key_("token_ids", "tokens_response") + # td["tokens_response"] = outputs.token_ids if return_log_probs: - td["log_probs"] = td["tokens_out"].outputs.logprobs.unsqueeze(-1) + tokens_response_td.rename_key_("logprobs", "log_probs") + # td["log_probs"] = outputs.logprobs.unsqueeze(-1) + td.update(tokens_response_td) elif not generate: td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1) return td @@ -313,7 +325,7 @@ def translate_lps(tokens_response, x): "text_response", ], strict=False, - inplace=False, + inplace="empty", ) else: module_dict["format"] = Mod( @@ -321,7 +333,7 @@ def translate_lps(tokens_response, x): in_keys=["log_probs", "tokens_response"], out_keys=["log_probs", "tokens_response"], strict=False, - inplace=False, + inplace="empty", ) return Seq(module_dict, inplace=True) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 69d3349a6b6..22191e049da 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -17,13 +17,14 @@ TensorDictParams, ) from tensordict.nn import ( - TensorDictModuleBase, composite_lp_aggregate, + composite_lp_aggregate, CompositeDistribution, dispatch, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, set_composite_lp_aggregate, TensorDictModule, + TensorDictModuleBase, ) from tensordict.utils import NestedKey from torch import distributions as d @@ -349,11 +350,15 @@ def __init__( if critic is not None: critic_network = critic del critic - if actor_network is None or (critic_network is None and critic_coef not in (None, 0.0)): + if actor_network is None or ( + critic_network is None and critic_coef not in (None, 0.0) + ): raise TypeError( "Missing positional arguments actor_network or critic_network." ) - critic_coef = 1.0 if critic_coef is None and critic_network is not None else critic_coef + critic_coef = ( + 1.0 if critic_coef is None and critic_network is not None else critic_coef + ) if reduction is None: reduction = "mean" @@ -523,7 +528,10 @@ def _get_entropy( def _get_cur_log_prob(self, tensordict): - if isinstance(self.actor_network, (ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule)): + if isinstance( + self.actor_network, + (ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule), + ): with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): @@ -591,8 +599,6 @@ def _log_weight( if is_tensor_collection(log_prob): log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) - print('log_prob', log_prob.shape) - print('prev_log_prob', prev_log_prob.shape) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) if is_tensor_collection(log_weight): log_weight = _sum_td_features(log_weight) @@ -1377,16 +1383,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def reset(self) -> None: self.beta = self._beta_init + class GRPO(ClipPPOLoss): - def __init__(self, - actor_network: TensorDictModuleBase, - # Default value of LLMData - log_prob_key="log_probs", + """TODO""" + def __init__( + self, + actor_network: TensorDictModuleBase, + # Default value of LLMData + log_prob_key="log_probs", ): super().__init__( - actor_network=actor_network, - critic_network=None, - critic_coef=0.0, - functional=False, + actor_network=actor_network, + critic_network=None, + critic_coef=0.0, + functional=False, ) self.set_keys(log_prob_key=log_prob_key) From 1c53e09adc42433faf030b84a537e56204e558ec Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 18 Mar 2025 11:19:03 -0700 Subject: [PATCH 3/6] Update [ghstack-poisoned] --- test/mocking_classes.py | 52 ++ test/test_actors.py | 69 ++- test/test_cost.py | 71 ++- test/test_env.py | 187 +++--- torchrl/envs/custom/llm.py | 48 +- torchrl/envs/transforms/llm.py | 120 ++-- torchrl/envs/transforms/transforms.py | 94 ++- torchrl/modules/distributions/discrete.py | 1 + torchrl/modules/llm/__init__.py | 3 +- torchrl/modules/llm/common.py | 48 ++ torchrl/modules/llm/transformers_policy.py | 615 ++++++++++++------- torchrl/modules/llm/vllm_policy.py | 664 +++++++++++++++------ torchrl/objectives/ppo.py | 92 ++- 13 files changed, 1445 insertions(+), 619 deletions(-) create mode 100644 torchrl/modules/llm/common.py diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 5f03e773591..7defb919f21 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -2459,3 +2459,55 @@ def _step( self.parent.device, ) return next_tensordict + + +class DummyStrDataLoader: + def __init__(self, batch_size=0): + self.batch_size = batch_size + + def generate_random_string(self, length=10): + """Generate a random string of a given length.""" + return "".join(random.choice(string.ascii_lowercase) for _ in range(length)) + + def __iter__(self): + return self + + def __next__(self): + if self.batch_size == 0: + return self.generate_random_string() + else: + return [self.generate_random_string() for _ in range(self.batch_size)] + + +class DummyTensorDataLoader: + def __init__(self, batch_size=0, max_length=10, padding=False): + self.batch_size = batch_size + self.max_length = max_length + self.padding = padding + + def generate_random_tensor(self): + """Generate a tensor of random int64 values.""" + length = random.randint(1, self.max_length) + return torch.tensor( + [random.randint(0, 100) for _ in range(length)], dtype=torch.int64 + ) + + def pad_tensor(self, tensor): + """Pad a tensor to the maximum length.""" + padding_length = self.max_length - len(tensor) + return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor)) + + def __iter__(self): + return self + + def __next__(self): + if self.batch_size == 0: + tensor = self.generate_random_tensor() + return self.pad_tensor(tensor) if self.padding else tensor + else: + tensors = [self.generate_random_tensor() for _ in range(self.batch_size)] + if self.padding: + tensors = [self.pad_tensor(tensor) for tensor in tensors] + return torch.stack(tensors) + else: + return tensors diff --git a/test/test_actors.py b/test/test_actors.py index bd0e3b5ed6f..0d63cf116b7 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -947,9 +947,10 @@ class TestLLMActor: def test_from_hf_transformers( self, from_text, generate, return_log_probs, tokens, attention_mask ): + torch.manual_seed(0) from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel - model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny" + # model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny" # Load the model and tokenizer # model = AutoModel.from_pretrained(model_name) # tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -1004,6 +1005,7 @@ def test_from_hf_transformers( def test_from_vllm( self, from_text, generate, return_log_probs, tokens, attention_mask ): + torch.manual_seed(0) from vllm import LLM model = LLM(model="facebook/opt-125m") @@ -1031,6 +1033,7 @@ def _make_data( generate, from_text, has_logits, + batch_size=1, text_response=None, tokens_response=None, ): @@ -1048,7 +1051,9 @@ def _make_data( else: text_response = NonTensorStack(text_response) lp_kwargs.update({"text_response": text_response}) - tdin = LLMData(text=NonTensorStack("a text"), **lp_kwargs, batch_size=1) + tdin = LLMData( + text=NonTensorStack("a text"), **lp_kwargs, batch_size=batch_size + ) else: if not generate: if tokens_response is None: @@ -1057,7 +1062,10 @@ def _make_data( tokens_response = torch.randint(1024, shape_response) lp_kwargs.update({"tokens_response": tokens_response}) tdin = LLMData( - tokens=tokens, attention_mask=attention_mask, **lp_kwargs, batch_size=1 + tokens=tokens, + attention_mask=attention_mask, + **lp_kwargs, + batch_size=batch_size ) return tdin @@ -1079,15 +1087,21 @@ def _run_check( elif from_text and not generate: assert tdin.text_response is not None + tdin.copy() td = m(tdin) assert td is tdin assert isinstance(td, LLMData) if from_text and generate: assert td.text_response is not None - if generate and (attention_mask is not None or from_text): - assert td.attention_mask is not None, (generate, generate, from_text) - else: - assert td.attention_mask is None, (generate, from_text) + + # TODO: vLLM may produce an attention mask when hf does not - explore consistency! + # if generate and (from_text or tdincopy.attention_mask is not None): + # assert td.attention_mask is not None, (generate, from_text, tdincopy.attention_mask is not None) + # if isinstance(td.attention_mask, torch.Tensor): + # assert td.attention_mask.shape == td.tokens.shape + # else: + # assert td.attention_mask is None, (generate, from_text) + if not generate: # logprobs are computed on text response of tokens_response assert td.text_response is not None or td.tokens_response is not None @@ -1097,7 +1111,7 @@ def _run_check( if generate: if return_log_probs: assert td.log_probs is not None - assert td.log_probs.shape[-2] == td.tokens_response.shape[-1] + assert td.log_probs.shape[-1] == td.tokens_response.shape[-1] else: assert td.log_probs is None @@ -1113,6 +1127,42 @@ def _run_check( != td.tokens[..., : td.tokens_response.shape[-1]] ).any(), (generate, from_text) + @pytest.mark.parametrize( + "from_text, tokens, attention_mask", + [ + ( + False, + torch.randint(1024, (1, 10)), + torch.ones(1, 10, dtype=torch.int64), + ), + (False, torch.randint(1024, (1, 10)), None), + (True, None, None), + ], + ) + def test_from_hf_logprobs(self, from_text, tokens, attention_mask): + torch.manual_seed(0) + from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + model = GPT2LMHeadModel(GPT2Config()).eval() + + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + m_generate = from_hf_transformers( + model, + tokenizer=tokenizer, + from_text=from_text, + generate=True, + return_log_probs=True, + ) + m_logprobs = from_hf_transformers( + model, tokenizer=tokenizer, from_text=from_text, generate=False + ) + self._check_lps( + m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False + ) + @pytest.mark.parametrize( "from_text, tokens, attention_mask", [ @@ -1126,6 +1176,7 @@ def _run_check( ], ) def test_from_vllm_logprobs(self, from_text, tokens, attention_mask): + torch.manual_seed(0) from vllm import LLM model = LLM(model="facebook/opt-125m") @@ -1162,6 +1213,8 @@ def _check_lps( text_response=td_generate.text_response, ) td_logprobs = model_logprobs(tdin_logprobs) + assert td_generate.log_probs.shape == td_generate.tokens_response.shape + assert td_logprobs.log_probs.shape == td_generate.tokens_response.shape torch.testing.assert_close( td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2 ) diff --git a/test/test_cost.py b/test/test_cost.py index 3fd1fad62da..79ca215ec7d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -145,7 +145,10 @@ get_available_devices, get_default_devices, ) - from pytorch.rl.test.mocking_classes import ContinuousActionConvMockEnv + from pytorch.rl.test.mocking_classes import ( + ContinuousActionConvMockEnv, + DummyStrDataLoader, + ) else: from _utils_internal import ( # noqa _call_value_nets, @@ -153,7 +156,7 @@ get_available_devices, get_default_devices, ) - from mocking_classes import ContinuousActionConvMockEnv + from mocking_classes import ContinuousActionConvMockEnv, DummyStrDataLoader _has_functorch = True try: @@ -16659,6 +16662,70 @@ def forward(self, td, mode): assert exploration_type() == ExplorationType.RANDOM +class TestPPO4LLMs: + @pytest.mark.parametrize("from_text", [True, False]) + def test_hf(self, from_text): + from torchrl.envs import LLMEnv, Transform + from torchrl.modules import from_hf_transformers + from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM + + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + tokenizer.pad_token = tokenizer.eos_token + + model = OPTForCausalLM(OPTConfig()) + policy_inference = from_hf_transformers( + model, tokenizer=tokenizer, generate=True, from_text=from_text + ) + policy_train = from_hf_transformers( + model, tokenizer=tokenizer, generate=False, from_text=False + ) + for p in policy_train.parameters(): + assert p.requires_grad + # Create some fake data + dl = DummyStrDataLoader(batch_size=32) + llm_env = LLMEnv.from_dataloader( + dl, + tokenizer=tokenizer if not from_text else None, + batch_size=(32,), + str2str=True, + ) + + class RewardTransform(Transform): + def _step(self, td, next_td): + next_td["reward"] = torch.randn_like( + td["tokens_response"], dtype=torch.float + ).unsqueeze(-1) + return next_td + + def transform_reward_spec(self, reward_spec): + return reward_spec.set( + "reward", Unbounded((*reward_spec.shape, -1, 1), dtype=torch.float) + ) + + llm_env = llm_env.append_transform(RewardTransform()) + with torch.no_grad(): + data = llm_env.rollout(3, policy_inference) + data = data.view(-1) + assert data["tokens_response"].shape[-1] == 20 + # Make some fake advantages: + data["advantage"] = torch.randn_like(data["next", "reward"]) + + loss = ClipPPOLoss( + actor_network=policy_train, + ) + loss_vals = loss(data) + + assert "loss_objective" in loss_vals + assert "loss_entropy" in loss_vals + assert loss_vals["loss_objective"].requires_grad + assert loss_vals["loss_entropy"].requires_grad + assert "clip_fraction" in loss_vals + assert "kl_approx" in loss_vals + assert "entropy" in loss_vals + assert "ESS" in loss_vals + assert "loss_critic" not in loss_vals + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_env.py b/test/test_env.py index dfcc5a5e87d..1f6a5312048 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -12,8 +12,8 @@ import pickle import random import re -import string from collections import defaultdict +from contextlib import nullcontext from functools import partial from sys import platform from typing import Any, Optional @@ -33,7 +33,7 @@ TensorDictBase, ) from tensordict.nn import TensorDictModuleBase -from tensordict.tensorclass import NonTensorStack, TensorClass +from tensordict.tensorclass import NonTensorData, NonTensorStack, TensorClass from tensordict.utils import _unravel_key_to_tuple from torch import nn @@ -133,6 +133,8 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + DummyStrDataLoader, + DummyTensorDataLoader, EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, @@ -174,6 +176,8 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + DummyStrDataLoader, + DummyTensorDataLoader, EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, @@ -4578,58 +4582,7 @@ def set_capture(self): yield None return - class DummyDataLoader: - def __init__(self, batch_size=0): - self.batch_size = batch_size - - def generate_random_string(self, length=10): - """Generate a random string of a given length.""" - return "".join(random.choice(string.ascii_lowercase) for _ in range(length)) - - def __iter__(self): - return self - - def __next__(self): - if self.batch_size == 0: - return self.generate_random_string() - else: - return [self.generate_random_string() for _ in range(self.batch_size)] - - class DummyTensorDataLoader: - def __init__(self, batch_size=0, max_length=10, padding=False): - self.batch_size = batch_size - self.max_length = max_length - self.padding = padding - - def generate_random_tensor(self): - """Generate a tensor of random int64 values.""" - length = random.randint(1, self.max_length) - return torch.tensor( - [random.randint(0, 100) for _ in range(length)], dtype=torch.int64 - ) - - def pad_tensor(self, tensor): - """Pad a tensor to the maximum length.""" - padding_length = self.max_length - len(tensor) - return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor)) - - def __iter__(self): - return self - - def __next__(self): - if self.batch_size == 0: - tensor = self.generate_random_tensor() - return self.pad_tensor(tensor) if self.padding else tensor - else: - tensors = [ - self.generate_random_tensor() for _ in range(self.batch_size) - ] - if self.padding: - tensors = [self.pad_tensor(tensor) for tensor in tensors] - return torch.stack(tensors) - else: - return tensors - + @pytest.mark.skipif(not _has_transformers, reason="test requires transformers") @pytest.mark.parametrize( "str2str,stack_method", [ @@ -4649,7 +4602,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size): ) if str2str: primer = DataLoadingPrimer( - dataloader=self.DummyDataLoader(batch_size=batch_size), + dataloader=DummyStrDataLoader(batch_size=batch_size), data_keys=[LLMEnv._DEFAULT_STR_KEY], example_data="a string!", ) @@ -4657,9 +4610,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size): if stack_method is None: stack_method = as_padded_tensor primer = DataLoadingPrimer( - dataloader=self.DummyTensorDataLoader( - batch_size=batch_size, padding=True - ), + dataloader=DummyTensorDataLoader(batch_size=batch_size, padding=True), data_keys=[LLMEnv._DEFAULT_TOKEN_KEY], data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], stack_method=stack_method, @@ -4674,25 +4625,39 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size): else: env.check_env_specs(break_when_any_done="both") + @pytest.mark.skipif(not _has_transformers, reason="test requires transformers") + @pytest.mark.parametrize("tokenizer", [True, False]) @pytest.mark.parametrize( - "str2str,stack_method", + "str2str,no_stack,stack_method", [ - [True, None], - [False, "as_padded_tensor"], - # TODO: a bit experimental, fails with check_env_specs - # [False, "as_nested_tensor"], - [False, None], + [True, True, None], + [True, False, None], + [False, False, "as_padded_tensor"], + [False, False, None], ], ) @pytest.mark.parametrize("batched", [True, False]) @pytest.mark.parametrize("device", [None, "cpu"]) @pytest.mark.parametrize("batch_size", [0, 4]) def test_llm_from_dataloader( - self, str2str, batched, stack_method, device, batch_size + self, + str2str, + batched, + stack_method, + device, + batch_size, + tokenizer, + no_stack, ): + from transformers import AutoTokenizer + + if tokenizer: + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + else: + tokenizer = None if str2str: kwargs = { - "dataloader": self.DummyDataLoader(batch_size=batch_size), + "dataloader": DummyStrDataLoader(batch_size=batch_size), "data_keys": [LLMEnv._DEFAULT_STR_KEY], "example_data": "a string!", } @@ -4700,7 +4665,7 @@ def test_llm_from_dataloader( if stack_method is None: stack_method = as_padded_tensor kwargs = { - "dataloader": self.DummyTensorDataLoader( + "dataloader": DummyTensorDataLoader( padding=True, batch_size=batch_size ), "data_keys": [LLMEnv._DEFAULT_TOKEN_KEY], @@ -4712,7 +4677,8 @@ def test_llm_from_dataloader( "str2str": str2str, "device": device, "has_attention": False, - "no_stack": False, + "no_stack": no_stack, + "tokenizer": tokenizer, } ) env = LLMEnv.from_dataloader(**kwargs) @@ -4725,12 +4691,17 @@ def test_llm_from_dataloader( if batch_size > 0: def policy(td): - if str2str: + if str2str and tokenizer is None: if not td.shape: - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "" + td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorData( + "", device=device + ) else: td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack( - *["" for _ in range(td.shape[0])] + *[ + NonTensorData("", device=device) + for _ in range(td.shape[0]) + ] ) else: td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( @@ -4742,34 +4713,48 @@ def policy(td): # Tell the env that we want 3 sub-envs r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3])) assert r.ndim == 2 - if str2str: + if str2str and tokenizer is None: assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str) assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str) - assert ( - r[0, 0][LLMEnv._DEFAULT_STR_KEY] - == r[0, 1][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - assert ( - r[0, 1][LLMEnv._DEFAULT_STR_KEY] - == r[0, 2][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - assert ( - r[-1, 0][LLMEnv._DEFAULT_STR_KEY] - == r[-1, 1][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - assert ( - r[-1, 1][LLMEnv._DEFAULT_STR_KEY] - == r[-1, 2][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - else: + should_fail = no_stack + if should_fail: + ctx = pytest.raises(AssertionError) + else: + ctx = nullcontext() + with ctx: + assert ( + r[0, 0][LLMEnv._DEFAULT_STR_KEY] + == r[0, 1][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) + ] + ), ( + r[0, 0][LLMEnv._DEFAULT_STR_KEY], + r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY], + r[0, 0]["next", LLMEnv._DEFAULT_STR_KEY], + r[0, 1][LLMEnv._DEFAULT_STR_KEY], + ) + with ctx: + assert ( + r[0, 1][LLMEnv._DEFAULT_STR_KEY] + == r[0, 2][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) + ] + ) + with ctx: + assert ( + r[-1, 0][LLMEnv._DEFAULT_STR_KEY] + == r[-1, 1][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) + ] + ) + with ctx: + assert ( + r[-1, 1][LLMEnv._DEFAULT_STR_KEY] + == r[-1, 2][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) + ] + ) + elif tokenizer is None: assert ( r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] == r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] @@ -4809,7 +4794,7 @@ def test_llm_from_dataloader_repeats( ): if str2str: kwargs = { - "dataloader": self.DummyDataLoader(batch_size=batch_size), + "dataloader": DummyStrDataLoader(batch_size=batch_size), "data_keys": [LLMEnv._DEFAULT_STR_KEY], "example_data": "a string!", "repeats": repeats, @@ -4818,7 +4803,7 @@ def test_llm_from_dataloader_repeats( if stack_method is None: stack_method = as_padded_tensor kwargs = { - "dataloader": self.DummyTensorDataLoader( + "dataloader": DummyTensorDataLoader( padding=True, batch_size=batch_size ), "data_keys": [LLMEnv._DEFAULT_TOKEN_KEY], @@ -4951,7 +4936,7 @@ def test_done_and_reward( ) if str2str else contextlib.nullcontext(): if str2str: kwargs = { - "dataloader": self.DummyDataLoader(batch_size=batch_size), + "dataloader": DummyStrDataLoader(batch_size=batch_size), "data_keys": [LLMEnv._DEFAULT_STR_KEY], "example_data": "a string!", "repeats": repeats, @@ -4962,7 +4947,7 @@ def test_done_and_reward( if stack_method is None: stack_method = as_padded_tensor kwargs = { - "dataloader": self.DummyTensorDataLoader( + "dataloader": DummyTensorDataLoader( padding=True, batch_size=batch_size ), "data_keys": [LLMEnv._DEFAULT_TOKEN_KEY], diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 8cb2d3cde84..63f9db8e2d6 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -154,12 +154,19 @@ def __init__( self.full_observation_spec_unbatched = Composite( { self.str_key: NonTensor( - example_data="a string", batched=True, shape=() + example_data="a string", + batched=True, + shape=(), + device=device, ) } ) self.full_action_spec_unbatched = Composite( - {action_key: NonTensor(example_data="a string", batched=True, shape=())} + { + action_key: NonTensor( + example_data="a string", batched=True, shape=(), device=device + ) + } ) else: if vocab_size is None: @@ -217,8 +224,8 @@ def __init__( if not self.assign_done: # Use single done self.full_done_spec_unbatched = Composite( - done=Unbounded(shape=(1,), dtype=torch.bool), - terminated=Unbounded(shape=(1,), dtype=torch.bool), + done=Unbounded(shape=(1,), dtype=torch.bool, device=device), + terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device), ) elif self.str2str: raise STR2STR_ERR @@ -226,11 +233,11 @@ def __init__( # Use single done self.full_done_spec_unbatched = Composite( tokens_data=Composite( - done=Unbounded(shape=(-1,), dtype=torch.bool), - terminated=Unbounded(shape=(-1,), dtype=torch.bool), + done=Unbounded(shape=(-1,), dtype=torch.bool, device=device), + terminated=Unbounded(shape=(-1,), dtype=torch.bool, device=device), ), - done=Unbounded(shape=(1,), dtype=torch.bool), - terminated=Unbounded(shape=(1,), dtype=torch.bool), + done=Unbounded(shape=(1,), dtype=torch.bool, device=device), + terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device), ) @classmethod @@ -346,6 +353,7 @@ def from_dataloader( if tokenizer is not None: if str2str: + # In this case, the tokenizer is appended to the env after each step if action_key is None: action_key = cls._DEFAULT_ACTION_STR_KEY tokenizer_transform = Tokenizer( @@ -360,6 +368,7 @@ def from_dataloader( missing_tolerance=False, ) else: + # In this case, the tokenizer acts before reset and that's all tokenizer_transform = Tokenizer( tokenizer=tokenizer, in_keys=[str_key], @@ -387,6 +396,7 @@ def from_dataloader( example_data=example_data, stack_method=stack_method, repeats=repeats, + device=device, ) env = LLMEnv( str2str=str2str, @@ -409,12 +419,12 @@ def from_dataloader( return env.append_transform(primer) @staticmethod - def _check_obs_act_and_cat(obs, action): + def _check_obs_act_and_cat(obs, action, *, device): if not isinstance(obs, str): raise TypeError(f"Observation must be a string, got {type(obs)}.") if not isinstance(action, str): raise TypeError(f"Action must be a string, got {type(action)}.") - return obs + action + return NonTensorData(obs + action, device=device) def _step( self, @@ -466,10 +476,11 @@ def _make_next_obs( self, tensordict: TensorDictBase, nex_td: TensorDictBase ) -> TensorDictBase: if self.no_stack: - if self.str2str: - raise NotImplementedError action = tensordict.get(self.action_key) - nex_td.set(self.token_key, action) + if self.str2str: + nex_td.set(self.str_key, action) + else: + nex_td.set(self.token_key, action) if self.has_attention: attention_mask = tensordict.get(self.attention_key) n = action.shape[-1] - attention_mask.shape[-1] @@ -495,11 +506,13 @@ def _make_next_obs( "The tensordict is batchless, yet the action and/or observations are not " f"strings but {type(action)} and {type(obs)}, respectivly." ) - observation = self._check_obs_act_and_cat(obs, action) + observation = self._check_obs_act_and_cat( + obs, action, device=self.device + ) else: observation = NonTensorStack( *[ - self._check_obs_act_and_cat(_obs, _action) + self._check_obs_act_and_cat(_obs, _action, device=self.device) for (_obs, _action) in _zip_strict(obs, action) ] ) @@ -552,6 +565,11 @@ def check_str(): f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms." ) td_reset = tensordict.copy() + if td_reset.device != self.device: + if self.device is None: + td_reset.clear_device_() + else: + td_reset = td_reset.to(self.device) tensordict = self._maybe_make_done(tensordict, td_reset) if self.as_llm_data: raise NotImplementedError() diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index 2470c7440b0..28d49d08013 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -6,26 +6,22 @@ from collections import deque from collections.abc import Mapping -from copy import copy +from copy import copy, deepcopy from typing import Any, Callable, Iterable, Literal import torch -from tensordict import ( - maybe_dense_stack, - NestedKey, - TensorDict, - TensorDictBase, - unravel_key, -) +from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, + TensorDictParams, ) from tensordict.utils import _zip_strict, is_seq_of_nested_key +from torch import nn from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform -from torchrl.envs.transforms.utils import _set_missing_tolerance +from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param from torchrl.envs.utils import make_composite_from_td @@ -366,14 +362,14 @@ def __init__( use_buffer: bool | None = None, auto_batch_size: bool = True, repeats: int | None = None, + device: torch.device | None = None, ): self.dataloader = dataloader if repeats is None: repeats = 0 self.repeats = repeats - if ( - getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None - ) or repeats > 0: + batch_size = getattr(dataloader, "batch_size", 0) + if (batch_size > 1 and use_buffer is None) or repeats > 0: use_buffer = True self.use_buffer = use_buffer @@ -381,13 +377,11 @@ def __init__( self._queue = deque() # No auto_batch_size if we know we have a single element - self.auto_batch_size = auto_batch_size and ( - getattr(dataloader, "batch_size", 1) > 0 - ) + self.auto_batch_size = auto_batch_size and (batch_size > 0) self.endless_dataloader = self._endless_iter(self.dataloader) if stack_method is None: - stack_method = maybe_dense_stack + stack_method = lazy_stack elif stack_method == "as_nested_tensor": stack_method = as_nested_tensor elif stack_method == "as_padded_tensor": @@ -407,12 +401,16 @@ def __init__( for data_key, data_spec in _zip_strict(data_keys, data_specs) } ) + if batch_size: + primers = batch_size.expand(batch_size) self.data_keys = data_keys elif primers is None: self.data_keys = data_keys # We can get the primer from the dataloader itself data = self._load_from_dataloader() primers = make_composite_from_td(data, dynamic_shape=True) + if batch_size: + primers = primers.expand(batch_size) self._queue.insert(0, data) if data_keys is None: self.data_keys = list(primers.keys(True, True)) @@ -426,6 +424,7 @@ def __init__( expand_specs=None, single_default_value=True, call_before_env_reset=True, + device=device, ) self._reset_key = "_reset" @@ -434,10 +433,14 @@ def _endless_iter(self, obj): while True: yield from obj + # def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: + # td = super()._reset_env_preprocess(tensordict) + # return lazy_stack(list(td.unbind(0))) + # def _load_from_dataloader(self, reset: torch.Tensor | None = None): """Loads a single element from the dataloader, or alternatively from the buffer. - If `reset` is passed, the one element per reset will be loaded. + If `reset` is passed, then one element per reset will be loaded. """ if reset is not None: if not reset.any(): @@ -446,8 +449,16 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): loaded = [self._load_from_dataloader() for i in range(reset.sum())] return self.stack_method(loaded) + primers = getattr(self, "primers", None) + if primers is not None: + device = self.primers.device + else: + device = None + if self.use_buffer and len(self._queue) > 0: result = self._queue.popleft() + if result.device != device: + result = result.to(device) return result data = next(self.endless_dataloader) @@ -456,7 +467,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): # TODO: one could rename the keys too if isinstance(data, Mapping): out = TensorDict.from_dict( - data, auto_batch_size=self.auto_batch_size, batch_dims=1 + data, + auto_batch_size=self.auto_batch_size, + batch_dims=1, + device=device, ) elif self.data_keys is None: raise RuntimeError( @@ -469,12 +483,14 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): {k: val for k, val in _zip_strict(self.data_keys, data)}, auto_batch_size=self.auto_batch_size, batch_dims=1, + device=device, ) elif len(self.data_keys) == 1: out = TensorDict.from_dict( {self.data_keys[0]: data}, auto_batch_size=self.auto_batch_size, batch_dims=1, + device=device, ) else: raise ValueError( @@ -567,6 +583,7 @@ def __init__( requires_grad=False, log_prob_key: NestedKey = "sample_log_prob", action_key: NestedKey = "action", + functional: bool = True, ): if in_keys is None: in_keys = self.DEFAULT_IN_KEYS @@ -592,32 +609,39 @@ def __init__( # update the in_keys for dispatch etc self.in_keys = self.in_keys + actor.in_keys + self.functional = functional # check that the model has parameters - # params = TensorDict.from_module(actor) - # with params.apply( - # _stateless_param, device="meta", filter_empty=False - # ).to_module(actor): - # # copy a stateless actor - # self.__dict__["functional_actor"] = deepcopy(actor) - self.__dict__["functional_actor"] = actor - - # we need to register these params as buffer to have `to` and similar - # methods work properly - - # def _make_detached_param(x): - # - # if isinstance(x, nn.Parameter): - # # we need an nn.Parameter since some modules (RNN) require nn.Parameters - # return nn.Parameter(x.data.clone(), requires_grad=requires_grad) - # elif x.requires_grad: - # raise ValueError( - # "Encountered a value that requires gradients but is not an nn.Parameter instance." - # ) - # return x.clone() - # self.frozen_params = params.apply(_make_detached_param, filter_empty=False) - # if requires_grad: - # # includes the frozen params/buffers in the module parameters/buffers - # self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True) + if functional: + params = TensorDict.from_module(actor) + with params.apply( + _stateless_param, device="meta", filter_empty=False + ).to_module(actor): + # copy a stateless actor + self.__dict__["functional_actor"] = deepcopy(actor) + + # we need to register these params as buffer to have `to` and similar + # methods work properly + + def _make_detached_param(x): + + if isinstance(x, nn.Parameter): + # we need an nn.Parameter since some modules (RNN) require nn.Parameters + return nn.Parameter(x.data.clone(), requires_grad=requires_grad) + elif x.requires_grad: + raise ValueError( + "Encountered a value that requires gradients but is not an nn.Parameter instance." + ) + return x.clone() + + self.frozen_params = params.apply(_make_detached_param, filter_empty=False) + if requires_grad: + # includes the frozen params/buffers in the module parameters/buffers + self.frozen_params = TensorDictParams( + self.frozen_params, no_convert=True + ) + + else: + self.__dict__["functional_actor"] = actor # self._buffers["actor_params"] = params.clone().detach() @@ -651,11 +675,17 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: if self.out_keys[0] != ("reward",) and self.parent is not None: next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero()) return next_tensordict - # with self.frozen_params.to_module(self.functional_actor): - if isinstance( + + if self.functional: + with self.frozen_params.to_module(self.functional_actor): + dist = self.functional_actor.get_dist(next_tensordict.clone(False)) + # get the log_prob given the original model + log_prob = dist.log_prob(action) + elif isinstance( self.functional_actor, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), ): + # with self.frozen_params.to_module(self.functional_actor): dist = self.functional_actor.get_dist(next_tensordict.copy()) # get the log_prob given the original model log_prob = dist.log_prob(action) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 776a5601106..a58c483d2b8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5313,8 +5313,8 @@ class Tokenizer(UnaryTransform): def __init__( self, - in_keys: Sequence[NestedKey], - out_keys: Sequence[NestedKey], + in_keys: Sequence[NestedKey] | None = None, + out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, @@ -5385,14 +5385,39 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: out_key, observation, ) - elif not self.missing_tolerance or out_key not in next_tensordict.keys( - True + elif ( + self.missing_tolerance + and self.return_attention_mask + and out_key in next_tensordict.keys(True) ): + attention_key = _replace_last(out_key, "attention_mask") + if attention_key not in next_tensordict: + next_tensordict[attention_key] = torch.ones_like( + next_tensordict.get(out_key) + ) + elif not self.missing_tolerance: raise KeyError( f"{self}: '{in_key}' not found in tensordict {next_tensordict}" ) return next_tensordict + @dispatch(source="in_keys", dest="out_keys") + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): + data = tensordict.get(in_key, None) + if data is not None: + data = self._apply_transform(data) + if self.return_attention_mask: + data, attention_mask = data + tensordict.set( + _replace_last(out_key, "attention_mask"), + attention_mask, + ) + tensordict.set(out_key, data) + elif not self.missing_tolerance: + raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") + return tensordict + def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: if self.call_before_reset: with _set_missing_tolerance(self, True): @@ -5416,7 +5441,7 @@ def call_tokenizer_fn(self, value: str | list[str]): out = self.tokenizer.encode(value, return_tensors="pt", **kwargs)[0] # TODO: incorporate attention mask if self.return_attention_mask: - attention_mask = torch.ones_like(out, dtype=torch.bool) + attention_mask = torch.ones_like(out, dtype=torch.int64) else: kwargs["padding"] = ( self.padding if self.max_length is None else "max_length" @@ -5445,9 +5470,32 @@ def call_tokenizer_inv_fn(self, value: Tensor): out = self.tokenizer.batch_decode( value, skip_special_tokens=self.skip_special_tokens ) + device = self._str_device if isinstance(out, list): - return NonTensorStack(*out) - return NonTensorData(out) + result = NonTensorStack(*out) + if device: + result = result.to(device) + return result + return NonTensorData(out, device=device) + + @property + def _str_device(self): + parent = self.parent + if parent is None: + return None + if self.in_keys: + in_key = self.in_keys[0] + elif self.in_keys_inv: + in_key = self.in_keys_inv[0] + else: + return None + if in_key in parent.observation_keys: + return parent.full_observation_spec[in_key].device + if in_key in parent.action_keys: + return parent.full_action_spec[in_key].device + if in_key in parent.state_keys: + return parent.full_state_spec[in_key].device + return None def transform_input_spec(self, input_spec: Composite) -> Composite: # We need to cap the spec to generate valid random strings @@ -5461,6 +5509,9 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: f"The input keys {in_key} wasn't found in the env input specs." ) local_spec = spec.pop(in_key) + local_dtype = local_spec.dtype + if local_dtype is None or local_dtype.is_floating_point: + local_dtype = torch.int64 new_shape = spec.shape if self.max_length is None: # Then we can't tell what the shape will be @@ -5472,7 +5523,7 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: self.tokenizer.vocab_size, shape=new_shape, device=local_spec.device, - dtype=local_spec.dtype, + dtype=local_dtype, ) return input_spec @@ -5484,12 +5535,24 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec attention_mask_keys = set() for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): new_shape = observation_spec.shape + torch.Size((-1,)) + try: + in_spec = observation_spec[in_key] + obs_dtype = in_spec.dtype + device = in_spec.device + except KeyError: + # In some cases (eg, the tokenizer is applied during reset on data that + # originates from a dataloader) we don't have an in_spec + in_spec = None + obs_dtype = None + device = observation_spec.device + if obs_dtype is None or obs_dtype.is_floating_point: + obs_dtype = torch.int64 observation_spec[out_key] = Bounded( 0, self.tokenizer.vocab_size, shape=new_shape, - device=observation_spec[in_key].device, - dtype=observation_spec[in_key].dtype, + device=device, + dtype=obs_dtype, ) if self.return_attention_mask: attention_mask_key = _replace_last(out_key, "attention_mask") @@ -5500,12 +5563,15 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec "entries are unique." ) attention_mask_keys.add(attention_mask_key) + attention_dtype = obs_dtype + if attention_dtype is None or attention_dtype.is_floating_point: + attention_dtype = torch.int64 observation_spec[attention_mask_key] = Bounded( 0, 2, shape=new_shape, - device=observation_spec[in_key].device, - dtype=observation_spec[in_key].dtype, + device=device, + dtype=attention_dtype, ) return observation_spec @@ -6124,7 +6190,7 @@ def __init__( kwargs = primers if not isinstance(kwargs, Composite): shape = kwargs.pop("shape", None) - device = kwargs.pop("device", None) + device = self.device if "batch_size" in kwargs.keys(): extra_kwargs = {"batch_size": kwargs.pop("batch_size")} else: @@ -6197,7 +6263,7 @@ def reset_key(self, value): @property def device(self): device = self._device - if device is None and self.parent is not None: + if device is None and hasattr(self, "parent") and self.parent is not None: device = self.parent.device self._device = device return device diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index 8e9cda99b3c..58f3b1affd8 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -185,6 +185,7 @@ class MaskedCategorical(D.Categorical): padding_value: The padding value in the mask tensor. When sparse_mask == True, the padding_value will be ignored. + Examples: >>> torch.manual_seed(0) >>> logits = torch.randn(4) / 100 # almost equal probabilities >>> mask = torch.tensor([True, False, True, True]) diff --git a/torchrl/modules/llm/__init__.py b/torchrl/modules/llm/__init__.py index 5d70748aeff..559133bea44 100644 --- a/torchrl/modules/llm/__init__.py +++ b/torchrl/modules/llm/__init__.py @@ -3,7 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .common import CategoricalSequential from .transformers_policy import from_hf_transformers from .vllm_policy import from_vllm -__all__ = ["from_hf_transformers", "from_vllm"] +__all__ = ["from_hf_transformers", "from_vllm", "CategoricalSequential"] diff --git a/torchrl/modules/llm/common.py b/torchrl/modules/llm/common.py new file mode 100644 index 00000000000..a168c16cec8 --- /dev/null +++ b/torchrl/modules/llm/common.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +from tensordict import NestedKey, TensorDictBase +from tensordict.nn import ( + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, + TensorDictSequential, +) +from torch import distributions as D +from torch.distributions import Categorical + + +class CategoricalSequential(ProbabilisticTensorDictSequential): + """A ProbabilisticTensorDictSequential subclass meant to work with LLMs. + + .. seealso:: :class:`~tensordict.nn.ProbabilisticTensorDictSequential` class. + + """ + + def get_dist( + self, + tensordict: TensorDictBase, + tensordict_out: TensorDictBase | None = None, + **kwargs, + ) -> D.Distribution: + td_out = self(tensordict.copy()) + return Categorical(td_out.get("logits")) + + # Sampling is taken care of by the sub-modules + forward = TensorDictSequential.forward + + @property + def log_prob_keys(self): + return ["log_probs"] + + log_prob_key = ProbabilisticTensorDictModule.log_prob_key + + @property + def dist_params_keys(self) -> List[NestedKey]: + raise NotImplementedError + + @property + def dist_sample_keys(self) -> List[NestedKey]: + return ["tokens_response"] diff --git a/torchrl/modules/llm/transformers_policy.py b/torchrl/modules/llm/transformers_policy.py index e977fc18be6..f3894d6693e 100644 --- a/torchrl/modules/llm/transformers_policy.py +++ b/torchrl/modules/llm/transformers_policy.py @@ -6,15 +6,11 @@ import torch -from tensordict import NestedKey, TensorDictBase -from tensordict.nn import ( - TensorDictModule as Mod, - TensorDictModuleBase, - TensorDictSequential as Seq, - WrapModule, -) +from tensordict import NestedKey, TensorDict, TensorDictBase +from tensordict.nn import TensorDictModule as Mod, TensorDictModuleBase, WrapModule from tensordict.tensorclass import NonTensorData, NonTensorStack from torchrl.data.llm import LLMData +from torchrl.modules.llm.common import CategoricalSequential def _maybe_clear_device(td): @@ -32,7 +28,7 @@ def _maybe_set_device(td): return td.to(device) -def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: +def log_probs_generate(td: TensorDictBase) -> TensorDictBase: """Computes the log_probs from a Transformer formatted TensorDict. Required keys in tensordict: @@ -58,16 +54,21 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: del td["tokens_out", "past_key_values"] - scores = dict(td["tokens_out", "scores"].items()) - scores = torch.stack( - [scores[str(k)] for k in range(len(scores))], 1 + logits = dict(td["tokens_out", "logits"].items()) + logits = torch.stack( + [logits[str(k)] for k in range(len(logits))], 1 ) # shape (B, seq-len, vocab_size) - logits = scores - scores.logsumexp(dim=-1, keepdim=True) - td["logits"] = scores[..., -seq_len:, :] - del td["tokens_out", "scores"] + # logits = logits - logits.logsumexp(dim=-1, keepdim=True) + logits = logits.log_softmax(dim=-1) + torch.testing.assert_close( + logits.logsumexp(-1), torch.zeros_like(logits.logsumexp(-1)) + ) + td["logits"] = logits[..., -seq_len:, :] + del td["tokens_out", "logits"] + tokens = tokens_out[..., -seq_len:] # shape (B, seq-len) log_probs = logits.gather(-1, tokens.unsqueeze(-1)) - td["log_probs"] = log_probs + td["log_probs"] = log_probs.squeeze(-1) return td @@ -90,19 +91,23 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: - "forward", "past_key_values" - "forward" """ - tokens_out = td["tokens_response", "input_ids"] - seq_len = tokens_out.shape[-1] + tokens = td["tokens_response", "input_ids"] + seq_len = tokens.shape[-1] + + logits = td["forward", "logits"] + # logits = logits - logits.logsumexp(dim=-1, keepdim=True) + logits = logits.log_softmax(dim=-1) + torch.testing.assert_close( + logits.logsumexp(-1), torch.zeros_like(logits.logsumexp(-1)) + ) + logits = logits[..., -seq_len - 1 : -1, :] + td["logits"] = logits + del td["forward"] + td["tokens_response"] = td["tokens_response", "input_ids"] - del td["forward", "past_key_values"] + log_probs = logits.gather(-1, tokens.unsqueeze(-1)) + td["log_probs"] = log_probs.squeeze(-1) - scores = td["forward", "logits"] - scores = scores[..., -seq_len:, :] - logits = scores - scores.logsumexp(dim=-1, keepdim=True) - td["logits"] = scores - del td["forward"] - scores.shape[1] - log_probs = logits.gather(-1, tokens_out.unsqueeze(-1)) - td["log_probs"] = log_probs return td @@ -183,89 +188,274 @@ def from_hf_transformers( # Keys: text_key: NestedKey = "text" + if from_text: + if generate: + func = _from_hf_generate_text + else: + func = _from_hf_lp_text + else: + if generate: + func = _from_hf_generate_tokens + else: + func = _from_hf_lp_tokens + module_dict = func( + device=device, + tokenizer_kwargs=tokenizer_kwargs, + text_key=text_key, + tokenizer=tokenizer, + kwargs=kwargs, + return_log_probs=return_log_probs, + model=model, + ) + return CategoricalSequential(module_dict, inplace=True) + + +def _from_hf_generate_text( + *, device, tokenizer_kwargs, text_key, tokenizer, kwargs, return_log_probs, model +) -> dict: + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + # TODO: add other paddings + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + + module_dict["encode"] = Mod( + tokenizer, + in_keys=[text_key], + out_keys=["tokens_in"], + method_kwargs=tokenizer_kwargs, + strict=True, + # We don't need the text after this + inplace=False, + ) + + if device: + module_dict["to_dest_device"] = Mod( + lambda tensor: tensor.to(device), + in_keys=["tokens_in"], + out_keys=["tokens_in"], + strict=True, + ) + + if not kwargs: + kwargs = {} + if return_log_probs: + if not kwargs.setdefault("output_logits", True): + raise RuntimeError + if not kwargs.setdefault("return_dict_in_generate", True): + raise RuntimeError + if ( + kwargs.setdefault("tokenizer", tokenizer) is not tokenizer + and tokenizer is not None + ): + raise RuntimeError + + module_dict["generate"] = Mod( + model, + method="generate", + method_kwargs=kwargs, + in_keys={ + "input_ids": ("tokens_in", "input_ids"), + "attention_mask": ("tokens_in", "attention_mask"), + }, + out_keys=["tokens_out"], + out_to_in_map=True, + strict=True, + ) + + # Keep only the new tokens + def remove_input_seq(tokens_in, tokens_out, tokenizer=tokenizer): + result = tokens_out[..., tokens_in.shape[-1] :] + return result + + module_dict["remove_input_seq"] = Mod( + remove_input_seq, + in_keys=[("tokens_in", "input_ids"), ("tokens_out", "sequences")], + out_keys=[("tokens_out", "sequences")], + strict=True, + ) + if return_log_probs: + module_dict["extract_log_probs"] = WrapModule( + log_probs_generate, + in_keys=[("tokens_out", "sequences"), ("tokens_out", "scores")], + out_keys=["logits", "log_probs"], + ) + module_dict["decode"] = Mod( + tokenizer.batch_decode, + in_keys=[("tokens_out", "sequences")], + out_keys=["text_out"], + strict=True, + ) + if device: + module_dict["to_source_device"] = _maybe_set_device + + module_dict["rebuild"] = Mod( + lambda *x: x, + in_keys=[ + ("tokens_out", "sequences"), + ("tokens_in", "input_ids"), + ("tokens_in", "attention_mask"), + "text_out", + "log_probs", + "logits", + ], + out_keys=[ + "tokens_response", + "tokens", + "attention_mask", + "text_response", + "log_probs", + "logits", + ], + # There may not be log_probs and logits + strict=False, + inplace=False, + ) + return module_dict + + +def _from_hf_generate_tokens( + *, device, tokenizer_kwargs, text_key, tokenizer, kwargs, return_log_probs, model +) -> dict: + # Keys: token_key: NestedKey = "tokens" attention_mask_key: NestedKey = "attention_mask" module_dict = {} if device: module_dict["clear_device"] = _maybe_clear_device - if from_text: - if not tokenizer_kwargs: - tokenizer_kwargs = {} - if not tokenizer_kwargs.setdefault("return_attention_mask", True): - raise RuntimeError - if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": - raise RuntimeError - # TODO: add other paddings - if tokenizer_kwargs.setdefault("padding", True) not in (True,): - raise RuntimeError - if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + module_dict["format"] = Mod( + lambda *x: x, + in_keys=[token_key, attention_mask_key], + out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + strict=False, + # We don't need the text after this + inplace=False, + ) + + if device: + module_dict["to_dest_device"] = Mod( + lambda tensor: tensor.to(device), + in_keys=["tokens_in"], + out_keys=["tokens_in"], + strict=True, + ) + + if not kwargs: + kwargs = {} + if return_log_probs: + # if not kwargs.setdefault("output_scores", True): + # raise RuntimeError + if not kwargs.setdefault("output_logits", True): raise RuntimeError + if not kwargs.setdefault("return_dict_in_generate", True): + raise RuntimeError + if ( + kwargs.setdefault("tokenizer", tokenizer) is not tokenizer + and tokenizer is not None + ): + raise RuntimeError + + module_dict["generate"] = Mod( + model, + method="generate", + method_kwargs=kwargs, + in_keys={ + "input_ids": ("tokens_in", "input_ids"), + "attention_mask": ("tokens_in", "attention_mask"), + }, + out_keys=["tokens_out"], + out_to_in_map=True, + strict=True, + ) - if generate: - module_dict["encode"] = Mod( - tokenizer, - in_keys=[text_key], - out_keys=["tokens_in"], - method_kwargs=tokenizer_kwargs, - strict=True, - # We don't need the text after this - inplace=False, - ) - else: - module_dict["encode"] = Mod( - # TODO: make this work with many strings - # Tokenize both strings, and only the first - lambda x, y: ( - tokenizer([_x + _y for _x, _y in zip(x, y)], **tokenizer_kwargs), - tokenizer(x, **tokenizer_kwargs), - ), - in_keys=[text_key, "text_response"], - out_keys=["tokens_in", "tokens_response"], - strict=True, - inplace=False, - ) + # Keep only the new tokens + def remove_input_seq(tokens_in, tokens_out): + # TODO: remove this assert + assert (tokens_out[..., : tokens_in.shape[-1]] == tokens_in).all() + result = tokens_out[..., tokens_in.shape[-1] :] + return result + + module_dict["remove_input_seq"] = Mod( + remove_input_seq, + in_keys=[("tokens_in", "input_ids"), ("tokens_out", "sequences")], + out_keys=[("tokens_out", "sequences")], + strict=True, + ) + if return_log_probs: + module_dict["extract_log_probs"] = WrapModule( + log_probs_generate, + in_keys=[("tokens_out", "sequences"), ("tokens_out", "logits")], + out_keys=["logits", "log_probs"], + ) + if device: + module_dict["to_source_device"] = _maybe_set_device + module_dict["rebuild"] = Mod( + lambda *x: x, + in_keys=[("tokens_out", "sequences"), "log_probs", "logits"], + out_keys=["tokens_response", "log_probs", "logits"], + inplace=False, + ) + return CategoricalSequential(module_dict, inplace=True) - def select(x, y): - return x.apply(lambda _x, _y: _x[..., _y.shape[-1] :], y) - module_dict["stack_response"] = Mod( - # Remove the init from the total tokens to get only the response tokens - select, - in_keys=["tokens_in", "tokens_response"], - out_keys=["tokens_response"], - strict=True, +def _from_hf_lp_text( + *, device, tokenizer_kwargs, text_key, tokenizer, kwargs, return_log_probs, model +) -> dict: + # Keys: + text_key: NestedKey = "text" + + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + # TODO: add other paddings + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + + def encode(input_txt: list[str], output_txt: list[str]): + input_tokens = TensorDict.from_dict( + tokenizer( + [_x + _y for _x, _y in zip(input_txt, output_txt)], + **tokenizer_kwargs, ) - elif not generate: - - def stack_for_logprobs(tokens, tokens_response, attention_mask=None): - tokens = torch.cat([tokens, tokens_response], -1) - if attention_mask is not None: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones(tokens_response.shape)], -1 - ) - return tokens, tokens_response, attention_mask - - module_dict["stack_response"] = Mod( - stack_for_logprobs, - in_keys=["tokens", "tokens_response", "attention_mask"], - out_keys=[ - ("tokens_in", "input_ids"), - ("tokens_response", "input_ids"), - ("tokens_in", "attention_mask"), - ], - strict=False, - inplace=False, ) - else: - module_dict["format"] = Mod( - lambda *x: x, - in_keys=[token_key, attention_mask_key], - out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], - strict=False, - # We don't need the text after this - inplace=False, + # TODO: if not generating, we should use the tokens generates before + # because encode-decode isn't strictly a bijection so encoding the response + # only may lead to surprises + input_only_tokens = TensorDict.from_dict( + tokenizer(input_txt, **tokenizer_kwargs) ) + output_tokens = input_tokens.apply( + lambda x, y: x[:, y.shape[1] :], input_only_tokens + ) + return input_tokens, output_tokens + + module_dict["encode"] = Mod( + # TODO: make this work with many strings + # Tokenize both strings, and only the first + encode, + in_keys=[text_key, "text_response"], + out_keys=["tokens_in", "tokens_response"], + strict=True, + inplace=False, + ) if device: module_dict["to_dest_device"] = Mod( @@ -275,133 +465,126 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None): strict=True, ) - if generate: - if not kwargs: - kwargs = {} - if return_log_probs: - if not kwargs.setdefault("output_scores", True): - raise RuntimeError - if not kwargs.setdefault("return_dict_in_generate", True): - raise RuntimeError - if ( - kwargs.setdefault("tokenizer", tokenizer) is not tokenizer - and tokenizer is not None - ): - raise RuntimeError - - module_dict["generate"] = Mod( - model, - method="generate", - method_kwargs=kwargs, - in_keys={ - "input_ids": ("tokens_in", "input_ids"), - "attention_mask": ("tokens_in", "attention_mask"), - }, - out_keys=["tokens_out"], - out_to_in_map=True, - strict=True, + if not kwargs: + kwargs = {} + if not kwargs.setdefault("return_dict", True): + raise RuntimeError + if return_log_probs not in (True, None): + raise RuntimeError( + "return_log_probs should be True or None when not generating." ) + module_dict["get_logprobs"] = Mod( + model, + method_kwargs=kwargs, + in_keys={ + "input_ids": ("tokens_in", "input_ids"), + "attention_mask": ("tokens_in", "attention_mask"), + }, + out_keys=["forward"], + out_to_in_map=True, + strict=True, + ) + module_dict["extract_log_probs"] = WrapModule( + log_probs_from_logits, + in_keys=[ + ("tokens_in", "input_ids"), + ("forward", "logits"), + ("tokens_response", "logits"), + ], + out_keys=["logits", "log_probs"], + ) + if device: + module_dict["to_source_device"] = _maybe_set_device + module_dict["rebuild"] = Mod( + lambda *x: x, + in_keys=["log_probs", "logits", "tokens_response"], + out_keys=["log_probs", "logits", "tokens_response"], + inplace=False, + strict=True, + ) + return module_dict - # Keep only the new tokens - def remove_input_seq(tokens_in, tokens_out): - return tokens_out[..., tokens_in.shape[-1] :] - module_dict["remove_input_seq"] = Mod( - remove_input_seq, - in_keys=[("tokens_in", "input_ids"), ("tokens_out", "sequences")], - out_keys=[("tokens_out", "sequences")], - strict=True, - ) - if return_log_probs: - module_dict["extract_log_probs"] = WrapModule( - log_probs_from_scores, - in_keys=[("tokens_out", "sequences"), ("tokens_out", "scores")], - out_keys=["logits", "log_probs"], - ) - if from_text: - module_dict["decode"] = Mod( - tokenizer.batch_decode, - in_keys=[("tokens_out", "sequences")], - out_keys=["text_out"], - strict=True, - ) - if device: - module_dict["to_source_device"] = _maybe_set_device - - module_dict["rebuild"] = Mod( - lambda *x: x, - in_keys=[ - ("tokens_out", "sequences"), - ("tokens_in", "input_ids"), - ("tokens_in", "attention_mask"), - "text_out", - "log_probs", - "logits", - ], - out_keys=[ - "tokens_response", - "tokens", - "attention_mask", - "text_response", - "log_probs", - "logits", - ], - # There may not be log_probs and logits - strict=False, - inplace=False, - ) - else: - if device: - module_dict["to_source_device"] = _maybe_set_device - module_dict["rebuild"] = Mod( - lambda *x: x, - in_keys=[("tokens_out", "sequences"), "log_probs", "logits"], - out_keys=["tokens_response", "log_probs", "logits"], - inplace=False, - ) - else: - if not kwargs: - kwargs = {} - if not kwargs.setdefault("return_dict", True): - raise RuntimeError - if return_log_probs not in (True, None): - raise RuntimeError( - "return_log_probs should be True or None when not generating." +def _from_hf_lp_tokens( + *, device, tokenizer_kwargs, text_key, tokenizer, kwargs, return_log_probs, model +) -> dict: + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + + def stack_for_logprobs(tokens, tokens_response, attention_mask=None): + tokens = torch.cat([tokens, tokens_response], -1) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones(tokens_response.shape)], -1 ) - module_dict["get_logprobs"] = Mod( - model, - method_kwargs=kwargs, - in_keys={ - "input_ids": ("tokens_in", "input_ids"), - "attention_mask": ("tokens_in", "attention_mask"), - }, - out_keys=["forward"], - out_to_in_map=True, + # We need to carry over tokens_response to retrieve the shape of the response + return tokens, tokens_response, attention_mask + + module_dict["stack_response"] = Mod( + stack_for_logprobs, + in_keys=["tokens", "tokens_response", "attention_mask"], + out_keys=[ + ("tokens_in", "input_ids"), + ("tokens_response", "input_ids"), + ("tokens_in", "attention_mask"), + ], + strict=False, + inplace=False, + # We can also work with lists, just ask for lists through as_list=True + get_kwargs={ + "as_padded_tensor": True, + "padding_side": "left", + }, + ) + + if device: + module_dict["to_dest_device"] = Mod( + lambda tensor: tensor.to(device), + in_keys=["tokens_in"], + out_keys=["tokens_in"], strict=True, ) - module_dict["extract_log_probs"] = WrapModule( - log_probs_from_logits, - in_keys=[("tokens_in", "input_ids"), ("forward", "logits")], - out_keys=["logits", "log_probs"], + + if not kwargs: + kwargs = {} + if not kwargs.setdefault("return_dict", True): + raise RuntimeError + if return_log_probs not in (True, None): + raise RuntimeError( + "return_log_probs should be True or None when not generating." ) - if device: - module_dict["to_source_device"] = _maybe_set_device - if from_text: - module_dict["rebuild"] = Mod( - lambda *x: x, - in_keys=["log_probs", "logits", "tokens_response"], - out_keys=["log_probs", "logits", "tokens_response"], - inplace=False, - ) - else: - module_dict["rebuild"] = Mod( - lambda *x: x, - in_keys=["log_probs", "logits"], - out_keys=["log_probs", "logits"], - inplace=False, - ) + module_dict["get_logprobs"] = Mod( + model, + method_kwargs=kwargs, + in_keys={ + "input_ids": ("tokens_in", "input_ids"), + "attention_mask": ("tokens_in", "attention_mask"), + # "labels": ("tokens_in", "input_ids"), + }, + out_keys=["forward"], + out_to_in_map=True, + strict=True, + ) + module_dict["extract_log_probs"] = WrapModule( + log_probs_from_logits, + in_keys=[ + ("tokens_in", "input_ids"), + ("forward", "logits"), + ("tokens_response", "logits"), + ], + out_keys=["logits", "log_probs"], + ) + if device: + module_dict["to_source_device"] = _maybe_set_device + module_dict["rebuild"] = Mod( + lambda *x: x, + in_keys=["log_probs", "logits"], + out_keys=["log_probs", "logits"], + inplace=False, + ) - return Seq(module_dict, inplace=True) + return module_dict if __name__ == "__main__": @@ -412,7 +595,7 @@ def remove_input_seq(tokens_in, tokens_out): tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - model = GPT2LMHeadModel(GPT2Config()) + model = GPT2LMHeadModel(GPT2Config()).eval() tokenizer.padding_side = "left" diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index 06b3d099856..cac19a57dea 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -17,24 +17,21 @@ NonTensorData, NonTensorStack, TensorClass, + TensorDict, ) from tensordict.nn import ( TensorDictModule as Mod, TensorDictModuleBase, TensorDictSequential as Seq, + WrapModule, ) -from tensordict.utils import _zip_strict +from tensordict.utils import _zip_strict, expand_as_right from torchrl.data import LLMData _has_vllm = importlib.util.find_spec("vllm") -if _has_vllm: - import vllm - - CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput) -else: - CompletionOutput_tc = None +CompletionOutput_tc = None def _maybe_clear_device(td): @@ -63,6 +60,7 @@ def from_vllm( generate: bool = True, generate_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, + pad_output: bool = True, ) -> TensorDictModuleBase: """Creates a TensorDictModule from a vLLM model. @@ -119,6 +117,72 @@ def from_vllm( Transformers library. """ + # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks + if tokenizer is None: + tokenizer = model.get_tokenizer() + + # retrieve the padding value - we use this to make the log-probs of pad token = 1 + padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0] + + if from_text: + if generate: + func = _from_vllm_generate_text + else: + func = _from_vllm_logprobs_text + else: + if generate: + func = _from_vllm_generate_tokens + else: + func = _from_vllm_logprobs_tokens + module_dict = func( + tokenizer=tokenizer, + model=model, + device=device, + padding_value=padding_value, + generate_kwargs=generate_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + return_log_probs=return_log_probs, + pad_output=pad_output, + ) + return Seq(module_dict, inplace=True) + + +def to_list(tokens, attention_mask): + """Converts a tensor of integer in a masked list (of lists) of integers.""" + if isinstance(tokens, torch.Tensor): + # TODO: make this an ND NonTensorStack + parent = [] + queue = collections.deque() + if attention_mask is None: + attention_mask = torch.ones_like(tokens) + queue.append((tokens, attention_mask.bool(), parent)) + while queue: + token, amask, _parent = queue.popleft() + if token.ndim == 1: + _parent.extend(token[amask].tolist()) + else: + _parent.extend([[] for _ in range(token.shape[0])]) + queue.extend( + [ + (t, m, local_parent) + for t, m, local_parent in zip(token, amask, _parent) + ] + ) + tokens = parent + return NonTensorStack(*tokens) + + +def _from_vllm_generate_text( + *, + tokenizer, + model, + device, + padding_value, + generate_kwargs, + tokenizer_kwargs, + return_log_probs, + pad_output, +): try: from vllm import SamplingParams except ImportError: @@ -128,111 +192,160 @@ def from_vllm( token_key: NestedKey = ("tokens",) attention_mask_key: NestedKey = ("attention_mask",) - # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks - if tokenizer is None: - tokenizer = model.get_tokenizer() module_dict = {} if device: module_dict["clear_device"] = _maybe_clear_device - if from_text: - if not tokenizer_kwargs: - tokenizer_kwargs = {} - if not tokenizer_kwargs.setdefault("return_attention_mask", True): - raise RuntimeError - if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": - raise RuntimeError - if tokenizer_kwargs.setdefault("padding", True) not in (True,): - raise RuntimeError - if tokenizer_kwargs.setdefault("padding_side", "left") != "left": - raise RuntimeError + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + + def tokenize(td): + out = TensorDict(batch_size=td.batch_size, device=td.device) + tokens_in = TensorDict.from_dict( + tokenizer(td.get(text_key), **tokenizer_kwargs) + ) + out.set("tokens_in", tokens_in) + return out - if generate: - module_dict["encode"] = Mod( - tokenizer, - in_keys=[text_key], - out_keys=["tokens_in"], - method_kwargs=tokenizer_kwargs, - strict=True, - inplace="empty", - ) - else: - module_dict["encode"] = Mod( - # TODO: make this work with many strings - # Tokenize both strings, and only the first - lambda x, y: ( - tokenizer([_x + _y for _x, _y in zip(x, y)], **tokenizer_kwargs), - tokenizer(x, **tokenizer_kwargs), - ), - in_keys=[text_key, "text_response"], - out_keys=["tokens_in", "tokens_response"], - strict=True, - inplace="empty", - ) + module_dict["encode"] = WrapModule( + tokenize, + in_keys=[text_key], + out_keys=["tokens_in"], + ) - def select(x, y): - return x.apply(lambda _x, _y: _x[..., _y.shape[-1] :], y) + module_dict["to_list"] = Mod( + to_list, + in_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + out_keys=[("tokens_in", "input_ids_list")], + strict=False, + ) - module_dict["stack_response"] = Mod( - # Remove the init from the total tokens to get only the response tokens - select, - in_keys=["tokens_in", "tokens_response"], - out_keys=["tokens_response"], - strict=True, - ) - elif not generate: + if generate_kwargs is None: + generate_kwargs = {} + generate_kwargs.setdefault("detokenize", False) + generate_kwargs.setdefault("prompt_logprobs", False) + generate_kwargs.setdefault("logprobs", return_log_probs) + sampling_params = SamplingParams(**generate_kwargs) - def stack_for_logprobs(tokens, tokens_response, attention_mask=None): - tokens = torch.cat([tokens, tokens_response], -1) - if attention_mask is not None: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones(tokens_response.shape)], -1 - ) - return tokens, tokens_response, attention_mask - - module_dict["stack_response"] = Mod( - stack_for_logprobs, - in_keys=["tokens", "tokens_response", "attention_mask"], - out_keys=[ - ("tokens_in", "input_ids"), - ("tokens_response", "input_ids"), - ("tokens_in", "attention_mask"), - ], - strict=False, - inplace="empty", - ) - else: - module_dict["move_inputs"] = Mod( - lambda *x: x, - in_keys=["tokens", "attention_mask"], - out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], - # It's ok if there's no mask - strict=False, - inplace="empty", + module_dict["generate"] = Mod( + model, + method="generate", + method_kwargs={"sampling_params": sampling_params}, + in_keys={ + "prompt_token_ids": ("tokens_in", "input_ids_list"), + }, + out_keys=["tokens_out"], + out_to_in_map=True, + strict=True, + ) + + def get_output_tokens_and_log_probs(td, padding_value=padding_value): + td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) + if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict): + td = lazy_stack(list(td.unbind(0))) + # When not generate, we don't want to overwrite this + tokens_response_td = td["tokens_out"].outputs._tensordict.select( + "token_ids", "logprobs", strict=False ) + if pad_output: + tokens_response_td = tokens_response_td.densify( + layout=torch.strided + ).to_padded_tensor(padding=padding_value) + tokens_response_td.rename_key_("token_ids", "tokens_response") + if return_log_probs: + padded_values = tokens_response_td["tokens_response"] == padding_value + tokens_response_td.rename_key_("logprobs", "log_probs") + if padded_values.any(): + lps = tokens_response_td["log_probs"] + lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) + tokens_response_td["log_probs"] = lps + td.update(tokens_response_td) + return td - def to_list(tokens, attention_mask): - """Converts a tensor of integer in a masked list (of lists) of integers.""" - if isinstance(tokens, torch.Tensor): - # TODO: make this an ND NonTensorStack - parent = [] - queue = collections.deque() - if attention_mask is None: - attention_mask = torch.ones_like(tokens) - queue.append((tokens, attention_mask.bool(), parent)) - while queue: - token, amask, _parent = queue.popleft() - if token.ndim == 1: - _parent.extend(token[amask].tolist()) - else: - _parent.extend([[] for _ in range(token.shape[0])]) - queue.extend( - [ - (t, m, local_parent) - for t, m, local_parent in zip(token, amask, _parent) - ] - ) - tokens = parent - return NonTensorStack(*tokens) + module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs + + module_dict["decode"] = Mod( + tokenizer.batch_decode, + in_keys=["tokens_response"], + out_keys=["text_response"], + ) + + if device: + module_dict["to_source_device"] = _maybe_set_device + + in_keys = [ + "log_probs", + "tokens_response", + ("tokens_in", "input_ids"), + ("tokens_in", "attention_mask"), + "text_response", + ] + out_keys = [ + "log_probs", + "tokens_response", + token_key, + attention_mask_key, + "text_response", + ] + + def format_td(td): + td = td.select(*in_keys, strict=False) + td.rename_key_(("tokens_in", "input_ids"), token_key) + td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key) + del td["tokens_in"] + return td + + module_dict["format"] = WrapModule( + format_td, + in_keys=in_keys, + out_keys=out_keys, + ) + + return module_dict + + +def _from_vllm_generate_tokens( + *, + tokenizer, + model, + device, + padding_value, + generate_kwargs, + tokenizer_kwargs, + return_log_probs, + pad_output, +): + try: + from vllm import SamplingParams + except ImportError: + raise ImportError("Please install `vllm` to use `from_vllm`.") + + token_key: NestedKey = ("tokens",) + attention_mask_key: NestedKey = ("attention_mask",) + + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + + def move_input(td): + result = TensorDict(batch_size=td.batch_size, device=td.device) + result["tokens_in"] = result.new_empty() + result["tokens_in", "input_ids"] = td.get("tokens") + result["tokens_in", "attention_mask"] = td.get("attention_mask") + return result + + module_dict["move_inputs"] = WrapModule( + move_input, + in_keys=["tokens", "attention_mask"], + out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + ) module_dict["to_list"] = Mod( to_list, @@ -242,13 +355,10 @@ def to_list(tokens, attention_mask): ) if generate_kwargs is None: - generate_kwargs = { - "detokenize": False, - "prompt_logprobs": not generate, - "logprobs": return_log_probs, - } - if not generate: - generate_kwargs["max_tokens"] = 1 + generate_kwargs = {} + generate_kwargs.setdefault("detokenize", False) + generate_kwargs.setdefault("prompt_logprobs", False) + generate_kwargs.setdefault("logprobs", return_log_probs) sampling_params = SamplingParams(**generate_kwargs) module_dict["generate"] = Mod( @@ -263,80 +373,292 @@ def to_list(tokens, attention_mask): strict=True, ) - padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0] - def get_output_tokens_and_log_probs(td, padding_value=padding_value): td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) - if td.ndim and not isinstance(td, LazyStackedTensorDict): + if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict): td = lazy_stack(list(td.unbind(0))) - if generate: - # When not generate, we don't want to overwrite this - tokens_response_td = td["tokens_out"].outputs._tensordict.select( - "token_ids", "logprobs", strict=False - ) - tokens_response_td.rename_key_("token_ids", "tokens_response") - # td["tokens_response"] = outputs.token_ids - if return_log_probs: - tokens_response_td.rename_key_("logprobs", "log_probs") - # td["log_probs"] = outputs.logprobs.unsqueeze(-1) - td.update(tokens_response_td) - elif not generate: - td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1) + # When not generate, we don't want to overwrite this + tokens_response_td = td["tokens_out"].outputs._tensordict.select( + "token_ids", "logprobs", strict=False + ) + if pad_output: + tokens_response_td = tokens_response_td.densify( + layout=torch.strided + ).to_padded_tensor(padding=padding_value) + tokens_response_td.rename_key_("token_ids", "tokens_response") + if return_log_probs: + padded_values = tokens_response_td["tokens_response"] == padding_value + tokens_response_td.rename_key_("logprobs", "log_probs") + if padded_values.any(): + lps = tokens_response_td["log_probs"] + lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) + tokens_response_td["log_probs"] = lps + td.update(tokens_response_td) return td module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs - if not generate: + if device: + module_dict["to_source_device"] = _maybe_set_device + + in_keys = [ + "log_probs", + "tokens_response", + ("tokens_in", "input_ids"), + ("tokens_in", "attention_mask"), + "text_response", + ] + out_keys = [ + "log_probs", + "tokens_response", + token_key, + attention_mask_key, + "text_response", + ] + + def format_td(td): + td = td.select(*in_keys, strict=False) + td.rename_key_(("tokens_in", "input_ids"), token_key) + td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key) + del td["tokens_in"] + return td + + module_dict["format"] = WrapModule( + format_td, + in_keys=in_keys, + out_keys=out_keys, + ) + + return module_dict - def translate_lps(tokens_response, x): - # we disregard the tokens from the prompt to focus on those of the response - return x[..., -tokens_response.shape[-1] :, :] - module_dict["translate_lps"] = Mod( - translate_lps, - in_keys=[("tokens_response", "input_ids"), "prompt_logprobs"], - out_keys=["log_probs"], +def _from_vllm_logprobs_text( + *, + tokenizer, + model, + device, + padding_value, + generate_kwargs, + tokenizer_kwargs, + return_log_probs, + pad_output, +): + try: + from vllm import SamplingParams + except ImportError: + raise ImportError("Please install `vllm` to use `from_vllm`.") + + text_key: NestedKey = ("text",) + + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + + def tokenize(td): + out = TensorDict(batch_size=td.batch_size, device=td.device) + text_prompt = td.get(text_key) + text_response = td.get("text_response") + tokens_in = tokenizer( + [_x + _y for _x, _y in zip(text_prompt, text_response)], **tokenizer_kwargs ) - elif from_text: - module_dict["decode"] = Mod( - tokenizer.batch_decode, - in_keys=["tokens_response"], - out_keys=["text_response"], + tokens_prompt = tokenizer(text_prompt, **tokenizer_kwargs) + tokens_in = TensorDict.from_dict(tokens_in) + out["tokens_in"] = tokens_in + tokens_response = tokens_in.apply( + lambda total_tokens, input_tokens: total_tokens[:, input_tokens.shape[1] :], + TensorDict.from_dict(tokens_prompt), ) + out["tokens_response"] = tokens_response + return out + + module_dict["encode"] = WrapModule( + # TODO: make this work with many strings + tokenize, + in_keys=[text_key, "text_response"], + out_keys=["tokens_in", "tokens_response"], + ) + + module_dict["to_list"] = Mod( + to_list, + in_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + out_keys=[("tokens_in", "input_ids_list")], + strict=False, + ) + + if generate_kwargs is None: + generate_kwargs = {} + generate_kwargs.setdefault("detokenize", False) + generate_kwargs.setdefault("prompt_logprobs", True) + generate_kwargs.setdefault("logprobs", return_log_probs) + generate_kwargs["max_tokens"] = 1 + sampling_params = SamplingParams(**generate_kwargs) + + module_dict["generate"] = Mod( + model, + method="generate", + method_kwargs={"sampling_params": sampling_params}, + in_keys={ + "prompt_token_ids": ("tokens_in", "input_ids_list"), + }, + out_keys=["tokens_out"], + out_to_in_map=True, + strict=True, + ) + + def get_output_tokens_and_log_probs(td, padding_value=padding_value): + td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) + if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict): + td = lazy_stack(list(td.unbind(0))) + td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs + return td + + module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs + + def translate_lps(tokens_response, x): + # we disregard the tokens from the prompt to focus on those of the response + padded = tokens_response == padding_value + lps = x[..., -tokens_response.shape[-1] :] + lps = torch.where(~padded, lps, 0.0) + return lps + + module_dict["translate_lps"] = Mod( + translate_lps, + in_keys=[("tokens_response", "input_ids"), "prompt_logprobs"], + out_keys=["log_probs"], + ) if device: module_dict["to_source_device"] = _maybe_set_device - if generate: - module_dict["format"] = Mod( - lambda *x: x, - in_keys=[ - "log_probs", - "tokens_response", - ("tokens_in", "input_ids"), - ("tokens_in", "attention_mask"), - "text_response", - ], - out_keys=[ - "log_probs", - "tokens_response", - token_key, - attention_mask_key, - "text_response", - ], - strict=False, - inplace="empty", - ) - else: - module_dict["format"] = Mod( - lambda *x: x, - in_keys=["log_probs", "tokens_response"], - out_keys=["log_probs", "tokens_response"], - strict=False, - inplace="empty", - ) + module_dict["format"] = Mod( + lambda *x: x, + in_keys=["log_probs", ("tokens_response", "input_ids")], + out_keys=["log_probs", "tokens_response"], + strict=False, + inplace="empty", + ) - return Seq(module_dict, inplace=True) + return module_dict + + +def _from_vllm_logprobs_tokens( + *, + tokenizer, + model, + device, + padding_value, + generate_kwargs, + tokenizer_kwargs, + return_log_probs, + pad_output, +): + try: + from vllm import SamplingParams + except ImportError: + raise ImportError("Please install `vllm` to use `from_vllm`.") + + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + + def stack_for_logprobs(td): + tokens = td.get("tokens") + tokens_response = td.get("tokens_response") + attention_mask = td.get("attention_mask") + + tokens = torch.cat([tokens, tokens_response], -1) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones(tokens_response.shape)], -1 + ) + result = TensorDict(batch_size=td.batch_size, device=td.device) + result.set(("tokens_in", "input_ids"), tokens) + result.set(("tokens_response", "input_ids"), tokens_response) + if attention_mask is not None: + result.set(("tokens_in", "attention_mask"), attention_mask) + return result + + module_dict["stack_response"] = WrapModule( + stack_for_logprobs, + in_keys=["tokens", "tokens_response", "attention_mask"], + out_keys=[ + ("tokens_in", "input_ids"), + ("tokens_response", "input_ids"), + ("tokens_in", "attention_mask"), + ], + ) + + module_dict["to_list"] = Mod( + to_list, + in_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + out_keys=[("tokens_in", "input_ids_list")], + strict=False, + ) + + if generate_kwargs is None: + generate_kwargs = {} + generate_kwargs.setdefault("detokenize", False) + generate_kwargs.setdefault("prompt_logprobs", True) + generate_kwargs.setdefault("logprobs", return_log_probs) + generate_kwargs["max_tokens"] = 1 + sampling_params = SamplingParams(**generate_kwargs) + + module_dict["generate"] = Mod( + model, + method="generate", + method_kwargs={"sampling_params": sampling_params}, + in_keys={ + "prompt_token_ids": ("tokens_in", "input_ids_list"), + }, + out_keys=["tokens_out"], + out_to_in_map=True, + strict=True, + ) + + def get_output_tokens_and_log_probs(td, padding_value=padding_value): + td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) + if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict): + td = lazy_stack(list(td.unbind(0))) + td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs + return td + + module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs + + def translate_lps(tokens_response, lps): + # we disregard the tokens from the prompt to focus on those of the response + padded = tokens_response == padding_value + lps = lps[..., -tokens_response.shape[-1] :] + lps = torch.where(~padded, lps, 0.0) + return lps + + module_dict["translate_lps"] = Mod( + translate_lps, + in_keys=[("tokens_response", "input_ids"), "prompt_logprobs"], + out_keys=["log_probs"], + ) + + if device: + module_dict["to_source_device"] = _maybe_set_device + + module_dict["format"] = Mod( + lambda *x: x, + in_keys=["log_probs", "tokens_response"], + out_keys=["log_probs", "tokens_response"], + strict=False, + inplace="empty", + ) + + return module_dict class _RequestOutput_tc(TensorClass["nocast"]): @@ -353,6 +675,12 @@ class _RequestOutput_tc(TensorClass["nocast"]): num_cached_tokens: str def __post_init__(self): + global CompletionOutput_tc + if CompletionOutput_tc is None: + import vllm + + CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput) + def postproc(output): def get_logprob(output): t = [] diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 22191e049da..748b6ac3c2b 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -24,7 +24,6 @@ ProbabilisticTensorDictSequential, set_composite_lp_aggregate, TensorDictModule, - TensorDictModuleBase, ) from tensordict.utils import NestedKey from torch import distributions as d @@ -350,15 +349,18 @@ def __init__( if critic is not None: critic_network = critic del critic + + if critic_coef is None and critic_network is not None: + critic_coef = 1.0 + elif critic_coef in (None, 0) and critic_network is not None: + critic_coef = None + if actor_network is None or ( critic_network is None and critic_coef not in (None, 0.0) ): raise TypeError( "Missing positional arguments actor_network or critic_network." ) - critic_coef = ( - 1.0 if critic_coef is None and critic_network is not None else critic_coef - ) if reduction is None: reduction = "mean" @@ -532,6 +534,8 @@ def _get_cur_log_prob(self, tensordict): self.actor_network, (ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule), ): + # assert tensordict['log_probs'].requires_grad + # assert tensordict['logits'].requires_grad with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): @@ -555,15 +559,21 @@ def _get_cur_log_prob(self, tensordict): f"tensordict stored {self.tensor_keys.action} requires grad." ) log_prob = dist.log_prob(action) + assert log_prob.requires_grad else: - with self.actor_network_params.to_module( - self.actor_network - ) if self.functional else contextlib.nullcontext(): - td = self.actor_network(tensordict) - log_prob = td.get(self.tensor_keys.sample_log_prob) - # TODO: decustomize this - dist = torch.distributions.Categorical(td.get("logits")) - is_composite = False + raise NotImplementedError( + "Only probabilistic modules from tensordict.nn are currently supported. " + "If you need to implement a custom logic to retrieve the log-probs (to compute " + "the PPO objective) or the distribution (for the PPO entropy), please augment " + f"the {type(self).__class__} by implementing your own logic in _get_cur_log_prob." + ) + # with self.actor_network_params.to_module( + # self.actor_network + # ) if self.functional else contextlib.nullcontext(): + # td = self.actor_network(tensordict) + # log_prob = td.get(self.tensor_keys.sample_log_prob) + # dist = torch.distributions.Categorical(td.get("logits")) + # is_composite = False return log_prob, dist, is_composite def _log_weight( @@ -913,7 +923,7 @@ def __init__( entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coef: float = 0.01, - critic_coef: float = 1.0, + critic_coef: float | None = None, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, normalize_advantage_exclude_dims: tuple[int] = (), @@ -980,6 +990,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone(False) advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: + if self.critic_network is None: + raise RuntimeError( + "Critic network is not specified, cannot compute advantage within forward." + ) self.value_estimator( tensordict, params=self._cached_critic_network_params_detached, @@ -1184,7 +1198,7 @@ def __init__( entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coef: float = 0.01, - critic_coef: float = 1.0, + critic_coef: float | None = None, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, normalize_advantage_exclude_dims: tuple[int] = (), @@ -1240,24 +1254,21 @@ def _set_in_keys(self): _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") # Get the parameter keys from the actor dist - # actor_dist_module = None - # for module in self.actor_network.modules(): - # # Ideally we should combine them if there is more than one - # if isinstance(module, ProbabilisticTensorDictModule): - # if actor_dist_module is not None: - # raise RuntimeError( - # "Actors with one and only one distribution are currently supported " - # f"in {type(self).__name__}. If you need to use more than one " - # f"distributions over the action space please submit an issue " - # f"on github." - # ) - # actor_dist_module = module - # if actor_dist_module is None: - # if hasattr(self.actor_network, "in_keys"): - # actor_dist_module = self.actor_network - # else: - # raise RuntimeError("Could not find the probabilistic module in the actor.") - keys += list(self.actor_network.in_keys) + actor_dist_module = None + for module in self.actor_network.modules(): + # Ideally we should combine them if there is more than one + if isinstance(module, ProbabilisticTensorDictModule): + if actor_dist_module is not None: + raise RuntimeError( + "Actors with one and only one distribution are currently supported " + f"in {type(self).__name__}. If you need to use more than one " + f"distributions over the action space please submit an issue " + f"on github." + ) + actor_dist_module = module + if actor_dist_module is None: + raise RuntimeError("Could not find the probabilistic module in the actor.") + keys += list(actor_dist_module.in_keys) self._in_keys = list(set(keys)) @property @@ -1382,20 +1393,3 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def reset(self) -> None: self.beta = self._beta_init - - -class GRPO(ClipPPOLoss): - """TODO""" - def __init__( - self, - actor_network: TensorDictModuleBase, - # Default value of LLMData - log_prob_key="log_probs", - ): - super().__init__( - actor_network=actor_network, - critic_network=None, - critic_coef=0.0, - functional=False, - ) - self.set_keys(log_prob_key=log_prob_key) From e22e7f027c954763a1e41a57c79a32a3c08d5403 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 18 Mar 2025 11:36:07 -0700 Subject: [PATCH 4/6] Update [ghstack-poisoned] --- sota-implementations/post-training/grpo.py | 8 ++++---- torchrl/envs/transforms/llm.py | 2 ++ torchrl/envs/transforms/transforms.py | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sota-implementations/post-training/grpo.py b/sota-implementations/post-training/grpo.py index 0dba78dd5c7..d313f54357d 100644 --- a/sota-implementations/post-training/grpo.py +++ b/sota-implementations/post-training/grpo.py @@ -97,7 +97,7 @@ def collate_fn(batch): env.append_transform(ShapedCorrectnessReward(tokenizer=tokenizer)) # Ref model - ref_model = GPT2LMHeadModel.from_pretrained("gpt2") + ref_model = GPT2LMHeadModel.from_pretrained("gpt2").eval() TensorDict.from_module(ref_model).data.to_module(ref_model) ref_model = from_hf_transformers( ref_model, @@ -118,7 +118,7 @@ def collate_fn(batch): ) # Collector - train_model = GPT2LMHeadModel.from_pretrained("gpt2") + train_model = GPT2LMHeadModel.from_pretrained("gpt2").eval() collector = SyncDataCollector( env, policy, @@ -130,7 +130,7 @@ def collate_fn(batch): ) # Loss module - policy_traning = from_hf_transformers( + policy_training = from_hf_transformers( train_model, tokenizer=tokenizer, from_text=False, @@ -138,7 +138,7 @@ def collate_fn(batch): return_log_probs=True, ) loss_fn = ClipPPOLoss( - actor_network=policy_traning, + actor_network=policy_training, critic_network=None, critic_coef=0.0, functional=False, diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index 28d49d08013..c757052551c 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -371,6 +371,8 @@ def __init__( batch_size = getattr(dataloader, "batch_size", 0) if (batch_size > 1 and use_buffer is None) or repeats > 0: use_buffer = True + if repeats: + batch_size = batch_size * repeats self.use_buffer = use_buffer if self.use_buffer: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index a58c483d2b8..ecbdd2485c0 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6432,6 +6432,8 @@ def _reset_func( resets = self.default_value(reset=_reset) tensordict_reset.update(resets) + print('self.primers', self.primers) + print('tensordict_reset', tensordict_reset) for key, spec in self.primers.items(True, True): if not self._validated: self._validate_value_tensor(tensordict_reset.get(key), spec) From ff208bb9236cd79c1dc9f231572846dad8077129 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 18 Mar 2025 12:04:32 -0700 Subject: [PATCH 5/6] Update [ghstack-poisoned] --- sota-implementations/post-training/grpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/post-training/grpo.py b/sota-implementations/post-training/grpo.py index d313f54357d..3ffdc101344 100644 --- a/sota-implementations/post-training/grpo.py +++ b/sota-implementations/post-training/grpo.py @@ -15,7 +15,7 @@ from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from torchrl.collectors import SyncDataCollector -from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement +from torchrl.data import LazyStackStorage, RayReplayBuffer, ReplayBuffer, SamplerWithoutReplacement from torchrl.envs import DataLoadingPrimer, KLRewardTransform, LLMEnv, StepCounter from torchrl.modules import from_hf_transformers, from_vllm from torchrl.objectives import ClipPPOLoss From e019e255774076769de757f3a63c226e364a72a5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Mar 2025 17:21:58 -0700 Subject: [PATCH 6/6] Update [ghstack-poisoned] --- torchrl/collectors/collectors.py | 38 +- torchrl/collectors/weight_update.py | 59 ++- torchrl/data/utils.py | 10 +- torchrl/envs/common.py | 5 +- torchrl/envs/custom/llm.py | 37 +- torchrl/envs/transforms/llm.py | 101 ++++- torchrl/envs/transforms/transforms.py | 4 +- torchrl/envs/utils.py | 6 +- torchrl/modules/llm/common.py | 6 +- torchrl/modules/llm/vllm_policy.py | 344 ++++++++++++------ torchrl/modules/tensordict_module/common.py | 6 +- torchrl/modules/tensordict_module/sequence.py | 5 + torchrl/objectives/ppo.py | 1 - 13 files changed, 452 insertions(+), 170 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3ff35dbb560..93dd39bea3c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -152,8 +152,28 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): trust_policy: bool compiled_policy: bool cudagraphed_policy: bool - local_weights_updater: LocalWeightUpdaterBase | None = None - remote_weights_updater: RemoteWeightUpdaterBase | None = None + _local_weights_updater: LocalWeightUpdaterBase | None = None + _remote_weights_updater: RemoteWeightUpdaterBase | None = None + + @property + def local_weight_updater(self) -> LocalWeightUpdaterBase: + return self._local_weight_updater + + @local_weight_updater.setter + def local_weight_updater(self, value: LocalWeightUpdaterBase | None): + if value is not None: + value.register_collector(self) + self._local_weight_updater = value + + @property + def remote_weight_updater(self) -> RemoteWeightUpdaterBase: + return self._remote_weight_updater + + @remote_weight_updater.setter + def remote_weight_updater(self, value: RemoteWeightUpdaterBase | None): + if value is not None: + value.register_collector(self) + self._remote_weight_updater = value def _get_policy_and_device( self, @@ -1515,7 +1535,7 @@ def __repr__(self) -> str: f"\nexploration={self.exploration_type})" ) return string - except AttributeError: + except Exception: return f"{type(self).__name__}(not_init)" @@ -1831,6 +1851,7 @@ def __init__( self.local_weights_updater = local_weights_updater self.policy = policy + self.policy_factory = policy_factory remainder = 0 if total_frames is None or total_frames < 0: @@ -2012,6 +2033,10 @@ def _run_processes(self) -> None: env_fun = CloudpickleWrapper(env_fun) # Create a policy on the right device + policy_factory = self.policy_factory + if policy_factory is not None: + policy_factory = CloudpickleWrapper(policy_factory) + policy_device = self.policy_device[i] storing_device = self.storing_device[i] env_device = self.env_device[i] @@ -2020,13 +2045,14 @@ def _run_processes(self) -> None: # This makes sure that a given set of shared weights for a given device are # shared for all policies that rely on that device. policy = self.policy - policy_weights = self._policy_weights_dict[policy_device] + policy_weights = self._policy_weights_dict.get(policy_device) if policy is not None and policy_weights is not None: cm = policy_weights.to_module(policy) else: cm = contextlib.nullcontext() with cm: kwargs = { + "policy_factory": policy_factory, "pipe_parent": pipe_parent, "pipe_child": pipe_child, "queue_out": queue_out, @@ -3107,6 +3133,7 @@ def _main_async_collector( compile_policy: bool = False, cudagraph_policy: bool = False, no_cuda_sync: bool = False, + policy_factory: Callable | None = None, ) -> None: pipe_parent.close() # init variables that will be cleared when closing @@ -3116,6 +3143,7 @@ def _main_async_collector( create_env_fn, create_env_kwargs=create_env_kwargs, policy=policy, + policy_factory=policy_factory, total_frames=-1, max_frames_per_traj=max_frames_per_traj, frames_per_batch=frames_per_batch, @@ -3278,7 +3306,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): continue elif msg == "update": - inner_collector.update_policy_weights_() + inner_collector.update_policy_weights_(policy_weights=data_in) pipe_child.send((j, "updated")) has_timed_out = False continue diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py index 9911c3228af..46732057ad6 100644 --- a/torchrl/collectors/weight_update.py +++ b/torchrl/collectors/weight_update.py @@ -2,9 +2,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import abc +import weakref from abc import abstractmethod -from typing import Callable, Dict, List, TypeVar +from typing import Any, Callable, TypeVar import torch from tensordict import TensorDictBase @@ -42,6 +45,25 @@ class LocalWeightUpdaterBase(metaclass=abc.ABCMeta): """ + _collector_wr: Any = None + + def register_collector(self, collector: DataCollectorBase): # noqa + """Register a collector in the updater. + + Once registered, the updater will not accept another collector. + + Args: + collector (DataCollectorBase): The collector to register. + + """ + if self._collector_wr is not None: + raise RuntimeError("Cannot register collector twice.") + self._collector_wr = weakref.ref(collector) + + @property + def collector(self) -> torchrl.collectors.DataCollectorBase: # noqa + return self._collector_wr() if self._collector_wr is not None else None + @abstractmethod def _get_server_weights(self) -> TensorDictBase: ... @@ -102,12 +124,33 @@ class RemoteWeightUpdaterBase(metaclass=abc.ABCMeta): Methods: update_weights: Updates the weights on specified or all remote workers. + register_collector: Registers a collector. This should be called automatically by the collector + upon registration of the updater. .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and :meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`. """ + _collector_wr: Any = None + + def register_collector(self, collector: DataCollectorBase): # noqa + """Register a collector in the updater. + + Once registered, the updater will not accept another collector. + + Args: + collector (DataCollectorBase): The collector to register. + + """ + if self._collector_wr is not None: + raise RuntimeError("Cannot register collector twice.") + self._collector_wr = weakref.ref(collector) + + @property + def collector(self) -> torch.collector.DataCollectorBase: # noqa + return self._collector_wr() if self._collector_wr is not None else None + @abstractmethod def _sync_weights_with_worker( self, worker_id: int | torch.device, server_weights: TensorDictBase @@ -123,7 +166,7 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: ... @abstractmethod - def all_worker_ids(self) -> list[int] | List[torch.device]: + def all_worker_ids(self) -> list[int] | list[torch.device]: ... def _skip_update(self, worker_id: int | torch.device) -> bool: @@ -132,14 +175,14 @@ def _skip_update(self, worker_id: int | torch.device) -> bool: def __call__( self, weights: TensorDictBase | None = None, - worker_ids: torch.device | int | List[int] | List[torch.device] | None = None, + worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, ): return self.update_weights(weights=weights, worker_ids=worker_ids) def update_weights( self, weights: TensorDictBase | None = None, - worker_ids: torch.device | int | List[int] | List[torch.device] | None = None, + worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, ): if weights is None: # Get the weights on server (local) @@ -257,12 +300,12 @@ class MultiProcessedRemoteWeightUpdate(RemoteWeightUpdaterBase): def __init__( self, get_server_weights: Callable[[], TensorDictBase] | None, - policy_weights: Dict[torch.device, TensorDictBase], + policy_weights: dict[torch.device, TensorDictBase], ): self.weights_getter = get_server_weights self._policy_weights = policy_weights - def all_worker_ids(self) -> list[int] | List[torch.device]: + def all_worker_ids(self) -> list[int] | list[torch.device]: return list(self._policy_weights) def _sync_weights_with_worker( @@ -321,7 +364,7 @@ class RayRemoteWeightUpdater(RemoteWeightUpdaterBase): def __init__( self, policy_weights: TensorDictBase, - remote_collectors: List, + remote_collectors: list, max_interval: int = 0, ): self.policy_weights = policy_weights @@ -329,7 +372,7 @@ def __init__( self.max_interval = max(0, max_interval) self._batches_since_weight_update = [0] * len(self.remote_collectors) - def all_worker_ids(self) -> list[int] | List[torch.device]: + def all_worker_ids(self) -> list[int] | list[torch.device]: return list(range(len(self.remote_collectors))) def _get_server_weights(self) -> TensorDictBase: diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index 1d3777eb48d..30521d5d00b 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -222,7 +222,15 @@ def contains_lazy_spec(spec: TensorSpec) -> bool: return False -class CloudpickleWrapper: +class _CloudpickleWrapperMeta(type): + def __call__(cls, obj): + if isinstance(obj, cls): + return obj + else: + return super().__call__(obj) + + +class CloudpickleWrapper(metaclass=_CloudpickleWrapperMeta): """A wrapper for functions that allow for serialization in multiprocessed settings.""" def __init__(self, fn: Callable, **kwargs): diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index cdfbe5c19e3..4398cb42ee5 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -174,7 +174,10 @@ def metadata_from_env(env) -> EnvMetaData: specs = env.specs.to("cpu") batch_size = env.batch_size - env_str = str(env) + try: + env_str = str(env) + except Exception: + env_str = f"{env.__class__.__name__}()" device = env.device specs = specs.to("cpu") batch_locked = env.batch_locked diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 63f9db8e2d6..2745b9b0851 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -34,18 +34,16 @@ class LLMEnv(EnvBase): - """A text generation environment. + """A text generation environment for language models. This environment is designed to work with language models, where the observation is a string or a tensor of - integers representing a sequence of tokens. - The action is also a string or a tensor of integers, which is concatenated to the previous observation to form the - new observation. + integers representing a sequence of tokens. The action is also a string or a tensor of integers, which is + concatenated to the previous observation to form the new observation. By default, this environment is meant to track history for a prompt. Users can append transforms to tailor this to their use case, such as Chain of Thought (CoT) reasoning or other custom processing. Users must append a transform to set the "done" condition, which would trigger the loading of the next prompt. - Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via :meth:`~from_dataloader`. @@ -57,7 +55,7 @@ class LLMEnv(EnvBase): attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored. Defaults to ``"attention_mask"``. action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to - ``tokens_response`` or ``"text_response"``. + ``"tokens_response"`` or ``"text_response"``. reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`. Defaults to ``"reward"``. str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. @@ -66,22 +64,21 @@ class LLMEnv(EnvBase): unbounded vocabulary. Defaults to ``None``. no_stack (bool, optional): If ``False`` (default), the environment should stack the action with the past observation, each action being a new, unseen part of a conversation. Otherwise, the action is assumed - to be the plain output of the LLM, including the input tokens / strings. - has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by + to be the plain output of the LLM, including the input tokens/strings. + has_attention (bool, optional): If ``True``, an attention mask is to be used under the key indicated by :attr:`attention_key`. Defaults to ``True``. - assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape + assign_reward (bool, optional): If ``True``, a zero-valued reward of shape equal to the action shape is written during calls to `step()`. Defaults to ``False``. - assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the + assign_done (bool, optional): If ``True``, a zero-valued done and terminated state of shape equal to the action shape is written during calls to `step()`. Defaults to ``False``. - - .. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root + .. note:: Regardless of the value assigned to `assign_done`, a done state will be written at the root as it is a requirement for all TorchRL environments. - batch_size (int or torch.Size, optional): Batch size of the environment. If left empty, the environment is batchless (or batch-unlocked), meaning that it can accept tensordicts of any batch size. Defaults to ``None`` (batch-unlocked). - as_llm_data (bool, optional): If ``True``, the data will be of type :class:`~torchrl.data.LLMData`. - Defaults to ``False``. + + .. note:: When using a :class:`~torchrl.envs.DataLoadingPrimer` transform, the batch-size of the env + and the transform should match. .. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples. @@ -112,6 +109,7 @@ def __init__( assign_done: bool = False, batch_size: int | torch.Size | None = None, has_attention: bool = True, + # Experimental as_llm_data: bool = False, ) -> None: self.as_llm_data = as_llm_data @@ -255,6 +253,7 @@ def from_dataloader( device: torch.device | None = None, vocab_size: int | None = None, no_stack: bool = False, + # Experimental as_llm_data: bool = False, batch_size: int | torch.Size | None = None, has_attention: bool = True, @@ -316,6 +315,10 @@ def from_dataloader( batch_size (int or torch.Size, optional): Batch size of the environment. If left empty, the environment is batchless (or batch-unlocked), meaning that it can accept tensordicts of any batch size. Defaults to ``None`` (batch-unlocked). + + .. note:: When using a :class:`~torchrl.envs.DataLoadingPrimer` transform, the batch-size of the env + and the transform should match. + primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to ``None``. data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. If not passed ``observation_key`` will be populated with the data. @@ -328,8 +331,6 @@ def from_dataloader( repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo samples (rather than an advantage module). - as_llm_data (bool, optional): If ``True``, the data will be of type :class:`~torchrl.data.LLMData`. - Defaults to ``False``. Returns: LLMEnv: The created LLMEnv instance. @@ -410,7 +411,7 @@ def from_dataloader( no_stack=no_stack, assign_reward=assign_reward, assign_done=assign_done, - batch_size=batch_size, + batch_size=batch_size if batch_size is not None else primer.batch_size, has_attention=has_attention, as_llm_data=as_llm_data, ) diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index c757052551c..7f572836704 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import warnings from collections import deque from collections.abc import Mapping from copy import copy, deepcopy @@ -20,6 +21,7 @@ from torch import nn from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded +from torchrl.envs.common import EnvBase from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param from torchrl.envs.utils import make_composite_from_td @@ -87,6 +89,15 @@ class DataLoadingPrimer(TensorDictPrimer): Args: dataloader (Iterable[Any]): The dataloader to load data from. + During collection, we will attempt to convert it into a tensordict using + :func:`~tensordict.from_dict` or a similar function. + If the dataloader has a `batch_size` attribute, it is assumed that the output will have a batch-size (i.e., + that `TensorDict` can figure out how many samples are present through `auto_batch_size=True`). If that is + the case, the data collected from the dataloader will be put in a queue and delivered progressively such that + the number of samples equates the `batch_size` argument of the Primer (see :attr:`batch_size` argument + below). + If the dataloader does not have a batch_size argument (or `dataloader.batch_size=0`), we assume that each + sample is a single item. Keyword Args: primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None. @@ -99,12 +110,18 @@ class DataLoadingPrimer(TensorDictPrimer): ensures that `next()` is called on the dataloader only when necessary, and that elements of the dataset are loaded in order. Defaults to ``True`` whenever the batch-size of the dataloader is greater than 1. - auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the - tensordict returned by the transform will be automatically determined assuming that there is a single batch - dimension. repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo samples (rather than an advantage module). + batch_size (int, torch.Size or None): the batch-size of the data delivered by the transform. + This is somewhat unrelated to the batch-size of the dataloader, in the sense that this number may or may + not match the DL's batch size. + If left empty or 0, the transform will output as many samples as the input tensordict asks for (e.g., + passing a `TensorDict(batch_size=(3,))` to the :meth:`~.reset` method will give 3 sampled out of the + dataloader). + + .. note:: The batch-size of the Primer must match the batch-size of the parent environment (typically a + wrapper around :class:`~torchrl.envs.LLMEnv`). Attributes: dataloader (Iterable[Any]): The dataloader to load data from. @@ -360,26 +377,70 @@ def __init__( stack_method: Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"] = None, use_buffer: bool | None = None, - auto_batch_size: bool = True, + batch_size: int | torch.Size | None = None, repeats: int | None = None, device: torch.device | None = None, + group_repeats: bool = False, ): self.dataloader = dataloader if repeats is None: repeats = 0 self.repeats = repeats - batch_size = getattr(dataloader, "batch_size", 0) - if (batch_size > 1 and use_buffer is None) or repeats > 0: + + # Determine batch-size + # We must distinguish the batch-size of the DL and the batch size of the transform. + # We may want more or less elements than the DL and the logic is slightly different so we + # allow to recompose batches on the fly. If the DL has a batch-size, every element will be + # unbound and stored in a queue. Otherwise, we get as many elements from the DL to fulfill + # the required batch-size. + # + # If the batch-size is passed, we will stack as many elements as necessary to fulfill this. + # If not, we try to get it from the dataloader. Contrary to the dataloader, we will always + # deliver the same batch-size (we create an infinite dataloader and reset when it's done), + # whereas DLs with drop_last=False may return batches of different sizes. + # + # If the batch size passed to the transform is empty (torch.Size(())) or 0, we will consider that + # the batch-size is determined on-the-fly. + # + # A batch-size of 0 in the dataloader means no batch-size. + # + # If needed, the various repeats can be grouped in a single batch through group_repeats. + # + # If auto_batch_size is on, we call auto_batch_size=True when doing TensorDict.from_dict: + # That way we get a tensordict of the right batch-size. + # If the dataloader has no batch-size, we're not sure that we can determine the batch-size + # automatically so we will consider that each element in the DL has a batch-size of 0 (ie, + # a single non-batched element is returned at a time). + + if batch_size is None: + batch_size = getattr(dataloader, "batch_size", 0) + auto_batch_size = batch_size > 0 + else: + auto_batch_size = hasattr(dataloader, "batch_size") + + if not isinstance(batch_size, int): + if not isinstance(batch_size, (list, tuple)) or len(batch_size) > 1: + raise ValueError( + "batch_size must be an int, or a list / tuple of length <=1." + ) + if batch_size: + batch_size = batch_size[0] + else: + batch_size = 0 + + if (batch_size >= 1 and use_buffer is None) or repeats: use_buffer = True - if repeats: + + # We deliver all the repeats in the same batch + if repeats and group_repeats: batch_size = batch_size * repeats self.use_buffer = use_buffer if self.use_buffer: self._queue = deque() - # No auto_batch_size if we know we have a single element - self.auto_batch_size = auto_batch_size and (batch_size > 0) + self.auto_batch_size = auto_batch_size + self.batch_size = torch.Size((batch_size,)) if batch_size > 0 else None self.endless_dataloader = self._endless_iter(self.dataloader) if stack_method is None: @@ -435,10 +496,6 @@ def _endless_iter(self, obj): while True: yield from obj - # def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: - # td = super()._reset_env_preprocess(tensordict) - # return lazy_stack(list(td.unbind(0))) - # def _load_from_dataloader(self, reset: torch.Tensor | None = None): """Loads a single element from the dataloader, or alternatively from the buffer. @@ -448,7 +505,7 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): if not reset.any(): raise RuntimeError("reset must have at least one True value.") if reset.ndim > 0: - loaded = [self._load_from_dataloader() for i in range(reset.sum())] + loaded = [self._load_from_dataloader() for _ in range(reset.sum())] return self.stack_method(loaded) primers = getattr(self, "primers", None) @@ -502,11 +559,25 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): if not out.ndim: out = out.unsqueeze(0) self._queue.extend( - [d for d in out.unbind(0) for _ in range(max(1, self.repeats))] + [d for _ in range(max(1, self.repeats)) for d in out.unbind(0)] ) return self._queue.popleft() return out + def set_container(self, container: Transform | EnvBase) -> None: + result = super().set_container(container) + # Check batch size + parent = getattr(self, "parent", None) + if ( + self.batch_size is not None + and parent is not None + and parent.batch_size != self.batch_size + ): + warnings.warn( + f"The parent env has a different batch size than the {type(self).__name__} transform." + ) + return result + def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})" diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ecbdd2485c0..f2c7e8517a3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6294,7 +6294,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: f"observation_spec was expected to be of type Composite. Got {type(observation_spec)} instead." ) - if self.primers.shape != observation_spec.shape: + if self.primers.shape[: observation_spec.ndim] != observation_spec.shape: if self.expand_specs: self.primers = self._expand_shape(self.primers) elif self.expand_specs is None: @@ -6432,8 +6432,6 @@ def _reset_func( resets = self.default_value(reset=_reset) tensordict_reset.update(resets) - print('self.primers', self.primers) - print('tensordict_reset', tensordict_reset) for key, spec in self.primers.items(True, True): if not self._validated: self._validate_value_tensor(tensordict_reset.get(key), spec) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index eb236e56c4b..0af7937445d 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -508,6 +508,7 @@ def _set_single_key( if isinstance(key, str): key = (key,) for k in key: + # TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature try: val = source._get_str(k, None) if is_tensor_collection(val): @@ -528,7 +529,7 @@ def _set_single_key( # This is a temporary solution to understand if a key is heterogeneous # while not having performance impact when the exception is not raised except RuntimeError as err: - if re.match(r"Found more than one unique shape in the tensors", str(err)): + if re.match(r"Failed to stack tensors within a tensordict", str(err)): # this is a het key for s_td, d_td in zip(source.tensordicts, dest.tensordicts): _set_single_key(s_td, d_td, k, clone=clone, device=device) @@ -541,6 +542,7 @@ def _set(source, dest, key, total_key, excluded): total_key = total_key + (key,) non_empty = False if unravel_key(total_key) not in excluded: + # TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature try: val = source.get(key) if is_tensor_collection(val) and not isinstance( @@ -571,7 +573,7 @@ def _set(source, dest, key, total_key, excluded): # This is a temporary solution to understand if a key is heterogeneous # while not having performance impact when the exception is not raised except RuntimeError as err: - if re.match(r"Found more than one unique shape in the tensors", str(err)): + if re.match(r"Failed to stack tensors within a tensordict", str(err)): # this is a het key non_empty_local = False for s_td, d_td in zip(source.tensordicts, dest.tensordicts): diff --git a/torchrl/modules/llm/common.py b/torchrl/modules/llm/common.py index a168c16cec8..517b391a6c5 100644 --- a/torchrl/modules/llm/common.py +++ b/torchrl/modules/llm/common.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from __future__ import annotations from tensordict import NestedKey, TensorDictBase from tensordict.nn import ( @@ -40,9 +40,9 @@ def log_prob_keys(self): log_prob_key = ProbabilisticTensorDictModule.log_prob_key @property - def dist_params_keys(self) -> List[NestedKey]: + def dist_params_keys(self) -> list[NestedKey]: raise NotImplementedError @property - def dist_sample_keys(self) -> List[NestedKey]: + def dist_sample_keys(self) -> list[NestedKey]: return ["tokens_response"] diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index cac19a57dea..eabe1116ac8 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -6,28 +6,24 @@ import collections import importlib.util +from typing import Literal import torch from tensordict import ( from_dataclass, lazy_stack, LazyStackedTensorDict, - maybe_dense_stack, NestedKey, NonTensorData, NonTensorStack, TensorClass, TensorDict, ) -from tensordict.nn import ( - TensorDictModule as Mod, - TensorDictModuleBase, - TensorDictSequential as Seq, - WrapModule, -) +from tensordict.nn import TensorDictModule as Mod, TensorDictModuleBase, WrapModule from tensordict.utils import _zip_strict, expand_as_right from torchrl.data import LLMData +from torchrl.modules.llm.common import CategoricalSequential _has_vllm = importlib.util.find_spec("vllm") @@ -61,25 +57,41 @@ def from_vllm( generate_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, pad_output: bool = True, + inplace: Literal[True, False, "empty"] | None = True, ) -> TensorDictModuleBase: """Creates a TensorDictModule from a vLLM model. - This function provides a consistent interface across various LLM engines. - - It supports text generation and log probability computation, similar to the Hugging Face Transformers interface. + This function provides a consistent interface across various LLM engines, allowing for text generation and + log probability computation, similar to the Hugging Face Transformers interface. Args: - model (LLM): The vLLM model to wrap. - return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `False`. - tokenizer (transformers.tokenization_utils.PreTrainedTokenizer, optional): The tokenizer to use. Defaults to `None`. - from_text (bool, optional): Whether the input is text. Defaults to `False`. - device (torch.device, optional): The device to use for computation. Defaults to `None`. - generate (bool, optional): Whether to generate text. Defaults to `True`. - generate_kwargs (dict, optional): Additional arguments for the model's generate method. Defaults to `None`. - tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer. Defaults to `None`. + model (vllm.LLM): The vLLM model to wrap. + return_log_probs (bool, optional): Whether to return log probabilities of the generated tokens. + Defaults to `False`. + tokenizer (transformers.tokenization_utils.PreTrainedTokenizer, optional): The tokenizer to use for encoding + and decoding text. If `None`, the tokenizer associated with the model will be used. Defaults to `None`. + from_text (bool, optional): Indicates whether the input is in text format. If `True`, the input is expected to + be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `False`. + device (torch.device, optional): The device to use for computation. If `None`, the default device will be used. + Defaults to `None`. + generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on + the input. If `False`, only log probabilities will be computed. Defaults to `True`. + generate_kwargs (dict, optional): Additional arguments to pass to the model's generate method. These + arguments can control aspects of the generation process, such as temperature and top-k sampling. + Defaults to `None`. + tokenizer_kwargs (dict, optional): Additional arguments to pass to the tokenizer. These arguments can control + aspects of the tokenization process, such as padding and truncation. Defaults to `None`. + pad_output (bool, optional): Whether to pad the output sequences to a uniform length. If `True`, the output + sequences will be padded. If `False`, lists of tokens will be used without padding. Defaults to `True`. + inplace (Literal[True, False, "empty"], optional): Determines how the module should handle in-place + operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be + created. + If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will conserve + type, batch-size and device). Defaults to `True`. Returns: - TensorDictModuleBase: A configured TensorDictModule for the specified model. + TensorDictModuleBase: A configured TensorDictModule for the specified model, capable of handling text or + token inputs and producing generated text or log probabilities. Input Keys: @@ -95,21 +107,21 @@ def from_vllm( Output Keys: - "tokens_response": The generated token sequences. - - "log_probs": The log probabilities of the generated tokens (if `return_log_probs` is True). - - "text_response": The generated text (if `from_text` is True and `generate` is True). + - "log_probs": The log probabilities of the generated tokens (if `return_log_probs` is `True`). + - "text_response": The generated text (if `from_text` is `True` and `generate` is `True`). Example: >>> from vllm import LLM >>> from transformers import AutoTokenizer - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = LLM(model="facebook/opt-125m") + >>> model = LLM("gpt2") + >>> tokenizer = model.get_tokenizer() >>> module = from_vllm( ... model, ... tokenizer=tokenizer, ... from_text=True, ... generate=True ... ) - >>> input_data = LLMData(text=NonTensorStack("Hello, world!"), batch_size=1) + >>> input_data = LLMData(text=NonTensorStack("Hello, world!", "This is another text"), batch_size=1) >>> output_data = module(input_data) >>> print(output_data.text_response) @@ -121,8 +133,13 @@ def from_vllm( if tokenizer is None: tokenizer = model.get_tokenizer() - # retrieve the padding value - we use this to make the log-probs of pad token = 1 - padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0] + if pad_output: + # retrieve the padding value - we use this to make the log-probs of pad token = 1 + padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0] + elif not from_text: + raise TypeError("passing tokens without padding isn't supported at the moment.") + else: + padding_value = None if from_text: if generate: @@ -144,7 +161,7 @@ def from_vllm( return_log_probs=return_log_probs, pad_output=pad_output, ) - return Seq(module_dict, inplace=True) + return CategoricalSequential(module_dict, inplace=inplace) def to_list(tokens, attention_mask): @@ -172,6 +189,11 @@ def to_list(tokens, attention_mask): return NonTensorStack(*tokens) +def _prepare(td): + out = TensorDict(batch_size=td.batch_size, device=td.device) + return out.update(td) + + def _from_vllm_generate_text( *, tokenizer, @@ -195,52 +217,63 @@ def _from_vllm_generate_text( module_dict = {} if device: module_dict["clear_device"] = _maybe_clear_device - if not tokenizer_kwargs: - tokenizer_kwargs = {} - if not tokenizer_kwargs.setdefault("return_attention_mask", True): - raise RuntimeError - if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": - raise RuntimeError - if tokenizer_kwargs.setdefault("padding", True) not in (True,): - raise RuntimeError - if tokenizer_kwargs.setdefault("padding_side", "left") != "left": - raise RuntimeError - def tokenize(td): - out = TensorDict(batch_size=td.batch_size, device=td.device) - tokens_in = TensorDict.from_dict( - tokenizer(td.get(text_key), **tokenizer_kwargs) + if pad_output: + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + + def tokenize(td): + out = TensorDict(batch_size=td.batch_size, device=td.device) + text = td.get(text_key) + if not isinstance(text, (list, str)): + text = text.tolist() + tokens_in = TensorDict.from_dict(tokenizer(text, **tokenizer_kwargs)) + out.set("tokens_in", tokens_in) + return out + + module_dict["encode"] = WrapModule( + tokenize, + in_keys=[text_key], + out_keys=["tokens_in"], ) - out.set("tokens_in", tokens_in) - return out - - module_dict["encode"] = WrapModule( - tokenize, - in_keys=[text_key], - out_keys=["tokens_in"], - ) - - module_dict["to_list"] = Mod( - to_list, - in_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], - out_keys=[("tokens_in", "input_ids_list")], - strict=False, - ) + module_dict["to_list"] = Mod( + to_list, + in_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + out_keys=[("tokens_in", "input_ids_list")], + strict=False, + ) + else: + module_dict["prepare"] = WrapModule( + _prepare, + ) if generate_kwargs is None: generate_kwargs = {} - generate_kwargs.setdefault("detokenize", False) + generate_kwargs.setdefault("detokenize", not pad_output) generate_kwargs.setdefault("prompt_logprobs", False) generate_kwargs.setdefault("logprobs", return_log_probs) sampling_params = SamplingParams(**generate_kwargs) + if pad_output: + in_keys = { + "prompt_token_ids": ("tokens_in", "input_ids_list"), + } + else: + in_keys = [text_key] + module_dict["generate"] = Mod( model, method="generate", method_kwargs={"sampling_params": sampling_params}, - in_keys={ - "prompt_token_ids": ("tokens_in", "input_ids_list"), - }, + in_keys=in_keys, out_keys=["tokens_out"], out_to_in_map=True, strict=True, @@ -248,34 +281,45 @@ def tokenize(td): def get_output_tokens_and_log_probs(td, padding_value=padding_value): td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) - if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict): + if td.ndim and not isinstance(td, LazyStackedTensorDict): td = lazy_stack(list(td.unbind(0))) # When not generate, we don't want to overwrite this tokens_response_td = td["tokens_out"].outputs._tensordict.select( - "token_ids", "logprobs", strict=False + "text", "token_ids", "logprobs", strict=False ) if pad_output: tokens_response_td = tokens_response_td.densify( layout=torch.strided ).to_padded_tensor(padding=padding_value) tokens_response_td.rename_key_("token_ids", "tokens_response") + tokens_response_td.rename_key_("text", "text_response") + if not pad_output: + # Then we can safely move the input tokens, but otherwise they + # may need padding + tokens_response_td.update( + td["tokens_out"].select("prompt_token_ids") + ).rename_key_("prompt_token_ids", token_key) + if return_log_probs: - padded_values = tokens_response_td["tokens_response"] == padding_value tokens_response_td.rename_key_("logprobs", "log_probs") - if padded_values.any(): - lps = tokens_response_td["log_probs"] - lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) - tokens_response_td["log_probs"] = lps + if pad_output: + padded_values = tokens_response_td["tokens_response"] == padding_value + if padded_values.any(): + lps = tokens_response_td["log_probs"] + lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) + tokens_response_td["log_probs"] = lps + td.update(tokens_response_td) return td module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs - module_dict["decode"] = Mod( - tokenizer.batch_decode, - in_keys=["tokens_response"], - out_keys=["text_response"], - ) + if pad_output: + module_dict["decode"] = Mod( + tokenizer.batch_decode, + in_keys=["tokens_response"], + out_keys=["text_response"], + ) if device: module_dict["to_source_device"] = _maybe_set_device @@ -283,23 +327,27 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value): in_keys = [ "log_probs", "tokens_response", - ("tokens_in", "input_ids"), - ("tokens_in", "attention_mask"), "text_response", + token_key, + "tokens_in", ] out_keys = [ "log_probs", "tokens_response", + "text_response", token_key, attention_mask_key, - "text_response", ] def format_td(td): td = td.select(*in_keys, strict=False) - td.rename_key_(("tokens_in", "input_ids"), token_key) - td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key) - del td["tokens_in"] + # We might already have the tokens + if ("tokens_in", "input_ids") in td: + td.rename_key_(("tokens_in", "input_ids"), token_key) + if "tokens_in" in td: + if ("tokens_in", "attention_mask") in td: + td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key) + del td["tokens_in"] return td module_dict["format"] = WrapModule( @@ -387,12 +435,13 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value): ).to_padded_tensor(padding=padding_value) tokens_response_td.rename_key_("token_ids", "tokens_response") if return_log_probs: - padded_values = tokens_response_td["tokens_response"] == padding_value tokens_response_td.rename_key_("logprobs", "log_probs") - if padded_values.any(): - lps = tokens_response_td["log_probs"] - lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) - tokens_response_td["log_probs"] = lps + if pad_output: + padded_values = tokens_response_td["tokens_response"] == padding_value + if padded_values.any(): + lps = tokens_response_td["log_probs"] + lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) + tokens_response_td["log_probs"] = lps td.update(tokens_response_td) return td @@ -457,27 +506,66 @@ def _from_vllm_logprobs_text( tokenizer_kwargs = {} if not tokenizer_kwargs.setdefault("return_attention_mask", True): raise RuntimeError - if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": - raise RuntimeError - if tokenizer_kwargs.setdefault("padding", True) not in (True,): + if pad_output: + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + if tokenizer_kwargs.setdefault("padding", pad_output) not in (pad_output,): raise RuntimeError if tokenizer_kwargs.setdefault("padding_side", "left") != "left": raise RuntimeError + # Contrary to the case with generate, we always need the tokenizer here to understand what length is the response + # To do this, we tokenize the prompt+response as well as the prompt, and then recover the response by taking + # the last slice of the tokenized prompt+response (ie, removing the tokens of the prompt). + # We need to do this rather than tokenizing the response because we want to ensure that there is no + # additional tokens, but there is defo room for improvement. def tokenize(td): out = TensorDict(batch_size=td.batch_size, device=td.device) text_prompt = td.get(text_key) + if not isinstance(text_prompt, list): + text_prompt = text_prompt.tolist() text_response = td.get("text_response") - tokens_in = tokenizer( - [_x + _y for _x, _y in zip(text_prompt, text_response)], **tokenizer_kwargs - ) + if not isinstance(text_response, list): + text_response = text_response.tolist() + text = [_x + _y for _x, _y in zip(text_prompt, text_response)] + tokens_in = tokenizer(text, **tokenizer_kwargs) tokens_prompt = tokenizer(text_prompt, **tokenizer_kwargs) - tokens_in = TensorDict.from_dict(tokens_in) + if not pad_output: + tokens_in = TensorDict( + input_ids=NonTensorStack(*tokens_in["input_ids"]), + attention_mask=NonTensorStack(*tokens_in["attention_mask"]), + batch_size=td.batch_size, + ) + prompt_input_ids = tokens_prompt["input_ids"] + prompt_attention_mask = tokens_prompt["attention_mask"] + response_input_ids = [] + for token_total, token_prompt in zip( + tokens_in["input_ids"], prompt_input_ids + ): + response_input_ids.append(token_total[len(token_prompt) :]) + response_input_ids = NonTensorStack(*response_input_ids) + response_attention_mask = [] + for mask, mask_prompt in zip( + tokens_in["attention_mask"], prompt_attention_mask + ): + response_attention_mask.append(mask[len(mask_prompt) :]) + response_attention_mask = NonTensorStack(*response_attention_mask) + tokens_response = TensorDict( + input_ids=response_input_ids, + attention_mask=response_attention_mask, + batch_size=td.batch_size, + ) + else: + tokens_in = TensorDict.from_dict(tokens_in) + tokens_prompt = TensorDict.from_dict(tokens_prompt) + tokens_response = tokens_in.apply( + lambda total_tokens, input_tokens: total_tokens[ + :, input_tokens.shape[1] : + ], + tokens_prompt, + ) + out["tokens_in"] = tokens_in - tokens_response = tokens_in.apply( - lambda total_tokens, input_tokens: total_tokens[:, input_tokens.shape[1] :], - TensorDict.from_dict(tokens_prompt), - ) out["tokens_response"] = tokens_response return out @@ -495,9 +583,17 @@ def tokenize(td): strict=False, ) + if tokenizer is not None: + in_keys = { + "prompt_token_ids": ("tokens_in", "input_ids_list"), + } + else: + in_keys = [text_key] + if generate_kwargs is None: generate_kwargs = {} - generate_kwargs.setdefault("detokenize", False) + # We use the tokens when we pad + generate_kwargs.setdefault("detokenize", not pad_output) generate_kwargs.setdefault("prompt_logprobs", True) generate_kwargs.setdefault("logprobs", return_log_probs) generate_kwargs["max_tokens"] = 1 @@ -507,9 +603,7 @@ def tokenize(td): model, method="generate", method_kwargs={"sampling_params": sampling_params}, - in_keys={ - "prompt_token_ids": ("tokens_in", "input_ids_list"), - }, + in_keys=in_keys, out_keys=["tokens_out"], out_to_in_map=True, strict=True, @@ -517,35 +611,61 @@ def tokenize(td): def get_output_tokens_and_log_probs(td, padding_value=padding_value): td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) - if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict): + if td.ndim and not isinstance(td, LazyStackedTensorDict): td = lazy_stack(list(td.unbind(0))) - td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs + td.update( + td["tokens_out"].select("prompt_token_ids", "prompt_logprobs", strict=False) + ) + del td["tokens_out"] return td module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs - def translate_lps(tokens_response, x): + def translate_lps(tokens_response, lps): # we disregard the tokens from the prompt to focus on those of the response - padded = tokens_response == padding_value - lps = x[..., -tokens_response.shape[-1] :] - lps = torch.where(~padded, lps, 0.0) + if isinstance(lps, torch.Tensor): + lps = lps[..., -tokens_response.shape[-1] :] + else: + # We use a nested tensor as it will be unbound during writing + lps = torch.nested.nested_tensor( + [lp[..., -len(tr) :] for lp, tr in zip(lps, tokens_response)] + ) + if pad_output: + padded = tokens_response == padding_value + lps = torch.where(~padded, lps, 0.0) return lps module_dict["translate_lps"] = Mod( translate_lps, in_keys=[("tokens_response", "input_ids"), "prompt_logprobs"], out_keys=["log_probs"], + get_kwargs={ + "as_list": not pad_output, + "as_padded_tensor": pad_output, + "padding_side": "left", + }, ) if device: module_dict["to_source_device"] = _maybe_set_device - module_dict["format"] = Mod( - lambda *x: x, - in_keys=["log_probs", ("tokens_response", "input_ids")], - out_keys=["log_probs", "tokens_response"], - strict=False, - inplace="empty", + in_keys = ["log_probs", ("tokens_response", "input_ids")] + out_keys = ["log_probs", "tokens_response"] + + def format_td(td): + td = td.select(*in_keys, strict=False) + td.rename_key_(("tokens_response", "input_ids"), "tokens_response") + if not pad_output: + # Turn the list of tokens in a tensor + td["tokens_response"] = torch.nested.nested_tensor( + [torch.tensor(val) for val in td["tokens_response"]] + ) + return td + + module_dict["format"] = WrapModule( + format_td, + in_keys=in_keys, + out_keys=out_keys, ) return module_dict @@ -636,8 +756,8 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value): def translate_lps(tokens_response, lps): # we disregard the tokens from the prompt to focus on those of the response - padded = tokens_response == padding_value lps = lps[..., -tokens_response.shape[-1] :] + padded = tokens_response == padding_value lps = torch.where(~padded, lps, 0.0) return lps @@ -718,7 +838,7 @@ def get_logprob(output): @classmethod def from_request_output(cls, requests): - out = maybe_dense_stack( + out = lazy_stack( [ cls( request_id=request.request_id, diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 81d96fb7cec..ff5899bea0c 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -120,6 +120,9 @@ class SafeModule(TensorDictModule): If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is ``False``. + inplace (bool or str, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty + :class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the + output preserves type, device and batch-size). Defaults to `True`. Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. The domain spec can be passed along if needed. TensorDictModule support functional and regular :obj:`nn.Module` objects. In the functional @@ -200,8 +203,9 @@ def __init__( out_keys: Iterable[str], spec: TensorSpec | None = None, safe: bool = False, + inplace: bool | str = True, ): - super().__init__(module, in_keys, out_keys) + super().__init__(module, in_keys, out_keys, inplace=inplace) self.register_spec(safe=safe, spec=spec) def register_spec(self, safe, spec): diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 8157869af27..6669ffdcbd9 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -28,6 +28,9 @@ class SafeSequential(TensorDictSequential, SafeModule): Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is ``True`` AND if the stack does not have the required keys, then SafeSequential will scan through the sub-tensordicts looking for those that have the required keys, if any. + inplace (bool or str, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty + :class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the + output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules). TensorDictSequence supports functional, modular and vmap coding: Examples: @@ -107,6 +110,7 @@ def __init__( self, *modules: TensorDictModule, partial_tolerant: bool = False, + inplace: bool | str | None = None, ): self.partial_tolerant = partial_tolerant @@ -124,4 +128,5 @@ def __init__( module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys, + inplace=inplace, ) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 748b6ac3c2b..8de86b35b7f 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -559,7 +559,6 @@ def _get_cur_log_prob(self, tensordict): f"tensordict stored {self.tensor_keys.action} requires grad." ) log_prob = dist.log_prob(action) - assert log_prob.requires_grad else: raise NotImplementedError( "Only probabilistic modules from tensordict.nn are currently supported. "