From 327dd7da4eef332b224e18d725d25c10e5c57cfc Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 7 Nov 2025 10:41:30 -0800 Subject: [PATCH 01/10] Fixed static inference --- examples/configs/grpo_math_1B_megatron.yaml | 2 +- nemo_rl/algorithms/grpo.py | 3 +- .../models/policy/megatron_policy_worker.py | 12 ++++- tests/functional/grpo_megatron_generation.sh | 46 +++++++++++++++++++ ....2-1b-instruct-1n8g-megatron_generation.sh | 42 +++++++++++++++++ tests/test_suites/nightly.txt | 1 + 6 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 tests/functional/grpo_megatron_generation.sh create mode 100644 tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 58ecbc778..7d75e2b82 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -137,7 +137,7 @@ policy: data_parallel_sharding_strategy: "optim_grads_params" generation: - backend: "vllm" + backend: "megatron" max_new_tokens: ${policy.max_total_sequence_length} temperature: 1.0 top_p: 1.0 diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index a9d9a1eb6..9041bfe27 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -511,7 +511,8 @@ def setup( # prepare refit info state_dict_info = policy.prepare_refit_info() - policy_generation.prepare_refit_info(state_dict_info) + if policy_generation is not None: + policy_generation.prepare_refit_info(state_dict_info) loss_fn = ClippedPGLossFn(loss_config) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 2047c05ea..548aee3d7 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1750,6 +1750,10 @@ def generate( """ no_grad = torch.no_grad() no_grad.__enter__() + + if self.should_disable_forward_pre_hook: + self.model = self.move_model(self.model, "cuda", move_params=True, move_grads=False) + self.model.config.flash_decode = True # Verify input is right padded assert isinstance(data, BatchedDataDict), ( @@ -1801,11 +1805,17 @@ def generate( # self.tokenizer.decode(prompt) # for prompt in data.get("input_ids") # ] + + input_ids = data["input_ids"] + tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size(1) + + padding = torch.full((input_ids.shape[0],tokens_to_generate), self.megatron_tokenizer.eod_id, dtype = input_ids.dtype, device= input_ids.device) + prompt_tokens_tensor = torch.cat([input_ids, padding], dim=1) # apply chat template out = run_mcore_engine( engine=inference_engine, # prompts = detokenized_prompts, - prompt_tokens_tensor=data["input_ids"], + prompt_tokens_tensor=prompt_tokens_tensor, prompt_lengths_tensor=data["input_lengths"], tokens_to_generate=self.cfg["generation"]["max_new_tokens"] # type: ignore - data["input_ids"].size(1), diff --git a/tests/functional/grpo_megatron_generation.sh b/tests/functional/grpo_megatron_generation.sh new file mode 100644 index 000000000..984be7e1a --- /dev/null +++ b/tests/functional/grpo_megatron_generation.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +# Using Qwen2.5-0.5B instead of Qwen3-0.6B because the latter is not supported by Megatron yet +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py \ + --config $PROJECT_ROOT/examples/configs/grpo_math_1B_megatron.yaml \ + policy.model_name=Qwen/Qwen2.5-0.5B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.logprob_batch_size=4 \ + policy.train_micro_batch_size=1 \ + policy.generation.backend=megatron \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/token_mult_prob_error"]) < 1.05' + diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh new file mode 100644 index 000000000..08f57cb5a --- /dev/null +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh @@ -0,0 +1,42 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=500 +MAX_STEPS=500 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + policy.generation.backend=megatron \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.1' \ + 'data["train/token_mult_prob_error"]["500"] < 1.1' \ + 'data["train/reward"]["500"] > 0.1' \ + 'mean(data["timing/train/total_step_time"], -6, -1) < 10.5' +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 91c24aada..24677186e 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -12,6 +12,7 @@ tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh # Megatron tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh # Functional 32b run tests/test_suites/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8-actckpt.v3.sh From e4cb2c96370a15f57a021b16799c980f5536a86b Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 7 Nov 2025 14:04:04 -0800 Subject: [PATCH 02/10] Updated to directly use static engine --- .../models/policy/megatron_policy_worker.py | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 548aee3d7..9f8466997 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1785,6 +1785,8 @@ def generate( from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, ) + from megatron.core.inference.sampling_params import SamplingParams + from megatron.core.inference.inference_request import InferenceRequest inference_context = StaticInferenceContext.from_config(inference_wrapper_config) @@ -1800,27 +1802,42 @@ def generate( max_batch_size=self.cfg["generation_batch_size"], ) - # detokenize the prompts - # detokenized_prompts = [ - # self.tokenizer.decode(prompt) - # for prompt in data.get("input_ids") - # ] - input_ids = data["input_ids"] tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size(1) padding = torch.full((input_ids.shape[0],tokens_to_generate), self.megatron_tokenizer.eod_id, dtype = input_ids.dtype, device= input_ids.device) prompt_tokens_tensor = torch.cat([input_ids, padding], dim=1) - # apply chat template - out = run_mcore_engine( - engine=inference_engine, - # prompts = detokenized_prompts, - prompt_tokens_tensor=prompt_tokens_tensor, - prompt_lengths_tensor=data["input_lengths"], - tokens_to_generate=self.cfg["generation"]["max_new_tokens"] # type: ignore - - data["input_ids"].size(1), + prompt_lengths_tensor = data["input_lengths"] + + sampling_params = SamplingParams( + temperature=1.0, + top_k=0, + top_p=0.0, + return_segments=False, + return_log_probs=True, + num_tokens_to_generate=tokens_to_generate, + top_n_logprobs=0, + return_prompt_top_n_logprobs=False, ) - # print(out) + requests = [] + for p, l in zip(prompt_tokens_tensor, prompt_lengths_tensor): + tokenized_prompt = p[:l].cpu().numpy().tolist() + detokenized_prompt = self.tokenizer.decode(tokenized_prompt) + req = InferenceRequest( + prompt=detokenized_prompt, + prompt_tokens=tokenized_prompt, + sampling_params=sampling_params, + request_id=inference_engine.get_new_request_id(), + ) + requests.append(req) + + result = inference_engine.generate(inference_requests=requests) + + out = { + "text": [x.prompt + x.generated_text for x in result], + "tokens": [x.prompt_tokens + x.generated_tokens.tolist() for x in result], + "logprobs" : [x.prompt_log_probs + x.generated_log_probs for x in result] + } input_lengths = data["input_lengths"] # pad the out "tokens" and "logprobs" and make them into tensors from lists From 3bbee4b2aa850ca38003886ae1b87064d7e04d28 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 7 Nov 2025 14:14:20 -0800 Subject: [PATCH 03/10] Fixed error --- examples/configs/grpo_math_1B_megatron.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 7d75e2b82..58ecbc778 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -137,7 +137,7 @@ policy: data_parallel_sharding_strategy: "optim_grads_params" generation: - backend: "megatron" + backend: "vllm" max_new_tokens: ${policy.max_total_sequence_length} temperature: 1.0 top_p: 1.0 From a52d9d09a34d8d0cd7c2fecaf2b75b43e43ec50b Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 7 Nov 2025 16:51:56 -0800 Subject: [PATCH 04/10] Fixed error --- .../models/policy/megatron_policy_worker.py | 34 +++++++++++-------- tests/functional/L1_Functional_Tests_GPU.sh | 1 + .../models/policy/test_megatron_worker.py | 1 - 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index a1d1f98d1..7c44da27d 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -75,9 +75,6 @@ from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) -from megatron.core.inference.text_generation_server.run_mcore_engine import ( - run_mcore_engine, -) from megatron.core.models.gpt import GPTModel from megatron.core.optimizer import ChainedOptimizer from megatron.core.parallel_state import ( @@ -1755,9 +1752,11 @@ def generate( """ no_grad = torch.no_grad() no_grad.__enter__() - + if self.should_disable_forward_pre_hook: - self.model = self.move_model(self.model, "cuda", move_params=True, move_grads=False) + self.model = self.move_model( + self.model, "cuda", move_params=True, move_grads=False + ) self.model.config.flash_decode = True # Verify input is right padded @@ -1787,11 +1786,11 @@ def generate( ) from megatron.core.inference.contexts import StaticInferenceContext + from megatron.core.inference.inference_request import InferenceRequest from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, ) from megatron.core.inference.sampling_params import SamplingParams - from megatron.core.inference.inference_request import InferenceRequest inference_context = StaticInferenceContext.from_config(inference_wrapper_config) @@ -1808,16 +1807,23 @@ def generate( ) input_ids = data["input_ids"] - tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size(1) - - padding = torch.full((input_ids.shape[0],tokens_to_generate), self.megatron_tokenizer.eod_id, dtype = input_ids.dtype, device= input_ids.device) + tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size( + 1 + ) + + padding = torch.full( + (input_ids.shape[0], tokens_to_generate), + self.megatron_tokenizer.eod_id, + dtype=input_ids.dtype, + device=input_ids.device, + ) prompt_tokens_tensor = torch.cat([input_ids, padding], dim=1) prompt_lengths_tensor = data["input_lengths"] sampling_params = SamplingParams( temperature=1.0, - top_k=0, - top_p=0.0, + top_k=self.cfg["generation"]["top_k"], + top_p=self.cfg["generation"]["top_p"], return_segments=False, return_log_probs=True, num_tokens_to_generate=tokens_to_generate, @@ -1826,7 +1832,7 @@ def generate( ) requests = [] for p, l in zip(prompt_tokens_tensor, prompt_lengths_tensor): - tokenized_prompt = p[:l].cpu().numpy().tolist() + tokenized_prompt = p[:l].cpu().numpy().tolist() detokenized_prompt = self.tokenizer.decode(tokenized_prompt) req = InferenceRequest( prompt=detokenized_prompt, @@ -1835,13 +1841,13 @@ def generate( request_id=inference_engine.get_new_request_id(), ) requests.append(req) - + result = inference_engine.generate(inference_requests=requests) out = { "text": [x.prompt + x.generated_text for x in result], "tokens": [x.prompt_tokens + x.generated_tokens.tolist() for x in result], - "logprobs" : [x.prompt_log_probs + x.generated_log_probs for x in result] + "logprobs": [x.prompt_log_probs + x.generated_log_probs for x in result], } input_lengths = data["input_lengths"] diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 9de07d28b..aa78afc35 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -23,6 +23,7 @@ time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/grpo.sh time uv run --no-sync bash ./tests/functional/grpo_async.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh +time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh time uv run --no-sync bash ./tests/functional/dpo.sh diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 33fe4f35a..c855c5d1a 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -538,7 +538,6 @@ def generation_setup(request, tiny_llama_model_path): cluster.shutdown() -@pytest.mark.skip(reason="Skipping megatron generation tests for now") @pytest.mark.timeout(240) @pytest.mark.parametrize( "generation_setup", From ef3b11fc60ce2bdd3663e8ea4cdfc6d3f725ae26 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 7 Nov 2025 16:53:25 -0800 Subject: [PATCH 05/10] Fixed error --- nemo_rl/models/policy/megatron_policy_worker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 7c44da27d..42ac914e2 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1831,8 +1831,10 @@ def generate( return_prompt_top_n_logprobs=False, ) requests = [] - for p, l in zip(prompt_tokens_tensor, prompt_lengths_tensor): - tokenized_prompt = p[:l].cpu().numpy().tolist() + for p, prompt_len in zip( + prompt_tokens_tensor, prompt_lengths_tensor, strict=True + ): + tokenized_prompt = p[:prompt_len].cpu().numpy().tolist() detokenized_prompt = self.tokenizer.decode(tokenized_prompt) req = InferenceRequest( prompt=detokenized_prompt, From eefecfeda9b1f40e11544a487eacbfdf3ccde17b Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Sun, 9 Nov 2025 10:07:24 -0800 Subject: [PATCH 06/10] Fixed error --- ...-1b-instruct-1n8g-megatron_generation.yaml | 36 +++++++++++++++++++ ....2-1b-instruct-1n8g-megatron_generation.sh | 0 2 files changed, 36 insertions(+) create mode 100644 examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml mode change 100644 => 100755 tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml new file mode 100644 index 000000000..bb641388d --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml @@ -0,0 +1,36 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + max_num_steps: 500 +checkpointing: + enabled: false + checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron_generation + save_period: 100 +policy: + model_name: meta-llama/Llama-3.2-1B-Instruct + tokenizer: + name: meta-llama/Llama-3.2-1B-Instruct + optimizer: null + megatron_cfg: + enabled: true + scheduler: + lr_warmup_iters: 50 + dtensor_cfg: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + backend: megatron + max_new_tokens: 512 + vllm_cfg: + max_model_len: 512 +data: + max_input_seq_length: 512 +logger: + log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron_generation + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-llama3.2-1b-instruct-1n8g-megatron_generation +cluster: + gpus_per_node: 8 + diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh old mode 100644 new mode 100755 From b1d55e52e4d2ade67b8bcfa920c2ee4523aef8fd Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Sun, 9 Nov 2025 18:02:00 -0800 Subject: [PATCH 07/10] Fixed error --- nemo_rl/models/policy/megatron_policy_worker.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 42ac914e2..d1ca9d40e 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1820,10 +1820,20 @@ def generate( prompt_tokens_tensor = torch.cat([input_ids, padding], dim=1) prompt_lengths_tensor = data["input_lengths"] + # Handle None values for top_k - convert to integer as required by Megatron + top_k_cfg = self.cfg["generation"]["top_k"] + top_k_val = 1 if greedy else (int(top_k_cfg) if top_k_cfg is not None else 0) + + # Use temperature 0.0 for greedy, 1.0 otherwise + temperature = 0.0 if greedy else 1.0 + + top_p_cfg = self.cfg["generation"]["top_p"] + top_p_val = float(top_p_cfg) if top_p_cfg is not None else 0.0 + sampling_params = SamplingParams( - temperature=1.0, - top_k=self.cfg["generation"]["top_k"], - top_p=self.cfg["generation"]["top_p"], + temperature=temperature, + top_k=top_k_val, + top_p=top_p_val, return_segments=False, return_log_probs=True, num_tokens_to_generate=tokens_to_generate, From 287e57dbce9743a335e78e8befe4a0d4f7e13b98 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Sun, 9 Nov 2025 20:15:45 -0800 Subject: [PATCH 08/10] Fixed error --- nemo_rl/models/policy/megatron_policy_worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index d1ca9d40e..426d780ae 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1828,7 +1828,9 @@ def generate( temperature = 0.0 if greedy else 1.0 top_p_cfg = self.cfg["generation"]["top_p"] - top_p_val = float(top_p_cfg) if top_p_cfg is not None else 0.0 + top_p_val = ( + 0.0 if greedy else (float(top_p_cfg) if top_p_cfg is not None else 0.0) + ) sampling_params = SamplingParams( temperature=temperature, From 89466fb27fbbb28c767c8ae6caa30d26e7051008 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Sun, 9 Nov 2025 22:09:49 -0800 Subject: [PATCH 09/10] Fixed error --- tests/unit/models/policy/test_megatron_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index c855c5d1a..b27dfb706 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -472,6 +472,7 @@ def generation_setup(request, tiny_llama_model_path): tiny_llama_model_path, tp=tp, pp=pp, + precision="bfloat16", # FlashAttention requires fp16 or bf16 generation_backend=generation_backend, ) From d26cc21a2f879ebf110c3beb38a144db7e7db6eb Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Mon, 10 Nov 2025 16:00:34 -0800 Subject: [PATCH 10/10] Removing unwanted padding in static batching --- nemo_rl/models/policy/megatron_policy_worker.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 426d780ae..cd67cbdf2 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1811,13 +1811,7 @@ def generate( 1 ) - padding = torch.full( - (input_ids.shape[0], tokens_to_generate), - self.megatron_tokenizer.eod_id, - dtype=input_ids.dtype, - device=input_ids.device, - ) - prompt_tokens_tensor = torch.cat([input_ids, padding], dim=1) + prompt_tokens_tensor = input_ids prompt_lengths_tensor = data["input_lengths"] # Handle None values for top_k - convert to integer as required by Megatron