Skip to content

Commit 45d9680

Browse files
polish the PR
Signed-off-by: Youngeun Kwon <[email protected]>
1 parent f0e0f85 commit 45d9680

File tree

5 files changed

+109
-59
lines changed

5 files changed

+109
-59
lines changed

examples/configs/grpo_math_1B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ policy:
230230
num_last_layers_in_bf16: 0
231231
num_first_layers_in_bf16: 0
232232
enable_vllm_metrics_logger: false # Set to true to enable vLLM internal metrics logger, might impact performance
233+
vllm_metrics_logger_interval: 0.5 # Interval in seconds to collect vLLM logger metrics
233234
vllm_kwargs: {}
234235
colocated:
235236
# true: generation shares training GPUs

nemo_rl/algorithms/grpo.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,9 +1038,6 @@ def grpo_train(
10381038
maybe_gpu_profile_step(policy_generation, total_steps + 1)
10391039
val_metrics, validation_timings = None, None
10401040

1041-
# Clear vLLM logger metrics after each step
1042-
policy_generation.clear_vllm_logger_metrics()
1043-
10441041
with timer.time("total_step_time"):
10451042
# Prepare batch
10461043
print("▶ Preparing batch...", flush=True)
@@ -1076,6 +1073,8 @@ def grpo_train(
10761073

10771074
dynamic_sampling_num_gen_batches += 1
10781075
with timer.time("generation"):
1076+
# Clear vLLM logger metrics for each generation step
1077+
policy_generation.clear_vllm_logger_metrics()
10791078
# Use penguin rollouts if enabled. We cascade penguin first since penguin requires async rollouts.
10801079
if _should_use_penguin(master_config):
10811080
generation_config = master_config["policy"]["generation"]
@@ -1125,10 +1124,9 @@ def grpo_train(
11251124
greedy=False,
11261125
)
11271126
policy_generation.finish_generation()
1128-
1129-
# Collect vLLM logger metrics for performance reporting
1130-
# inflight batch sizes and num pending samples are collected from each vLLM worker
1131-
vllm_logger_metrics = policy_generation.get_vllm_logger_metrics()
1127+
# Collect vLLM logger metrics for performance reporting after each generation step
1128+
# inflight batch sizes and num pending samples are collected from each vLLM worker
1129+
vllm_logger_metrics = policy_generation.get_vllm_logger_metrics()
11321130

11331131
repeated_batch = scale_rewards(
11341132
repeated_batch, master_config["grpo"]["reward_scaling"]
@@ -1934,6 +1932,9 @@ def async_grpo_train(
19341932

19351933
print("✅ Buffer ready! Starting training loop...")
19361934

1935+
# Clear vLLM logger metrics after at start of training
1936+
policy_generation.clear_vllm_logger_metrics()
1937+
19371938
# Main training loop
19381939
try:
19391940
while step < master_config["grpo"]["max_num_steps"]:
@@ -1944,9 +1945,6 @@ def async_grpo_train(
19441945
if policy != policy_generation:
19451946
maybe_gpu_profile_step(policy_generation, step + 1)
19461947

1947-
# Clear vLLM logger metrics after each step
1948-
policy_generation.clear_vllm_logger_metrics()
1949-
19501948
with timer.time("total_step_time"):
19511949
# Sample trajectories from replay buffer
19521950
print("📦 Sampling from replay buffer...")
@@ -2179,6 +2177,9 @@ def async_grpo_train(
21792177
trajectory_collector.set_weight_version.remote(weight_version)
21802178
trajectory_collector.resume_after_refit.remote()
21812179

2180+
# Clear vLLM logger metrics after each refit (weight sync), starting a new logging cycle
2181+
policy_generation.clear_vllm_logger_metrics()
2182+
21822183
# Validation
21832184
val_metrics, validation_timings = None, None
21842185
is_last_step = step + 1 == master_config["grpo"]["max_num_steps"]

nemo_rl/algorithms/utils.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -401,12 +401,13 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float:
401401
v / max(per_worker_token_counts_list) for v in per_worker_token_counts_list
402402
]
403403
max_rows_to_print = 100
404+
bar_length = 20
404405
print(" • Visualizing Token Imbalance per Generation Worker:")
405406
for i in range(min(len(per_worker_token_counts_list), max_rows_to_print)):
406407
print(
407408
f" - Generated Tokens from Worker {i:3.0f}:"
408-
f"{'■' * int(per_worker_load_ratio[i] * 10)}"
409-
f"{'□' * (10 - int(per_worker_load_ratio[i] * 10))}"
409+
f"{'■' * int(per_worker_load_ratio[i] * bar_length)}"
410+
f"{'□' * (bar_length - int(per_worker_load_ratio[i] * bar_length))}"
410411
f" Count: {per_worker_token_counts_list[i] / 1000:.1f}K"
411412
)
412413
estimated_idle_ratio = 1 - sum(per_worker_load_ratio) / len(
@@ -442,25 +443,77 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float:
442443
)
443444

444445
# =====================================================
445-
# vLLM Logger Metrics (inflight batch sizes and pending samples)
446+
# vLLM Logger Metrics (inflight batch sizes, num pending samples, etc.)
446447
# =====================================================
448+
def resize_timeline(data, new_size):
449+
old_size = len(data)
450+
x_old = np.linspace(0, 1, old_size)
451+
x_new = np.linspace(0, 1, new_size)
452+
return np.interp(x_new, x_old, data)
453+
454+
def visualize_per_worker_timeline(
455+
metric_dict: dict[int, list[int]],
456+
metric_name: str,
457+
timeline_interval: float | None,
458+
) -> None:
459+
dp_ranks = list(metric_dict.keys())
460+
max_timeline_length = 50
461+
marker = {0: "□", 1: "⧅", 2: "⛝", 3: "■"}
462+
463+
max_value = max(max(v) for v in metric_dict.values())
464+
bin_width = (max_value + 1) / len(marker)
465+
466+
print(f" - {metric_name}:")
467+
print(f" - Max value: {max_value}")
468+
print(" - Timeline:")
469+
for dp_idx, metric_values in metric_dict.items():
470+
timeline = []
471+
length = len(metric_values)
472+
if timeline_interval is not None:
473+
count_zeros = lambda x: sum(v == 0 for v in x)
474+
idle = count_zeros(metric_values) * timeline_interval
475+
active = length * timeline_interval - idle
476+
if length > max_timeline_length:
477+
resized_metric_values = resize_timeline(
478+
metric_values, max_timeline_length
479+
)
480+
else:
481+
resized_metric_values = metric_values
482+
483+
for i, value in enumerate(resized_metric_values):
484+
timeline.append(marker[min(int(value // bin_width), len(marker) - 1)])
485+
if timeline_interval is not None:
486+
print(
487+
f" - Generation Worker {dp_idx:3.0f}: {' '.join(timeline)} (Active: {active:.2f} s, Idle: {idle:.2f} s)"
488+
)
489+
else:
490+
print(f" - Generation Worker {dp_idx:3.0f}: {' '.join(timeline)}")
491+
447492
if "vllm_logger_metrics" in metrics:
493+
# vllm_logger_metrics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]]
494+
# metric_name: "inflight_batch_sizes" or "num_pending_samples"
448495
vllm_logger_metrics = metrics["vllm_logger_metrics"]
496+
449497
if vllm_logger_metrics is not None:
498+
vllm_metrics_logger_interval = master_config["policy"]["generation"][
499+
"vllm_cfg"
500+
]["vllm_metrics_logger_interval"]
450501
print(" • vLLM Logger Metrics:")
451-
for dp_idx, inflight_batch_sizes in vllm_logger_metrics[
452-
"inflight_batch_sizes"
453-
].items():
454-
print(
455-
f" - vLLM Inflight Batch Sizes for DP {dp_idx}: {inflight_batch_sizes}",
456-
flush=True,
457-
)
458-
for dp_idx, num_pending_samples in vllm_logger_metrics[
459-
"num_pending_samples"
460-
].items():
461-
print(
462-
f" - vLLM Num Pending Samples for DP {dp_idx}: {num_pending_samples}",
463-
flush=True,
502+
# Visualize the inflight batch sizes timeline
503+
visualize_per_worker_timeline(
504+
vllm_logger_metrics["inflight_batch_sizes"],
505+
"Inflight Batch Sizes",
506+
vllm_metrics_logger_interval,
507+
)
508+
max_num_pending_samples = max(
509+
max(v) for v in vllm_logger_metrics["num_pending_samples"].values()
510+
)
511+
# If there is at least one pending sample, visualize the timeline
512+
if max_num_pending_samples > 0:
513+
visualize_per_worker_timeline(
514+
vllm_logger_metrics["num_pending_samples"],
515+
"Num Pending Samples",
516+
None,
464517
)
465518

466519
# =====================================================

nemo_rl/models/generation/vllm/vllm_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -836,9 +836,9 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]:
836836
dp_indices.append(dp_idx)
837837

838838
results = ray.get(futures)
839-
vllm_logger_metrics: dict[str, dict[int, dict[int, list[int]]]] = {
840-
"inflight_batch_sizes": {},
841-
"num_pending_samples": {},
839+
vllm_logger_metrics: dict[str, dict[int, list[int]]] = {
840+
"inflight_batch_sizes": {}, # dp_idx -> list[int]
841+
"num_pending_samples": {}, # dp_idx -> list[int]
842842
}
843843

844844
for dp_idx, stats in zip(dp_indices, results):

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -330,41 +330,44 @@ def _patch_vllm_init_workers_ray():
330330
self._create_engine(llm_kwargs)
331331

332332
# Optionally start periodic vLLM metrics logging if the flag is set
333+
# NOTE: vLLM metrics logger is only supported with async engine enabled
333334
# Metrics logger only enabled for per-actor, model-owner only
334-
if self.cfg["vllm_cfg"].get("enable_vllm_metrics_logger", False):
335-
self._maybe_start_vllm_metrics_logger()
335+
if self.cfg["vllm_cfg"].get("enable_vllm_metrics_logger", False) and self.cfg[
336+
"vllm_cfg"
337+
].get("async_engine", False):
338+
self._start_vllm_metrics_logger()
336339

337340
# will be initialized in post_init
338341
# used in update_weights_from_ipc_handles
339342
self.vllm_device_ids = None
340343

341-
def _maybe_start_vllm_metrics_logger(self) -> None:
342-
"""Start a background thread that periodically prints vLLM inflight/queued sizes.
344+
def _start_vllm_metrics_logger(self) -> None:
345+
"""Start a background thread that periodically collects vLLM logger metrics.
343346
344-
Controlled by env var NRL_VLLM_LOG_METRICS_INTERVAL_SEC. Set to a positive
345-
float (e.g. "10") to enable. Runs only on the model-owner actor.
347+
Controlled by vllm_metrics_logger_interval (default: 0.5) in vllm_cfg.
348+
Runs only on the model-owner actor.
346349
"""
350+
assert self.cfg["vllm_cfg"].get("async_engine", False), (
351+
"vLLM metrics logger is only supported with async engine enabled"
352+
)
347353
# Run only on the model-owner actor
348354
if not getattr(self, "is_model_owner", False):
349355
return
350356

351-
try:
352-
interval_s_str = os.environ.get("NRL_VLLM_LOG_METRICS_INTERVAL_SEC", "0.5")
353-
if not interval_s_str:
354-
return
355-
interval_s = float(interval_s_str)
356-
except Exception:
357-
return
358-
359-
if interval_s <= 0:
360-
return
357+
assert "vllm_metrics_logger_interval" in self.cfg["vllm_cfg"], (
358+
"vllm_metrics_logger_interval must be set in vllm_cfg if enable_vllm_metrics_logger is True"
359+
)
360+
interval_s = self.cfg["vllm_cfg"]["vllm_metrics_logger_interval"]
361+
assert interval_s > 0, (
362+
f"vllm_metrics_logger_interval must be a positive float, got {interval_s}"
363+
)
361364

362365
# Lazy import inside thread target to avoid import overhead if disabled
363366
stop_event = threading.Event()
364367
self._vllm_metrics_logger_stop_event = stop_event
365368

366-
self.inflight_batch_sizes: dict[int, list[int]] = {}
367-
self.num_pending_samples: dict[int, list[int]] = {}
369+
self.inflight_batch_sizes: list[int] = []
370+
self.num_pending_samples: list[int] = []
368371

369372
def _logger_loop():
370373
# Delay a little to let engine settle
@@ -386,25 +389,17 @@ def _logger_loop():
386389
if isinstance(m, Gauge):
387390
# Log the vllm inflight batch sizes
388391
if m.name == "vllm:num_requests_running":
389-
eng = int(m.labels.get("engine", "0"))
390-
if eng not in self.inflight_batch_sizes:
391-
self.inflight_batch_sizes[eng] = []
392-
self.inflight_batch_sizes[eng].append(int(m.value))
392+
self.inflight_batch_sizes.append(int(m.value))
393393
# Log the vllm pending number of requests in the queue
394394
elif m.name == "vllm:num_requests_waiting":
395-
eng = int(m.labels.get("engine", "0"))
396-
if eng not in self.num_pending_samples:
397-
self.num_pending_samples[eng] = []
398-
self.num_pending_samples[eng].append(int(m.value))
395+
self.num_pending_samples.append(int(m.value))
399396
except Exception:
400397
print(
401398
"⚠️[vLLM Metric Logger]⚠️ Exception in vLLM metrics logger",
402399
flush=True,
403400
)
404-
# tolerate bad metric entries
405401
pass
406402
except Exception:
407-
# Avoid crashing the worker on logging issues
408403
print(
409404
"⚠️[vLLM Metric Logger]⚠️ Exception in vLLM metrics logger",
410405
flush=True,
@@ -439,8 +434,8 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]:
439434
def clear_vllm_logger_metrics(self) -> None:
440435
if not self.cfg["vllm_cfg"].get("enable_vllm_metrics_logger", False):
441436
return
442-
self.inflight_batch_sizes = {}
443-
self.num_pending_samples = {}
437+
self.inflight_batch_sizes = []
438+
self.num_pending_samples = []
444439

445440
def llm(self):
446441
return self.llm

0 commit comments

Comments
 (0)