Skip to content

Commit 6f8575b

Browse files
fix lint error in vllm_worker_async.py
Removed unused import statements in vllm_worker_async.py Signed-off-by: Youngeun Kwon <[email protected]> fix lint error Signed-off-by: Youngeun Kwon <[email protected]> Remove unnecessary blank line in vllm_worker_async.py Signed-off-by: Youngeun Kwon <[email protected]> ci pipe clean Signed-off-by: Youngeun Kwon <[email protected]> add lock for thread safety Signed-off-by: Youngeun Kwon <[email protected]> fix the ci error ci error Signed-off-by: Youngeun Kwon <[email protected]>
1 parent a081752 commit 6f8575b

File tree

4 files changed

+32
-17
lines changed

4 files changed

+32
-17
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2154,6 +2154,7 @@ def async_grpo_train(
21542154
train_results = policy.train(train_data, loss_fn)
21552155

21562156
print("🔄 Synchronizing policy weights to trajectory collector…")
2157+
vllm_logger_metrics = None
21572158
if NEED_REFIT:
21582159
# Measure pending-generation wait as exposed_generation time
21592160
print("🔄 Coordinating with trajectory collector before refit...")
@@ -2257,7 +2258,8 @@ def async_grpo_train(
22572258
else:
22582259
metrics[k] = np.sum(v).item()
22592260
metrics.update(rollout_metrics)
2260-
metrics["vllm_logger_metrics"] = vllm_logger_metrics
2261+
if vllm_logger_metrics is not None:
2262+
metrics["vllm_logger_metrics"] = vllm_logger_metrics
22612263
total_valid_tokens += metrics["global_valid_toks"]
22622264

22632265
# Checkpointing (same as sync version)

nemo_rl/algorithms/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,11 +498,13 @@ def visualize_per_worker_timeline(
498498
else:
499499
print(f" - Generation Worker {dp_idx:3.0f}: {''.join(timeline)}")
500500

501-
vllm_logger_metrics = metrics["vllm_logger_metrics"]
502-
is_vllm_metrics_logger_enabled = master_config["policy"]["generation"][
503-
"vllm_cfg"
504-
].get("enable_vllm_metrics_logger", False)
501+
is_vllm_metrics_logger_enabled = (
502+
master_config["policy"]["generation"]
503+
.get("vllm_cfg", {})
504+
.get("enable_vllm_metrics_logger", False)
505+
)
505506
if is_vllm_metrics_logger_enabled:
507+
vllm_logger_metrics = metrics["vllm_logger_metrics"]
506508
# vllm_logger_me trics: dict[str (metric_name), dict[int (dp_idx), list[int] (metric_values)]]
507509
# metric_name: "inflight_batch_sizes" or "num_pending_samples"
508510
vllm_metrics_logger_interval = master_config["policy"]["generation"][

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,9 @@ def _start_vllm_metrics_logger(self) -> None:
366366
stop_event = threading.Event()
367367
self._vllm_metrics_logger_stop_event = stop_event
368368

369+
# Thread synchronization for metrics access
370+
self._vllm_metrics_lock = threading.Lock()
371+
369372
self.inflight_batch_sizes: list[int] = []
370373
self.num_pending_samples: list[int] = []
371374

@@ -389,10 +392,12 @@ def _logger_loop():
389392
if isinstance(m, Gauge):
390393
# Log the vllm inflight batch sizes
391394
if m.name == "vllm:num_requests_running":
392-
self.inflight_batch_sizes.append(int(m.value))
395+
with self._vllm_metrics_lock:
396+
self.inflight_batch_sizes.append(int(m.value))
393397
# Log the vllm pending number of requests in the queue
394398
elif m.name == "vllm:num_requests_waiting":
395-
self.num_pending_samples.append(int(m.value))
399+
with self._vllm_metrics_lock:
400+
self.num_pending_samples.append(int(m.value))
396401
except Exception:
397402
print(
398403
"⚠️[vLLM Metric Logger]⚠️ Exception in vLLM metrics logger",
@@ -426,16 +431,20 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]:
426431
if not self.cfg["vllm_cfg"].get("enable_vllm_metrics_logger", False):
427432
return {}
428433

429-
return {
430-
"inflight_batch_sizes": copy.deepcopy(self.inflight_batch_sizes),
431-
"num_pending_samples": copy.deepcopy(self.num_pending_samples),
432-
}
434+
with self._vllm_metrics_lock:
435+
metric = {
436+
"inflight_batch_sizes": copy.deepcopy(self.inflight_batch_sizes),
437+
"num_pending_samples": copy.deepcopy(self.num_pending_samples),
438+
}
439+
return metric
433440

434441
def clear_vllm_logger_metrics(self) -> None:
435442
if not self.cfg["vllm_cfg"].get("enable_vllm_metrics_logger", False):
436443
return
437-
self.inflight_batch_sizes = []
438-
self.num_pending_samples = []
444+
445+
with self._vllm_metrics_lock:
446+
self.inflight_batch_sizes = []
447+
self.num_pending_samples = []
439448

440449
def llm(self):
441450
return self.llm

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,12 @@ def _setup_vllm_openai_api_server(self, app: FastAPI) -> FastAPI:
165165
from logging import LogRecord
166166
from typing import List, Optional, Union
167167

168-
from fastapi import Request
169-
from fastapi.responses import JSONResponse, StreamingResponse
170-
171-
from vllm.entrypoints.openai.api_server import (
168+
from fastapi import Request # pyright: ignore[reportMissingImports]
169+
from fastapi.responses import ( # pyright: ignore[reportMissingImports]
170+
JSONResponse,
171+
StreamingResponse,
172+
)
173+
from vllm.entrypoints.openai.api_server import ( # pyright: ignore[reportMissingImports]
172174
BaseModelPath,
173175
OpenAIServingChat,
174176
OpenAIServingModels,

0 commit comments

Comments
 (0)