From 54e45e95066874cde1002baaef695f67f3d72e37 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Fri, 7 Nov 2025 11:27:33 -0800 Subject: [PATCH] draft Signed-off-by: Youngeun Kwon --- docs/guides/async-grpo.md | 7 ++++++- examples/configs/grpo_math_1B.yaml | 5 +++++ nemo_rl/algorithms/async_utils.py | 9 +++++++-- nemo_rl/algorithms/grpo.py | 6 ++++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/docs/guides/async-grpo.md b/docs/guides/async-grpo.md index 0beac8204..79a05ba69 100644 --- a/docs/guides/async-grpo.md +++ b/docs/guides/async-grpo.md @@ -43,6 +43,7 @@ grpo: max_trajectory_age_steps: 1 # Maximum age, in training steps, for trajectories in_flight_weight_updates: false # Enable for faster weight synchronization recompute_kv_cache_after_weight_updates: false # Invalidates kv cache after in-flight-weight-updates + max_num_in_flight_batches_in_generation: ${grpo.async_grpo.max_trajectory_age_steps} # Controls the maximum number of in-flight prompts to regulate average off-policyness. ``` ### Complete Example Config @@ -69,6 +70,7 @@ grpo: max_trajectory_age_steps: 1 in_flight_weight_updates: false # Enable for faster weight synchronization recompute_kv_cache_after_weight_updates: false # Invalidates kv cache after in-flight-weight-updates + max_num_in_flight_batches_in_generation: ${grpo.async_grpo.max_trajectory_age_steps} # Controls the maximum number of in-flight prompts to regulate average off-policyness. cluster: num_nodes: 2 @@ -165,7 +167,10 @@ sequenceDiagram 4. **In-Flight Weight Updates**: Enable `in_flight_weight_updates: true` when using `async_engine: true` for updating the weights of vLLM engine during generation. This prevents stalling training pipeline until longest generation finishes and provides significant performance benefits. 5. **Recompute KV Cache After Weight Updates**: While using in-flight weight update, user can choose whether to recompute -KV caches after weight udpate by configuring `recompute_kv_cache_after_weight_update` configuration. +KV caches after weight update by configuring `recompute_kv_cache_after_weight_update` configuration. + +6. **Control Max Number of In-Flight Batches**: Use `max_num_in_flight_batches_in_generation` (1 ≤ value ≤ `max_trajectory_age_steps`) to cap concurrent prompt batches to control average trajectory age; +number of effective in-flight prompts = value × `num_prompts_per_step`. Keep it equal to `max_trajectory_age_steps` for maximum throughput; lower it when to reduce off-policyness. ## Why Importance Sampling Correction Is Required for Async diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 50124bb71..a890c02bb 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -34,6 +34,11 @@ grpo: max_trajectory_age_steps: 1 in_flight_weight_updates: false # Set to true to enable in-flight weight updates recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates + max_num_in_flight_batches_in_generation: ${grpo.async_grpo.max_trajectory_age_steps} + # This argument will enable pipeline-rl style async-grpo training. + # Allowed values are 1 <= max_num_in_flight_batches_in_generation <= max_trajectory_age_steps + # Maximum number of in-flight prompts will be max_num_in_flight_batches_in_generation * num_prompts_per_step + # By having lower max_num_in_flight_batches_in_generation, we could reduce the avg trajactory age, but might also reduce the inference throughput. loss_fn: reference_policy_kl_penalty: 0.01 diff --git a/nemo_rl/algorithms/async_utils.py b/nemo_rl/algorithms/async_utils.py index c1ce9ab76..e12c21ca2 100644 --- a/nemo_rl/algorithms/async_utils.py +++ b/nemo_rl/algorithms/async_utils.py @@ -278,11 +278,16 @@ def __init__( self._inflight_threads: set[_threading.Thread] = set() self._threads_lock: _threading.Lock = _threading.Lock() - # Limit in-flight generator requests to num_prompts_per_step * max_trajectory_age_steps + # Limit in-flight generator requests to num_prompts_per_step * max_num_in_flight_batches_in_generation + # By default, max_num_in_flight_batches_in_generation is set to max_trajectory_age_steps # This value limits the parallelism of the generation requests. max_inflight = ( int(self.master_config["grpo"]["num_prompts_per_step"]) - * int(self.master_config["grpo"]["async_grpo"]["max_trajectory_age_steps"]) + * int( + self.master_config["grpo"]["async_grpo"][ + "max_num_in_flight_batches_in_generation" + ] + ) ) or 1 self._inflight_sema = _threading.Semaphore(max_inflight) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index fdfa5e3ed..f65180b35 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -108,6 +108,12 @@ class AsyncGRPOConfig(TypedDict): in_flight_weight_updates: NotRequired[bool] # Recomputes the KV cache after the in-flight weight updates. recompute_kv_cache_after_weight_updates: NotRequired[bool] + # Maximum number of in-flight prompts in generation. + # Required to enable pipeline-rl style async-grpo training. + # Allowed values are 1 <= max_num_in_flight_batches_in_generation <= max_trajectory_age_steps + # Maximum number of in-flight prompts will be max_num_in_flight_batches_in_generation * num_prompts_per_step + # By having lower max_num_in_flight_batches_in_generation, we could reduce the avg trajectory age, but might also reduce the inference throughput. + max_num_in_flight_batches_in_generation: NotRequired[int] class GRPOConfig(TypedDict):