From c25000186b79bdadd42a587195b74cc14711222d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 15:03:07 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_env.py | 98 +++++++++++++++++++++++++++ torchrl/envs/custom/llm.py | 5 ++ torchrl/envs/transforms/rlhf.py | 29 ++++++-- torchrl/envs/transforms/transforms.py | 6 +- 4 files changed, 129 insertions(+), 9 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index f4759f9a119..b2c26914ca6 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4700,6 +4700,104 @@ def policy(td): r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[])) assert r.ndim == 1 + @pytest.mark.parametrize( + "str2str,stack_method", + [ + [True, None], + [False, "as_padded_tensor"], + # TODO: a bit experimental, fails with check_env_specs + # [False, "as_nested_tensor"], + [False, None], + ], + ) + @pytest.mark.parametrize("batched", [True, False]) + @pytest.mark.parametrize("device", [None, "cpu"]) + @pytest.mark.parametrize("batch_size", [0, 4]) + @pytest.mark.parametrize("repeats", [3]) + def test_llm_from_dataloader_repeats( + self, str2str, batched, stack_method, device, batch_size, repeats + ): + if str2str: + kwargs = { + "dataloader": self.DummyDataLoader(batch_size=batch_size), + "data_keys": ["observation"], + "example_data": "a string!", + "repeats": repeats, + } + else: + if stack_method is None: + stack_method = as_padded_tensor + kwargs = { + "dataloader": self.DummyTensorDataLoader( + padding=True, batch_size=batch_size + ), + "data_keys": ["observation"], + "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], + "stack_method": stack_method, + "repeats": repeats, + } + kwargs.update({"str2str": str2str, "device": device}) + env = LLMEnv.from_dataloader(**kwargs) + assert env.transform.repeats == repeats + + max_steps = 3 + env.append_transform(StepCounter(max_steps=max_steps)) + + def policy(td): + if str2str: + if not td.shape: + td["action"] = "" + else: + td["action"] = NonTensorStack( + *["" for _ in range(td.shape[0])] + ) + else: + td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64) + return td + + if batched: + r = env.rollout( + 100, + policy, + tensordict=TensorDict(batch_size=[3]), + break_when_any_done=False, + ) + else: + r = env.rollout(100, policy, break_when_any_done=False) + # check that r at reset is always the same + r_reset = r[..., ::max_steps] + if not batched: + if str2str: + assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"] + assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"] + assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"] + else: + assert ( + r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"] + ).all() + assert ( + r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"] + ).all() + assert ( + r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"] + ).any() + else: + # When batched, each block contains the 3 reset packs + if str2str: + assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"] + assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"] + assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"] + else: + assert ( + r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"] + ).all() + assert ( + r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"] + ).all() + assert ( + r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"] + ).any() + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index c62f2c015a2..fc94db38216 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -142,6 +142,7 @@ def from_dataloader( example_data: Any = None, stack_method: Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"] = None, + repeats: int | None = None, ) -> LLMEnv: """Creates an LLMEnv instance from a dataloader. @@ -165,6 +166,9 @@ def from_dataloader( example_data (Any, optional): Example data to use for initializing the primer. Defaults to ``None``. stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``None``. + repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in + situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo + samples (rather than an advantage module). Returns: LLMEnv: The created LLMEnv instance. @@ -178,6 +182,7 @@ def from_dataloader( data_specs=data_specs, example_data=example_data, stack_method=stack_method, + repeats=repeats, ) env = LLMEnv( str2str=str2str, diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 997b1af039a..b4c8712299e 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -103,6 +103,9 @@ class DataLoadingPrimer(TensorDictPrimer): auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the tensordict returned by the transform will be automatically determined assuming that there is a single batch dimension. + repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in + situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo + samples (rather than an advantage module). Attributes: dataloader (Iterable[Any]): The dataloader to load data from. @@ -359,15 +362,21 @@ def __init__( | Literal["as_nested_tensor", "as_padded_tensor"] = None, use_buffer: bool | None = None, auto_batch_size: bool = True, + repeats: int | None = None, ): self.dataloader = dataloader - if getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None: + if repeats is None: + repeats = 0 + self.repeats = repeats + if ( + getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None + ) or repeats > 0: use_buffer = True self.use_buffer = use_buffer # No auto_batch_size if we know we have a single element self.auto_batch_size = auto_batch_size and ( - getattr(dataloader, "dataloader", 1) > 0 + getattr(dataloader, "batch_size", 1) > 0 ) self.endless_dataloader = self._endless_iter(self.dataloader) if primers is None: @@ -420,11 +429,13 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): if not reset.any(): raise RuntimeError("reset must have at least one True value.") if reset.ndim > 0: - return self.stack_method( - [self._load_from_dataloader() for i in range(reset.sum())] - ) + loaded = [self._load_from_dataloader() for i in range(reset.sum())] + return self.stack_method(loaded) + if self.use_buffer and len(self._queue) > 0: - return self._queue.popleft() + result = self._queue.popleft() + return result + data = next(self.endless_dataloader) # Some heuristic here: # if data is a map, assume its keys match the keys in spec @@ -450,7 +461,11 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): f"Unrecognized data type: {type(data)} with keys {self.data_keys}." ) if self.use_buffer: - self._queue.extend(out.unbind(0)) + if not out.ndim: + out = out.unsqueeze(0) + self._queue.extend( + [d for d in out.unbind(0) for _ in range(max(1, self.repeats))] + ) return self._queue.popleft() return out diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index eff19ef1b61..fa0ca17317c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -7352,7 +7352,9 @@ def _reset( else: # It may be the case that reset did not provide a done state, in which case # we fall back on the spec - done = self.parent.output_spec["full_done_spec", entry_name].zero() + done = self.parent.output_spec_unbatched[ + "full_done_spec", entry_name + ].zero(tensordict_reset.shape) reset = torch.ones_like(done) step_count = tensordict.get(step_count_key, default=None) @@ -7362,7 +7364,7 @@ def _reset( step_count = step_count.to(reset.device, non_blocking=True) # zero the step count if reset is needed - step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0) + step_count = torch.where(~reset, step_count.expand_as(reset), 0) tensordict_reset.set(step_count_key, step_count) if self.max_steps is not None: truncated = step_count >= self.max_steps