diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index f6f60b9e7..78c61365c 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -78,6 +78,8 @@ from ray.util.placement_group import PlacementGroup, placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from rich.pretty import pprint +from tensordict import tensorclass +from torchrl.data import RayReplayBuffer, ReplayBuffer from tqdm import tqdm from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler from transformers.integrations import HfDeepSpeedConfig @@ -487,6 +489,32 @@ def __post_init__(self): assert len(self.tools) == len(set(self.tools)), "Duplicate tools are not allowed" +@tensorclass +class PackedLogProbSequence: + query_response: torch.Tensor + """packed query and response (batch_size, pack_length)""" + attention_mask: torch.Tensor + """3D attention mask for packed sequences (batch_size, pack_length, pack_length); + it basically uses a intra-document mask for each query response pair; + see https://huggingface.co/blog/sirluk/llm-sequence-packing for more details + """ + response_mask: torch.Tensor + """response mask for packed sequences (batch_size, pack_length)""" + ref_logprob: Optional[torch.Tensor] + """packed rewards (batch_size, pack_length)""" + old_logprob: Optional[torch.Tensor] + """packed rewards (batch_size, pack_length)""" + + tool_mask: Optional[torch.Tensor] = None + """tool mask for packed sequences (batch_size, pack_length)""" + position_id: Optional[torch.Tensor] = None + """packed position ids (batch_size, pack_length)""" + advantage: Optional[torch.Tensor] = None + """packed advantages (batch_size, pack_length) (to be filled in by the main process)""" + reward: Optional[torch.Tensor] = None + """packed rewards (batch_size, pack_length)""" + + def next_batch(dataset_indices: List[int], dataset: datasets.Dataset) -> Batch: """Extract next batch of data based on indices.""" data_next = dataset[dataset_indices] @@ -593,6 +621,7 @@ def from_pretrained( beaker_config: BeakerRuntimeConfig, wandb_url: str, tokenizer: PreTrainedTokenizer, + replay_buffer: ReplayBuffer, ): # ------------------------------------------------------------ # Monkey patch to load checkpoints with `weights_only=False` @@ -615,6 +644,7 @@ def load(self, path: str, map_location=None): self.model_config = model_config self.beaker_config = beaker_config self.wandb_url = wandb_url + self.replay_buffer = replay_buffer torch.cuda.set_device(self.local_rank) self.device = torch.device(self.local_rank) @@ -946,6 +976,7 @@ def train( # if we have multiple minibatches, we need to calculate the old logprobs for each minibatch # following gtrl scripts in just doing this on the current active policy, rather than use the logprobs # from the generator (note that async mode means these are a bit diff!) + replay_buffer_list = [] old_logprobs = [None for _ in range(len(collated_query_responses))] if num_mini_batches > 1: with Timer("Old logprobs Calculation", noop=self.rank != 0): @@ -973,6 +1004,20 @@ def train( old_logprobs[i] = old_logprob torch.cuda.empty_cache() + replay_buffer_list.append( + PackedLogProbSequence( + query_response=query_response, + attention_mask=attention_mask, + response_mask=response_mask, + tool_mask=tool_mask, + advantage=collated_advantages[i], + position_id=position_id, + old_logprob=old_logprob, + ref_logprob=ref_logprob, + ) + ) + self.replay_buffer.extend(replay_buffer_list) + local_step = 0 # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): @@ -1965,6 +2010,7 @@ def create_model_and_optimizer( inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, evaluation_inference_results_Q: ray_queue.Queue, + replay_buffer: ReplayBuffer, ) -> tuple[ModelGroup, list[vllm_utils3.LLMRayActor], dict, int, int]: """Create the model, optimizer, and vLLM engines.""" # Create placement group @@ -1975,7 +2021,7 @@ def create_model_and_optimizer( policy_group = ModelGroup(pg, PolicyTrainerRayProcess, args.num_learners_per_node, args.single_gpu_mode) wandb_url = wandb.run.get_url() if args.with_tracking else None inits.extend( - model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) + model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer, replay_buffer) for model in policy_group.models ) @@ -2109,7 +2155,11 @@ def split_and_insert_batch( def load_data_from_packing_thread( - packed_sequences_Q: Queue, num_total_tokens: int, stop_event: threading.Event, health_check_fn: Callable[[], None] + packed_sequences_Q: Queue, + num_total_tokens: int, + stop_event: threading.Event, + health_check_fn: Callable[[], None], + replay_buffer: ReplayBuffer, ): """Get the packed sequences with advantages from the packing thread.""" with Timer("[Main Thread] 📦 Getting packed sequences from thread") as timer: @@ -2595,6 +2645,7 @@ def run_training( generate_metrics_Q, weight_sync_metrics_Q, actor_manager: ActorManager, + replay_buffer: ReplayBuffer, checkpoint_state=None, ): if resume_training_step > 1: @@ -2808,6 +2859,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): param_prompt_Q = ray_queue.Queue(maxsize=queue_size) # We don't care if we ever hit the max, so we let the queue be unbounded. evaluation_inference_results_Q = ray_queue.Queue() + replay_buffer = RayReplayBuffer(batch_size=128) policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager = ( create_model_and_optimizer( @@ -2820,6 +2872,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): inference_results_Q, param_prompt_Q, evaluation_inference_results_Q, + replay_buffer, ) ) @@ -2886,6 +2939,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): generate_metrics_Q, weight_sync_metrics_Q, actor_manager, + replay_buffer, checkpoint_state, ) finally: diff --git a/pyproject.toml b/pyproject.toml index 5ec2c83b3..254375a4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "immutabledict==1.2.0", "flash-attn>=2.8.0.post2; platform_system != 'Darwin'", "liger-kernel>=0.5.4; platform_system != 'Darwin'", + "torchrl>=0.9.2", ] [build-system] diff --git a/uv.lock b/uv.lock index e6c858c79..039a2aaae 100644 --- a/uv.lock +++ b/uv.lock @@ -2206,6 +2206,7 @@ dependencies = [ { name = "tensorboard" }, { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, + { name = "torchrl" }, { name = "transformers" }, { name = "vllm" }, { name = "wandb" }, @@ -2258,6 +2259,7 @@ requires-dist = [ { name = "tensorboard", specifier = ">=2.18.0" }, { name = "torch", marker = "sys_platform != 'darwin'", specifier = ">=2.7.0,<2.8", index = "https://download.pytorch.org/whl/cu128" }, { name = "torch", marker = "sys_platform == 'darwin'", specifier = ">=2.7.0,<2.8" }, + { name = "torchrl", specifier = ">=0.9.2" }, { name = "transformers", specifier = ">=4.52.4,<4.54.0" }, { name = "uvicorn", marker = "extra == 'code'", specifier = ">=0.20.0" }, { name = "vllm", specifier = "==0.9.1" }, @@ -2457,6 +2459,57 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/34/98/f5196ba0f4105a4790cec8c6671cf676c96dfa29bfedfe3c4f112bf4e6ad/opentelemetry_semantic_conventions_ai-0.4.9-py3-none-any.whl", hash = "sha256:71149e46a72554ae17de46bca6c11ba540c19c89904bd4cc3111aac6edf10315", size = 5617, upload-time = "2025-05-16T10:20:53.062Z" }, ] +[[package]] +name = "orjson" +version = "3.11.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/be/4d/8df5f83256a809c22c4d6792ce8d43bb503be0fb7a8e4da9025754b09658/orjson-3.11.3.tar.gz", hash = "sha256:1c0603b1d2ffcd43a411d64797a19556ef76958aef1c182f22dc30860152a98a", size = 5482394, upload-time = "2025-08-26T17:46:43.171Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/64/4a3cef001c6cd9c64256348d4c13a7b09b857e3e1cbb5185917df67d8ced/orjson-3.11.3-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:29cb1f1b008d936803e2da3d7cba726fc47232c45df531b29edf0b232dd737e7", size = 238600, upload-time = "2025-08-26T17:44:36.875Z" }, + { url = "https://files.pythonhosted.org/packages/10/ce/0c8c87f54f79d051485903dc46226c4d3220b691a151769156054df4562b/orjson-3.11.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97dceed87ed9139884a55db8722428e27bd8452817fbf1869c58b49fecab1120", size = 123526, upload-time = "2025-08-26T17:44:39.574Z" }, + { url = "https://files.pythonhosted.org/packages/ef/d0/249497e861f2d438f45b3ab7b7b361484237414945169aa285608f9f7019/orjson-3.11.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:58533f9e8266cb0ac298e259ed7b4d42ed3fa0b78ce76860626164de49e0d467", size = 128075, upload-time = "2025-08-26T17:44:40.672Z" }, + { url = "https://files.pythonhosted.org/packages/e5/64/00485702f640a0fd56144042a1ea196469f4a3ae93681871564bf74fa996/orjson-3.11.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c212cfdd90512fe722fa9bd620de4d46cda691415be86b2e02243242ae81873", size = 130483, upload-time = "2025-08-26T17:44:41.788Z" }, + { url = "https://files.pythonhosted.org/packages/64/81/110d68dba3909171bf3f05619ad0cf187b430e64045ae4e0aa7ccfe25b15/orjson-3.11.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ff835b5d3e67d9207343effb03760c00335f8b5285bfceefd4dc967b0e48f6a", size = 132539, upload-time = "2025-08-26T17:44:43.12Z" }, + { url = "https://files.pythonhosted.org/packages/79/92/dba25c22b0ddfafa1e6516a780a00abac28d49f49e7202eb433a53c3e94e/orjson-3.11.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f5aa4682912a450c2db89cbd92d356fef47e115dffba07992555542f344d301b", size = 135390, upload-time = "2025-08-26T17:44:44.199Z" }, + { url = "https://files.pythonhosted.org/packages/44/1d/ca2230fd55edbd87b58a43a19032d63a4b180389a97520cc62c535b726f9/orjson-3.11.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7d18dd34ea2e860553a579df02041845dee0af8985dff7f8661306f95504ddf", size = 132966, upload-time = "2025-08-26T17:44:45.719Z" }, + { url = "https://files.pythonhosted.org/packages/6e/b9/96bbc8ed3e47e52b487d504bd6861798977445fbc410da6e87e302dc632d/orjson-3.11.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d8b11701bc43be92ea42bd454910437b355dfb63696c06fe953ffb40b5f763b4", size = 131349, upload-time = "2025-08-26T17:44:46.862Z" }, + { url = "https://files.pythonhosted.org/packages/c4/3c/418fbd93d94b0df71cddf96b7fe5894d64a5d890b453ac365120daec30f7/orjson-3.11.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:90368277087d4af32d38bd55f9da2ff466d25325bf6167c8f382d8ee40cb2bbc", size = 404087, upload-time = "2025-08-26T17:44:48.079Z" }, + { url = "https://files.pythonhosted.org/packages/5b/a9/2bfd58817d736c2f63608dec0c34857339d423eeed30099b126562822191/orjson-3.11.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fd7ff459fb393358d3a155d25b275c60b07a2c83dcd7ea962b1923f5a1134569", size = 146067, upload-time = "2025-08-26T17:44:49.302Z" }, + { url = "https://files.pythonhosted.org/packages/33/ba/29023771f334096f564e48d82ed855a0ed3320389d6748a9c949e25be734/orjson-3.11.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f8d902867b699bcd09c176a280b1acdab57f924489033e53d0afe79817da37e6", size = 135506, upload-time = "2025-08-26T17:44:50.558Z" }, + { url = "https://files.pythonhosted.org/packages/39/62/b5a1eca83f54cb3aa11a9645b8a22f08d97dbd13f27f83aae7c6666a0a05/orjson-3.11.3-cp310-cp310-win32.whl", hash = "sha256:bb93562146120bb51e6b154962d3dadc678ed0fce96513fa6bc06599bb6f6edc", size = 136352, upload-time = "2025-08-26T17:44:51.698Z" }, + { url = "https://files.pythonhosted.org/packages/e3/c0/7ebfaa327d9a9ed982adc0d9420dbce9a3fec45b60ab32c6308f731333fa/orjson-3.11.3-cp310-cp310-win_amd64.whl", hash = "sha256:976c6f1975032cc327161c65d4194c549f2589d88b105a5e3499429a54479770", size = 131539, upload-time = "2025-08-26T17:44:52.974Z" }, + { url = "https://files.pythonhosted.org/packages/cd/8b/360674cd817faef32e49276187922a946468579fcaf37afdfb6c07046e92/orjson-3.11.3-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9d2ae0cc6aeb669633e0124531f342a17d8e97ea999e42f12a5ad4adaa304c5f", size = 238238, upload-time = "2025-08-26T17:44:54.214Z" }, + { url = "https://files.pythonhosted.org/packages/05/3d/5fa9ea4b34c1a13be7d9046ba98d06e6feb1d8853718992954ab59d16625/orjson-3.11.3-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:ba21dbb2493e9c653eaffdc38819b004b7b1b246fb77bfc93dc016fe664eac91", size = 127713, upload-time = "2025-08-26T17:44:55.596Z" }, + { url = "https://files.pythonhosted.org/packages/e5/5f/e18367823925e00b1feec867ff5f040055892fc474bf5f7875649ecfa586/orjson-3.11.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00f1a271e56d511d1569937c0447d7dce5a99a33ea0dec76673706360a051904", size = 123241, upload-time = "2025-08-26T17:44:57.185Z" }, + { url = "https://files.pythonhosted.org/packages/0f/bd/3c66b91c4564759cf9f473251ac1650e446c7ba92a7c0f9f56ed54f9f0e6/orjson-3.11.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b67e71e47caa6680d1b6f075a396d04fa6ca8ca09aafb428731da9b3ea32a5a6", size = 127895, upload-time = "2025-08-26T17:44:58.349Z" }, + { url = "https://files.pythonhosted.org/packages/82/b5/dc8dcd609db4766e2967a85f63296c59d4722b39503e5b0bf7fd340d387f/orjson-3.11.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d7d012ebddffcce8c85734a6d9e5f08180cd3857c5f5a3ac70185b43775d043d", size = 130303, upload-time = "2025-08-26T17:44:59.491Z" }, + { url = "https://files.pythonhosted.org/packages/48/c2/d58ec5fd1270b2aa44c862171891adc2e1241bd7dab26c8f46eb97c6c6f1/orjson-3.11.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dd759f75d6b8d1b62012b7f5ef9461d03c804f94d539a5515b454ba3a6588038", size = 132366, upload-time = "2025-08-26T17:45:00.654Z" }, + { url = "https://files.pythonhosted.org/packages/73/87/0ef7e22eb8dd1ef940bfe3b9e441db519e692d62ed1aae365406a16d23d0/orjson-3.11.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6890ace0809627b0dff19cfad92d69d0fa3f089d3e359a2a532507bb6ba34efb", size = 135180, upload-time = "2025-08-26T17:45:02.424Z" }, + { url = "https://files.pythonhosted.org/packages/bb/6a/e5bf7b70883f374710ad74faf99bacfc4b5b5a7797c1d5e130350e0e28a3/orjson-3.11.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9d4a5e041ae435b815e568537755773d05dac031fee6a57b4ba70897a44d9d2", size = 132741, upload-time = "2025-08-26T17:45:03.663Z" }, + { url = "https://files.pythonhosted.org/packages/bd/0c/4577fd860b6386ffaa56440e792af01c7882b56d2766f55384b5b0e9d39b/orjson-3.11.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2d68bf97a771836687107abfca089743885fb664b90138d8761cce61d5625d55", size = 131104, upload-time = "2025-08-26T17:45:04.939Z" }, + { url = "https://files.pythonhosted.org/packages/66/4b/83e92b2d67e86d1c33f2ea9411742a714a26de63641b082bdbf3d8e481af/orjson-3.11.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:bfc27516ec46f4520b18ef645864cee168d2a027dbf32c5537cb1f3e3c22dac1", size = 403887, upload-time = "2025-08-26T17:45:06.228Z" }, + { url = "https://files.pythonhosted.org/packages/6d/e5/9eea6a14e9b5ceb4a271a1fd2e1dec5f2f686755c0fab6673dc6ff3433f4/orjson-3.11.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f66b001332a017d7945e177e282a40b6997056394e3ed7ddb41fb1813b83e824", size = 145855, upload-time = "2025-08-26T17:45:08.338Z" }, + { url = "https://files.pythonhosted.org/packages/45/78/8d4f5ad0c80ba9bf8ac4d0fc71f93a7d0dc0844989e645e2074af376c307/orjson-3.11.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:212e67806525d2561efbfe9e799633b17eb668b8964abed6b5319b2f1cfbae1f", size = 135361, upload-time = "2025-08-26T17:45:09.625Z" }, + { url = "https://files.pythonhosted.org/packages/0b/5f/16386970370178d7a9b438517ea3d704efcf163d286422bae3b37b88dbb5/orjson-3.11.3-cp311-cp311-win32.whl", hash = "sha256:6e8e0c3b85575a32f2ffa59de455f85ce002b8bdc0662d6b9c2ed6d80ab5d204", size = 136190, upload-time = "2025-08-26T17:45:10.962Z" }, + { url = "https://files.pythonhosted.org/packages/09/60/db16c6f7a41dd8ac9fb651f66701ff2aeb499ad9ebc15853a26c7c152448/orjson-3.11.3-cp311-cp311-win_amd64.whl", hash = "sha256:6be2f1b5d3dc99a5ce5ce162fc741c22ba9f3443d3dd586e6a1211b7bc87bc7b", size = 131389, upload-time = "2025-08-26T17:45:12.285Z" }, + { url = "https://files.pythonhosted.org/packages/3e/2a/bb811ad336667041dea9b8565c7c9faf2f59b47eb5ab680315eea612ef2e/orjson-3.11.3-cp311-cp311-win_arm64.whl", hash = "sha256:fafb1a99d740523d964b15c8db4eabbfc86ff29f84898262bf6e3e4c9e97e43e", size = 126120, upload-time = "2025-08-26T17:45:13.515Z" }, + { url = "https://files.pythonhosted.org/packages/3d/b0/a7edab2a00cdcb2688e1c943401cb3236323e7bfd2839815c6131a3742f4/orjson-3.11.3-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8c752089db84333e36d754c4baf19c0e1437012242048439c7e80eb0e6426e3b", size = 238259, upload-time = "2025-08-26T17:45:15.093Z" }, + { url = "https://files.pythonhosted.org/packages/e1/c6/ff4865a9cc398a07a83342713b5932e4dc3cb4bf4bc04e8f83dedfc0d736/orjson-3.11.3-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:9b8761b6cf04a856eb544acdd82fc594b978f12ac3602d6374a7edb9d86fd2c2", size = 127633, upload-time = "2025-08-26T17:45:16.417Z" }, + { url = "https://files.pythonhosted.org/packages/6e/e6/e00bea2d9472f44fe8794f523e548ce0ad51eb9693cf538a753a27b8bda4/orjson-3.11.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b13974dc8ac6ba22feaa867fc19135a3e01a134b4f7c9c28162fed4d615008a", size = 123061, upload-time = "2025-08-26T17:45:17.673Z" }, + { url = "https://files.pythonhosted.org/packages/54/31/9fbb78b8e1eb3ac605467cb846e1c08d0588506028b37f4ee21f978a51d4/orjson-3.11.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f83abab5bacb76d9c821fd5c07728ff224ed0e52d7a71b7b3de822f3df04e15c", size = 127956, upload-time = "2025-08-26T17:45:19.172Z" }, + { url = "https://files.pythonhosted.org/packages/36/88/b0604c22af1eed9f98d709a96302006915cfd724a7ebd27d6dd11c22d80b/orjson-3.11.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e6fbaf48a744b94091a56c62897b27c31ee2da93d826aa5b207131a1e13d4064", size = 130790, upload-time = "2025-08-26T17:45:20.586Z" }, + { url = "https://files.pythonhosted.org/packages/0e/9d/1c1238ae9fffbfed51ba1e507731b3faaf6b846126a47e9649222b0fd06f/orjson-3.11.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc779b4f4bba2847d0d2940081a7b6f7b5877e05408ffbb74fa1faf4a136c424", size = 132385, upload-time = "2025-08-26T17:45:22.036Z" }, + { url = "https://files.pythonhosted.org/packages/a3/b5/c06f1b090a1c875f337e21dd71943bc9d84087f7cdf8c6e9086902c34e42/orjson-3.11.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd4b909ce4c50faa2192da6bb684d9848d4510b736b0611b6ab4020ea6fd2d23", size = 135305, upload-time = "2025-08-26T17:45:23.4Z" }, + { url = "https://files.pythonhosted.org/packages/a0/26/5f028c7d81ad2ebbf84414ba6d6c9cac03f22f5cd0d01eb40fb2d6a06b07/orjson-3.11.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:524b765ad888dc5518bbce12c77c2e83dee1ed6b0992c1790cc5fb49bb4b6667", size = 132875, upload-time = "2025-08-26T17:45:25.182Z" }, + { url = "https://files.pythonhosted.org/packages/fe/d4/b8df70d9cfb56e385bf39b4e915298f9ae6c61454c8154a0f5fd7efcd42e/orjson-3.11.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:84fd82870b97ae3cdcea9d8746e592b6d40e1e4d4527835fc520c588d2ded04f", size = 130940, upload-time = "2025-08-26T17:45:27.209Z" }, + { url = "https://files.pythonhosted.org/packages/da/5e/afe6a052ebc1a4741c792dd96e9f65bf3939d2094e8b356503b68d48f9f5/orjson-3.11.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:fbecb9709111be913ae6879b07bafd4b0785b44c1eb5cac8ac76da048b3885a1", size = 403852, upload-time = "2025-08-26T17:45:28.478Z" }, + { url = "https://files.pythonhosted.org/packages/f8/90/7bbabafeb2ce65915e9247f14a56b29c9334003536009ef5b122783fe67e/orjson-3.11.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9dba358d55aee552bd868de348f4736ca5a4086d9a62e2bfbbeeb5629fe8b0cc", size = 146293, upload-time = "2025-08-26T17:45:29.86Z" }, + { url = "https://files.pythonhosted.org/packages/27/b3/2d703946447da8b093350570644a663df69448c9d9330e5f1d9cce997f20/orjson-3.11.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eabcf2e84f1d7105f84580e03012270c7e97ecb1fb1618bda395061b2a84a049", size = 135470, upload-time = "2025-08-26T17:45:31.243Z" }, + { url = "https://files.pythonhosted.org/packages/38/70/b14dcfae7aff0e379b0119c8a812f8396678919c431efccc8e8a0263e4d9/orjson-3.11.3-cp312-cp312-win32.whl", hash = "sha256:3782d2c60b8116772aea8d9b7905221437fdf53e7277282e8d8b07c220f96cca", size = 136248, upload-time = "2025-08-26T17:45:32.567Z" }, + { url = "https://files.pythonhosted.org/packages/35/b8/9e3127d65de7fff243f7f3e53f59a531bf6bb295ebe5db024c2503cc0726/orjson-3.11.3-cp312-cp312-win_amd64.whl", hash = "sha256:79b44319268af2eaa3e315b92298de9a0067ade6e6003ddaef72f8e0bedb94f1", size = 131437, upload-time = "2025-08-26T17:45:34.949Z" }, + { url = "https://files.pythonhosted.org/packages/51/92/a946e737d4d8a7fd84a606aba96220043dcc7d6988b9e7551f7f6d5ba5ad/orjson-3.11.3-cp312-cp312-win_arm64.whl", hash = "sha256:0e92a4e83341ef79d835ca21b8bd13e27c859e4e9e4d7b63defc6e58462a3710", size = 125978, upload-time = "2025-08-26T17:45:36.422Z" }, +] + [[package]] name = "outlines" version = "0.1.11" @@ -3091,6 +3144,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, ] +[[package]] +name = "pyvers" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/39/c5432f541e6ea1d616dfd6ef42ce02792f7eb42dd44f5ed4439dbe17a58b/pyvers-0.1.0-py3-none-any.whl", hash = "sha256:065249805ae537ddf9a2d1a8dffc6d0a12474a347d2eaa2f35ebdae92c0c8199", size = 10092, upload-time = "2025-06-08T23:46:46.219Z" }, +] + [[package]] name = "pywin32" version = "310" @@ -3742,6 +3806,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363, upload-time = "2023-10-23T21:23:35.583Z" }, ] +[[package]] +name = "tensordict" +version = "0.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cloudpickle" }, + { name = "importlib-metadata" }, + { name = "numpy" }, + { name = "orjson" }, + { name = "packaging" }, + { name = "pyvers" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/94/5f55b11c879ec3aaf16cf69028c2126f9ebba4d195dc7e2ebaca5a400859/tensordict-0.9.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:070340b3ef65cd6017eb2279400967af1ec46c4e8515d3d5d0c037f9bccdc325", size = 737735, upload-time = "2025-07-14T12:52:04.459Z" }, + { url = "https://files.pythonhosted.org/packages/80/67/082f6f880509c4a1dd74fce517378c1384f372df4268ffa2036b18a4f100/tensordict-0.9.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:369557c2c643737e4476022fd1cdc9cea0448ea7b0d2ec38cd0130ed129f02e1", size = 422771, upload-time = "2025-07-14T12:52:06.97Z" }, + { url = "https://files.pythonhosted.org/packages/94/d8/20477c94c2229149f38b870c6b642d273f0ae776a378276c2f3a512b72d5/tensordict-0.9.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48fb63920ea9a6885fdf04f3782809274bc78618427a3731cc778ed7392bb8e6", size = 425771, upload-time = "2025-07-14T12:52:08.035Z" }, + { url = "https://files.pythonhosted.org/packages/e7/f9/e6fe6b3a79d373bd95a19e07eed04c734224f47b342ccc385c5d3f2456cb/tensordict-0.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:588fd33d731da08e962d5bdcd24fdc382f427d7446ece4f038232342138d0977", size = 470600, upload-time = "2025-07-14T12:52:09.406Z" }, + { url = "https://files.pythonhosted.org/packages/1c/05/413fcfacbe382d5994d448aaea5191e79904105a3295e737f6f3c1e2a473/tensordict-0.9.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a24526cfee13d59565bf1539ec056937bfeb566a97edf96ecee71851b3f377b5", size = 739188, upload-time = "2025-07-14T12:52:10.431Z" }, + { url = "https://files.pythonhosted.org/packages/68/88/e370331135b9e42601b9c2aaa230491ed59c8771131aedf5ace0e56e0fc2/tensordict-0.9.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2ca3b65b22fe3d7c4e1d4376d443b7f5b412e150eb97d1cdf0a846d6f554858a", size = 425211, upload-time = "2025-07-14T12:52:11.483Z" }, + { url = "https://files.pythonhosted.org/packages/4f/9b/9214660b086c29f8232ee9c9c84d8543eefd1397ced87e15dee1b7b6af14/tensordict-0.9.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:0382a618d5b3a40a86a82b538551b4ca751a3308ea06b5768f3dfe36ba7bfb2a", size = 427529, upload-time = "2025-07-14T12:52:12.811Z" }, + { url = "https://files.pythonhosted.org/packages/01/b5/e05fe096a051cd21745b4249b7594e29a3b2af8c80584d78f7ae2012a257/tensordict-0.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:b626c1c2926cd5f7d91eaa548c960b4559f683d5f2315c0dbea743f54412a98c", size = 473014, upload-time = "2025-07-14T12:52:14.161Z" }, + { url = "https://files.pythonhosted.org/packages/cc/dd/2e1d044a5a7901fb2a94686299e7291482f6c5623149edb777be887a0658/tensordict-0.9.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:46f59e9db37af740a3cafb08ff98df6f252ffaa0598a9354d2546dcbb0e97d92", size = 739949, upload-time = "2025-07-14T12:52:15.177Z" }, + { url = "https://files.pythonhosted.org/packages/4c/24/dd6d1874cda92749cd8ddd9a77f82a758f2cd5ece0d376417f662543badf/tensordict-0.9.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5a5b436784d2d19f94a7e6023da8430b93efc202f4daadcb247985569b384461", size = 424861, upload-time = "2025-07-14T12:52:16.239Z" }, + { url = "https://files.pythonhosted.org/packages/89/68/0d801f339ff43aaaecad26d87f74bacb20503d9bcb1bdf7f76486c1f0b86/tensordict-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ba30eacb5630dc1ee47159a66baa9dc6c5cec3f130b39b7589ff712f9e45777a", size = 430242, upload-time = "2025-07-14T12:52:17.643Z" }, + { url = "https://files.pythonhosted.org/packages/8a/07/3950b9c61be10dc1bf44c8f686c88dd715b803cb78a02b8f0c378c9f80b0/tensordict-0.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:040c7d79f97dcc7e1f99ce7914a01b1fb74b4861d11291d998e4dc47cad4c106", size = 474256, upload-time = "2025-07-14T12:52:19.18Z" }, +] + [[package]] name = "tiktoken" version = "0.9.0" @@ -3926,6 +4019,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/23/b73163ac06e5a724375df61a5b6c853861a825fe98e64388f277514153dd/torchaudio-2.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:275931c8a38ff84b5692df990506b41f18d0a0706574d96bc8456ad9e5fa85c8", size = 2493451, upload-time = "2025-04-23T14:46:46.456Z" }, ] +[[package]] +name = "torchrl" +version = "0.9.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cloudpickle" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "tensordict" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/fe/fe15e7ddbb050c94ae195f7ebf301fcd59731a2d396b6650502e194e6cba/torchrl-0.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a37d08ffac866832c77a71f7e8afdc4a354c8f729a471a2400ff2cea02af632", size = 1714742, upload-time = "2025-07-17T17:06:56.937Z" }, + { url = "https://files.pythonhosted.org/packages/97/df/b3e8de956945a2b58e9c80d868becd308b81c38ac79594991cfad376f762/torchrl-0.9.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d7969a9394d7a3d5d35bda18d23aeb10981f08b46eced0cdee976c145c6194b3", size = 1402892, upload-time = "2025-07-17T17:06:58.576Z" }, + { url = "https://files.pythonhosted.org/packages/fd/3a/23c52c98ecc40f1fbf6ffd48ed95f6027c933374886451fe7be421fbf43f/torchrl-0.9.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:08c772e113e085737fb956e2b69390ac1ea38924533e61b3c717e829024040cd", size = 1409679, upload-time = "2025-07-17T17:07:00.002Z" }, + { url = "https://files.pythonhosted.org/packages/e8/81/da6bdaf56aae725f6722d84f92e39b8aca86fc6c70c55d003e6891ea9130/torchrl-0.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:1aeb1a60736be00959644c2ce7c1a9d402c612bcef240574fd0905092173976f", size = 1370446, upload-time = "2025-07-17T17:07:01.174Z" }, + { url = "https://files.pythonhosted.org/packages/a9/21/583b8a4c55abd6e4945e4787f3b6a44214abc8ca1c50d82102e9473645da/torchrl-0.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:76b7ed6918cfdfeaf65a498df3ceca6b814960377860f77211216989ef30a31e", size = 1715997, upload-time = "2025-07-17T17:07:02.744Z" }, + { url = "https://files.pythonhosted.org/packages/c3/4d/ea265e1ad4875b42c8a797ebe8ba48987899b6b2298d0f28e3b080f68cd9/torchrl-0.9.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a2e2f7392d8350f1eed048420c6b8d361d85d1ddd93974ca1ffd1a6def51516e", size = 1404579, upload-time = "2025-07-17T17:07:04.505Z" }, + { url = "https://files.pythonhosted.org/packages/c1/87/66b126cfe7dce7f31f5e638475a9ebc9e744d5f083d49ebc558178b9d57f/torchrl-0.9.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:18ada5d4676f45db66b99a8de37a5780a4b7be227124d42257eae1ee74ff7fa8", size = 1411513, upload-time = "2025-07-17T17:07:05.951Z" }, + { url = "https://files.pythonhosted.org/packages/df/05/38b1360d794663204a6622e0f9ba2fc5846821f47a9c6a1cb761717a0dd5/torchrl-0.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:93ca8d378244480b5f268725c08090e381f94a665129ed22e3c9a771d03de69f", size = 1371515, upload-time = "2025-07-17T17:07:07.092Z" }, + { url = "https://files.pythonhosted.org/packages/05/0c/2c4bc554eccd8601b94eeddaf123d78f6c4e2cffc8af4bc2258be436faaf/torchrl-0.9.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:30dff9165d117b9f67bfff9466b8271d25bbfa14231eb41d9faed407db35a8d8", size = 1718618, upload-time = "2025-07-17T17:07:08.204Z" }, + { url = "https://files.pythonhosted.org/packages/ff/32/4aa545fb408d47f3ae8f1b44b93f1dcddfa534fb209f9f83a9168ffb22cf/torchrl-0.9.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:81a236ad797b82a12f8e64ca774ea5b0ea87e9083a6f38268625993ea0a3c122", size = 1404089, upload-time = "2025-07-17T17:07:09.688Z" }, + { url = "https://files.pythonhosted.org/packages/ba/9e/0657833e90aa23ecd7ec18ddf9942bbd298f91bf4fe1fe49988a647d614c/torchrl-0.9.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:3ac51734e3a75961e80167f3945848325c3186edbda119dd6e3cb403def1584d", size = 1410004, upload-time = "2025-07-17T17:07:10.777Z" }, + { url = "https://files.pythonhosted.org/packages/dc/a2/36bda622004c182f56ef736f386d362166e5776bdf4d1992fc7aa9cc007b/torchrl-0.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:5da161795521fd1bc71dc09075fb9f185700018c2296f0889d514f0f12a4d21d", size = 1372046, upload-time = "2025-07-17T17:07:11.91Z" }, +] + [[package]] name = "torchvision" version = "0.22.0"