diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml new file mode 100644 index 000000000..bb641388d --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml @@ -0,0 +1,36 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + max_num_steps: 500 +checkpointing: + enabled: false + checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron_generation + save_period: 100 +policy: + model_name: meta-llama/Llama-3.2-1B-Instruct + tokenizer: + name: meta-llama/Llama-3.2-1B-Instruct + optimizer: null + megatron_cfg: + enabled: true + scheduler: + lr_warmup_iters: 50 + dtensor_cfg: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + backend: megatron + max_new_tokens: 512 + vllm_cfg: + max_model_len: 512 +data: + max_input_seq_length: 512 +logger: + log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron_generation + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-llama3.2-1b-instruct-1n8g-megatron_generation +cluster: + gpus_per_node: 8 + diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index a9d9a1eb6..9041bfe27 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -511,7 +511,8 @@ def setup( # prepare refit info state_dict_info = policy.prepare_refit_info() - policy_generation.prepare_refit_info(state_dict_info) + if policy_generation is not None: + policy_generation.prepare_refit_info(state_dict_info) loss_fn = ClippedPGLossFn(loss_config) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 3bf211be5..cd67cbdf2 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -75,9 +75,6 @@ from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) -from megatron.core.inference.text_generation_server.run_mcore_engine import ( - run_mcore_engine, -) from megatron.core.models.gpt import GPTModel from megatron.core.optimizer import ChainedOptimizer from megatron.core.parallel_state import ( @@ -1755,6 +1752,12 @@ def generate( """ no_grad = torch.no_grad() no_grad.__enter__() + + if self.should_disable_forward_pre_hook: + self.model = self.move_model( + self.model, "cuda", move_params=True, move_grads=False + ) + self.model.config.flash_decode = True # Verify input is right padded assert isinstance(data, BatchedDataDict), ( @@ -1783,9 +1786,11 @@ def generate( ) from megatron.core.inference.contexts import StaticInferenceContext + from megatron.core.inference.inference_request import InferenceRequest from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, ) + from megatron.core.inference.sampling_params import SamplingParams inference_context = StaticInferenceContext.from_config(inference_wrapper_config) @@ -1801,21 +1806,57 @@ def generate( max_batch_size=self.cfg["generation_batch_size"], ) - # detokenize the prompts - # detokenized_prompts = [ - # self.tokenizer.decode(prompt) - # for prompt in data.get("input_ids") - # ] - # apply chat template - out = run_mcore_engine( - engine=inference_engine, - # prompts = detokenized_prompts, - prompt_tokens_tensor=data["input_ids"], - prompt_lengths_tensor=data["input_lengths"], - tokens_to_generate=self.cfg["generation"]["max_new_tokens"] # type: ignore - - data["input_ids"].size(1), + input_ids = data["input_ids"] + tokens_to_generate = self.cfg["generation"]["max_new_tokens"] - input_ids.size( + 1 ) - # print(out) + + prompt_tokens_tensor = input_ids + prompt_lengths_tensor = data["input_lengths"] + + # Handle None values for top_k - convert to integer as required by Megatron + top_k_cfg = self.cfg["generation"]["top_k"] + top_k_val = 1 if greedy else (int(top_k_cfg) if top_k_cfg is not None else 0) + + # Use temperature 0.0 for greedy, 1.0 otherwise + temperature = 0.0 if greedy else 1.0 + + top_p_cfg = self.cfg["generation"]["top_p"] + top_p_val = ( + 0.0 if greedy else (float(top_p_cfg) if top_p_cfg is not None else 0.0) + ) + + sampling_params = SamplingParams( + temperature=temperature, + top_k=top_k_val, + top_p=top_p_val, + return_segments=False, + return_log_probs=True, + num_tokens_to_generate=tokens_to_generate, + top_n_logprobs=0, + return_prompt_top_n_logprobs=False, + ) + requests = [] + for p, prompt_len in zip( + prompt_tokens_tensor, prompt_lengths_tensor, strict=True + ): + tokenized_prompt = p[:prompt_len].cpu().numpy().tolist() + detokenized_prompt = self.tokenizer.decode(tokenized_prompt) + req = InferenceRequest( + prompt=detokenized_prompt, + prompt_tokens=tokenized_prompt, + sampling_params=sampling_params, + request_id=inference_engine.get_new_request_id(), + ) + requests.append(req) + + result = inference_engine.generate(inference_requests=requests) + + out = { + "text": [x.prompt + x.generated_text for x in result], + "tokens": [x.prompt_tokens + x.generated_tokens.tolist() for x in result], + "logprobs": [x.prompt_log_probs + x.generated_log_probs for x in result], + } input_lengths = data["input_lengths"] # pad the out "tokens" and "logprobs" and make them into tensors from lists diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 9de07d28b..aa78afc35 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -23,6 +23,7 @@ time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/grpo.sh time uv run --no-sync bash ./tests/functional/grpo_async.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh +time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh time uv run --no-sync bash ./tests/functional/dpo.sh diff --git a/tests/functional/grpo_megatron_generation.sh b/tests/functional/grpo_megatron_generation.sh new file mode 100644 index 000000000..984be7e1a --- /dev/null +++ b/tests/functional/grpo_megatron_generation.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +# Using Qwen2.5-0.5B instead of Qwen3-0.6B because the latter is not supported by Megatron yet +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py \ + --config $PROJECT_ROOT/examples/configs/grpo_math_1B_megatron.yaml \ + policy.model_name=Qwen/Qwen2.5-0.5B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.logprob_batch_size=4 \ + policy.train_micro_batch_size=1 \ + policy.generation.backend=megatron \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/token_mult_prob_error"]) < 1.05' + diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh new file mode 100755 index 000000000..08f57cb5a --- /dev/null +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh @@ -0,0 +1,42 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=500 +MAX_STEPS=500 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + policy.generation.backend=megatron \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.1' \ + 'data["train/token_mult_prob_error"]["500"] < 1.1' \ + 'data["train/reward"]["500"] > 0.1' \ + 'mean(data["timing/train/total_step_time"], -6, -1) < 10.5' +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 91c24aada..24677186e 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -12,6 +12,7 @@ tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh # Megatron tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh # Functional 32b run tests/test_suites/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8-actckpt.v3.sh diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 33fe4f35a..b27dfb706 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -472,6 +472,7 @@ def generation_setup(request, tiny_llama_model_path): tiny_llama_model_path, tp=tp, pp=pp, + precision="bfloat16", # FlashAttention requires fp16 or bf16 generation_backend=generation_backend, ) @@ -538,7 +539,6 @@ def generation_setup(request, tiny_llama_model_path): cluster.shutdown() -@pytest.mark.skip(reason="Skipping megatron generation tests for now") @pytest.mark.timeout(240) @pytest.mark.parametrize( "generation_setup",