Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/guides/async-grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ grpo:
max_trajectory_age_steps: 1 # Maximum age, in training steps, for trajectories
in_flight_weight_updates: false # Enable for faster weight synchronization
recompute_kv_cache_after_weight_updates: false # Invalidates kv cache after in-flight-weight-updates
max_num_in_flight_batches_in_generation: ${grpo.async_grpo.max_trajectory_age_steps} # Controls the maximum number of in-flight prompts to regulate average off-policyness.
```

### Complete Example Config
Expand All @@ -69,6 +70,7 @@ grpo:
max_trajectory_age_steps: 1
in_flight_weight_updates: false # Enable for faster weight synchronization
recompute_kv_cache_after_weight_updates: false # Invalidates kv cache after in-flight-weight-updates
max_num_in_flight_batches_in_generation: ${grpo.async_grpo.max_trajectory_age_steps} # Controls the maximum number of in-flight prompts to regulate average off-policyness.

cluster:
num_nodes: 2
Expand Down Expand Up @@ -165,7 +167,10 @@ sequenceDiagram
4. **In-Flight Weight Updates**: Enable `in_flight_weight_updates: true` when using `async_engine: true` for updating the weights of vLLM engine during generation. This prevents stalling training pipeline until longest generation finishes and provides significant performance benefits.

5. **Recompute KV Cache After Weight Updates**: While using in-flight weight update, user can choose whether to recompute
KV caches after weight udpate by configuring `recompute_kv_cache_after_weight_update` configuration.
KV caches after weight update by configuring `recompute_kv_cache_after_weight_update` configuration.

6. **Control Max Number of In-Flight Batches**: Use `max_num_in_flight_batches_in_generation` (1 ≤ value ≤ `max_trajectory_age_steps`) to cap concurrent prompt batches to control average trajectory age;
number of effective in-flight prompts = value × `num_prompts_per_step`. Keep it equal to `max_trajectory_age_steps` for maximum throughput; lower it when to reduce off-policyness.

## Why Importance Sampling Correction Is Required for Async

Expand Down
5 changes: 5 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ grpo:
max_trajectory_age_steps: 1
in_flight_weight_updates: false # Set to true to enable in-flight weight updates
recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates
max_num_in_flight_batches_in_generation: ${grpo.async_grpo.max_trajectory_age_steps}
# This argument will enable pipeline-rl style async-grpo training.
# Allowed values are 1 <= max_num_in_flight_batches_in_generation <= max_trajectory_age_steps
# Maximum number of in-flight prompts will be max_num_in_flight_batches_in_generation * num_prompts_per_step
# By having lower max_num_in_flight_batches_in_generation, we could reduce the avg trajactory age, but might also reduce the inference throughput.

loss_fn:
reference_policy_kl_penalty: 0.01
Expand Down
9 changes: 7 additions & 2 deletions nemo_rl/algorithms/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,16 @@ def __init__(
self._inflight_threads: set[_threading.Thread] = set()
self._threads_lock: _threading.Lock = _threading.Lock()

# Limit in-flight generator requests to num_prompts_per_step * max_trajectory_age_steps
# Limit in-flight generator requests to num_prompts_per_step * max_num_in_flight_batches_in_generation
# By default, max_num_in_flight_batches_in_generation is set to max_trajectory_age_steps
# This value limits the parallelism of the generation requests.
max_inflight = (
int(self.master_config["grpo"]["num_prompts_per_step"])
* int(self.master_config["grpo"]["async_grpo"]["max_trajectory_age_steps"])
* int(
self.master_config["grpo"]["async_grpo"][
"max_num_in_flight_batches_in_generation"
]
)
) or 1
self._inflight_sema = _threading.Semaphore(max_inflight)

Expand Down
6 changes: 6 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ class AsyncGRPOConfig(TypedDict):
in_flight_weight_updates: NotRequired[bool]
# Recomputes the KV cache after the in-flight weight updates.
recompute_kv_cache_after_weight_updates: NotRequired[bool]
# Maximum number of in-flight prompts in generation.
# Required to enable pipeline-rl style async-grpo training.
# Allowed values are 1 <= max_num_in_flight_batches_in_generation <= max_trajectory_age_steps
# Maximum number of in-flight prompts will be max_num_in_flight_batches_in_generation * num_prompts_per_step
# By having lower max_num_in_flight_batches_in_generation, we could reduce the avg trajectory age, but might also reduce the inference throughput.
max_num_in_flight_batches_in_generation: NotRequired[int]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% convinced about the advantage of this flag for our async implementation. We would always want to aim for max throughput rollouts so as to fully utilize the rollout gpus. If the user indeed intends to control off policy factor, they should use the max trajectory age.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is to test the benefit of the Pipeline-RL-style implementation.

A potential target case might be like this:

  • A user wants max 2-off for most of the samples, but there are a few very-long generating samples that cause the straggler effect, which requires 8-off to hide the latency.

In the above case, in the current ToT, if we choose 2-off, it will cause long-exposed generation; if we choose 8-off, it will make everything 8-off.

With this PR, by setting max_trajectory_age_steps to a large number (e.g., 8), and setting max_num_in_flight_batches_in_generation=2 might be able to achieve both goals.

But in general, I agree with your comment. Controlling the average age and throughput in this way is too complex for the users. I am okay with holding this PR until we get strong proof that this feature gives any benefit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the analysis, but we can fix it a better way than introducing more flags. Our current implementation is very strict, i.e a trajectory while it is enqueued knows exactly when in the future will it be used. We need to relax this condition such that the decision of what goes into a training batch is FIFO when the trajectory age fits within the budget.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the FIFO restriction is limiting both performance and the capability of minimizing the average age of the trajectory. We may need more discussions on how to improve this.



class GRPOConfig(TypedDict):
Expand Down
Loading