Skip to content
2 changes: 2 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ policy:
use_deep_gemm: False
num_last_layers_in_bf16: 0
num_first_layers_in_bf16: 0
enable_vllm_metrics_logger: true # Set to true to enable vLLM internal metrics logger, turn off for better performance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know how much performance hit do we take with metrics enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't noticeable to me. But, theoretically, it can have some overhead. I will run some iso-config runs in our perf tracker that has exposed generation time and see how much perf impact there is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the perf difference in QWEN3 30B model. There was no noticeable perf difference.

vllm_metrics_logger_interval: 0.5 # Interval in seconds to collect vLLM logger metrics
vllm_kwargs: {}
colocated:
# true: generation shares training GPUs
Expand Down
19 changes: 19 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,8 @@ def grpo_train(

dynamic_sampling_num_gen_batches += 1
with timer.time("generation"):
# Clear vLLM logger metrics for each generation step
policy_generation.clear_vllm_logger_metrics()
# Use penguin rollouts if enabled. We cascade penguin first since penguin requires async rollouts.
if _should_use_penguin(master_config):
generation_config = master_config["policy"]["generation"]
Expand Down Expand Up @@ -1122,6 +1124,9 @@ def grpo_train(
greedy=False,
)
policy_generation.finish_generation()
# Collect vLLM logger metrics for performance reporting after each generation step
# inflight batch sizes and num pending samples are collected from each vLLM worker
vllm_logger_metrics = policy_generation.get_vllm_logger_metrics()

repeated_batch = scale_rewards(
repeated_batch, master_config["grpo"]["reward_scaling"]
Expand Down Expand Up @@ -1340,6 +1345,7 @@ def grpo_train(
metrics[k] = np.sum(v).item()

metrics.update(rollout_metrics)
metrics["vllm_logger_metrics"] = vllm_logger_metrics
total_valid_tokens += metrics["global_valid_toks"]

## Checkpointing
Expand Down Expand Up @@ -1907,6 +1913,9 @@ def async_grpo_train(

print("✅ All setup complete, starting buffer wait...")

# Clear vLLM logger metrics after at start of training
policy_generation.clear_vllm_logger_metrics()

# Wait for initial buffer fill
print(
f"⏳ Waiting for replay buffer to have sufficient trajectories ({min_trajectories_needed} trajectories)..."
Expand Down Expand Up @@ -2145,12 +2154,17 @@ def async_grpo_train(
train_results = policy.train(train_data, loss_fn)

print("🔄 Synchronizing policy weights to trajectory collector…")
vllm_logger_metrics = None
if NEED_REFIT:
# Measure pending-generation wait as exposed_generation time
print("🔄 Coordinating with trajectory collector before refit...")
with timer.time("exposed_generation"):
ray.get(trajectory_collector.prepare_for_refit.remote())

# Collect vLLM logger metrics for performance reporting
# inflight batch sizes and num pending samples are collected from each vLLM worker
vllm_logger_metrics = policy_generation.get_vllm_logger_metrics()

# Only the actual refit/weight transfer should be counted as weight_sync
print("🔄 Performing policy generation refit...")
with timer.time("weight_sync"):
Expand All @@ -2164,6 +2178,9 @@ def async_grpo_train(
trajectory_collector.set_weight_version.remote(weight_version)
trajectory_collector.resume_after_refit.remote()

# Clear vLLM logger metrics after each refit (weight sync), starting a new logging cycle
policy_generation.clear_vllm_logger_metrics()

# Validation
val_metrics, validation_timings = None, None
is_last_step = step + 1 == master_config["grpo"]["max_num_steps"]
Expand Down Expand Up @@ -2241,6 +2258,8 @@ def async_grpo_train(
else:
metrics[k] = np.sum(v).item()
metrics.update(rollout_metrics)
if vllm_logger_metrics is not None:
metrics["vllm_logger_metrics"] = vllm_logger_metrics
total_valid_tokens += metrics["global_valid_toks"]

# Checkpointing (same as sync version)
Expand Down
116 changes: 111 additions & 5 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import random
import warnings
from functools import partial, wraps
from typing import Optional
from typing import Any, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -384,7 +384,7 @@ def maybe_pad_last_batch(batch: dict, dp_size: int, mbs: int) -> dict:

def print_performance_metrics(
train_results: dict[str, float],
metrics: dict[str, float],
metrics: dict[str, Any],
timing_metrics: dict[str, float],
master_config: dict,
) -> dict[str, float]:
Expand All @@ -400,13 +400,14 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float:
per_worker_load_ratio = [
v / max(per_worker_token_counts_list) for v in per_worker_token_counts_list
]
max_rows_to_print = 100
max_rows_to_print = 1000
bar_length = 20
print(" • Visualizing Token Imbalance per Generation Worker:")
for i in range(min(len(per_worker_token_counts_list), max_rows_to_print)):
print(
f" - Generated Tokens from Worker {i:3.0f}:"
f"{'■' * int(per_worker_load_ratio[i] * 10)}"
f"{'□' * (10 - int(per_worker_load_ratio[i] * 10))}"
f"{'■' * int(per_worker_load_ratio[i] * bar_length)}"
f"{'□' * (bar_length - int(per_worker_load_ratio[i] * bar_length))}"
f" Count: {per_worker_token_counts_list[i] / 1000:.1f}K"
)
estimated_idle_ratio = 1 - sum(per_worker_load_ratio) / len(
Expand Down Expand Up @@ -441,6 +442,111 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float:
f" • Mean Total Tokens per Sample: {metrics['mean_total_tokens_per_sample']:.2f}"
)

# =====================================================
# vLLM Logger Metrics (inflight batch sizes, num pending samples, etc.)
# =====================================================
def resize_timeline(data, new_size):
old_size = len(data)
x_old = np.linspace(0, 1, old_size)
x_new = np.linspace(0, 1, new_size)
return np.interp(x_new, x_old, data)

def visualize_per_worker_timeline(
metric_dict: dict[int, list[int]],
metric_name: str,
timeline_interval: float | None,
) -> None:
dp_ranks = list(metric_dict.keys())
max_rows_to_print = 1000
max_timeline_length = 50
marker = {0: "▃", 1: "▅", 2: "▆", 3: "▉"}
zero_marker = "▁"

max_value = max((max(v) if v else 0) for v in metric_dict.values())
bin_width = (max_value + 1) / len(marker)

print(f" - {metric_name}:")
print(f" - Max value: {max_value}")
print(
f" - Timeline (0: {zero_marker}, {', '.join(f'{1.0 if k == 0 else k * (max_value / len(marker))}-{(k + 1) * (max_value / len(marker))}: {marker[k]}' for k in marker.keys())}):"
)
for dp_idx, metric_values in metric_dict.items():
if dp_idx > max_rows_to_print:
break
timeline = []
length = len(metric_values)
if timeline_interval is not None:
count_zeros = lambda x: sum(v == 0 for v in x)
idle = count_zeros(metric_values) * timeline_interval
active = length * timeline_interval - idle
if length > max_timeline_length:
resized_metric_values = resize_timeline(
metric_values, max_timeline_length
)
else:
resized_metric_values = metric_values

for i, value in enumerate(resized_metric_values):
m = (
zero_marker
if value == 0
else marker[min(int(value // bin_width), len(marker) - 1)]
)
timeline.append(m)
if timeline_interval is not None:
print(
f" - Generation Worker {dp_idx:3.0f}: {''.join(timeline)} (Active: {active:.2f} s, Idle: {idle:.2f} s)"
)
else:
print(f" - Generation Worker {dp_idx:3.0f}: {''.join(timeline)}")

is_vllm_metrics_logger_enabled = (
master_config["policy"]["generation"]
.get("vllm_cfg", {})
.get("enable_vllm_metrics_logger", False)
)
if is_vllm_metrics_logger_enabled:
vllm_logger_metrics = metrics["vllm_logger_metrics"]
# vllm_logger_me trics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]]
# metric_name: "inflight_batch_sizes" or "num_pending_samples"

assert "inflight_batch_sizes" in vllm_logger_metrics, (
"inflight_batch_sizes not found in vllm_logger_metrics"
)
assert "num_pending_samples" in vllm_logger_metrics, (
"num_pending_samples not found in vllm_logger_metrics"
)
assert isinstance(vllm_logger_metrics["inflight_batch_sizes"], dict), (
"inflight_batch_sizes must be a dictionary"
)
assert isinstance(vllm_logger_metrics["num_pending_samples"], dict), (
"num_pending_samples must be a dictionary"
)

vllm_metrics_logger_interval = master_config["policy"]["generation"][
"vllm_cfg"
]["vllm_metrics_logger_interval"]
print(" • vLLM Logger Metrics:")
# Visualize the inflight batch sizes timeline
if len(vllm_logger_metrics["inflight_batch_sizes"].values()) > 0:
visualize_per_worker_timeline(
vllm_logger_metrics["inflight_batch_sizes"],
"Inflight Batch Sizes",
vllm_metrics_logger_interval,
)
if len(vllm_logger_metrics["num_pending_samples"].values()) > 0:
max_num_pending_samples = max(
(max(v) if v else 0)
for v in vllm_logger_metrics["num_pending_samples"].values()
)
# If there is at least one pending sample, visualize the timeline
if max_num_pending_samples > 0:
visualize_per_worker_timeline(
vllm_logger_metrics["num_pending_samples"],
"Num Pending Samples",
None,
)

# =====================================================
# Throughputs
# =====================================================
Expand Down
45 changes: 45 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,51 @@ def stop_gpu_profiling(self) -> None:
futures = self.worker_group.run_all_workers_single_data("stop_gpu_profiling")
ray.get(futures)

def get_vllm_logger_metrics(self) -> dict[str, Any]:
"""Collect vLLM logger metrics from vLLM workers (model-owner actors only)."""
if not self.cfg["vllm_cfg"].get("enable_vllm_metrics_logger", False):
return {}

futures: list[ray.ObjectRef] = []
dp_indices: list[int] = []
for dp_idx in range(self.worker_group.dp_size):
worker_idx = self.worker_group.get_dp_leader_worker_idx(dp_idx)
future = self.worker_group.run_single_worker_single_data(
"get_vllm_logger_metrics",
worker_idx=worker_idx,
)
futures.append(future)
dp_indices.append(dp_idx)

results = ray.get(futures)
vllm_logger_metrics: dict[str, dict[int, list[int]]] = {
"inflight_batch_sizes": {}, # dp_idx -> list[int]
"num_pending_samples": {}, # dp_idx -> list[int]
}

for dp_idx, stats in zip(dp_indices, results):
if not stats:
continue
inflight_batch_sizes = stats.get("inflight_batch_sizes")
if inflight_batch_sizes:
vllm_logger_metrics["inflight_batch_sizes"][dp_idx] = (
inflight_batch_sizes
)
num_pending_samples = stats.get("num_pending_samples")
if num_pending_samples:
vllm_logger_metrics["num_pending_samples"][dp_idx] = num_pending_samples

return vllm_logger_metrics

def clear_vllm_logger_metrics(self) -> None:
if not self.cfg["vllm_cfg"].get("enable_vllm_metrics_logger", False):
return
futures = self.worker_group.run_all_workers_single_data(
"clear_vllm_logger_metrics",
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)
ray.get(futures)

def __del__(self) -> None:
"""Shuts down the worker groups when the object is deleted or is garbage collected.
Expand Down
Loading
Loading