Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions examples/configs/grpo_adk_gemma.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# GRPO configuration for unique numbers environment
defaults: "grpo_math_8B.yaml"

grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 16
max_rollout_turns: 20
max_num_steps: 100
val_at_start: false

data:
add_system_prompt: false
shuffle: false

checkpointing:
enabled: false
checkpoint_dir: "results/grpo-adk"
metric_name: "val_reward"
higher_is_better: true
keep_top_k: 3
save_period: 10

env:
unique_numbers:
cfg:
max_turns: 15
min_length: 5
max_length: 10
max_integer: 15

logger:
wandb_enabled: True
wandb:
project: "grpo-simulated-adk"
name: "gemma-4b-__NOW__"


policy:
train_global_batch_size: 512
logprob_batch_size: 1
model_name: google/gemma-3-4b-it
dynamic_batching:
enabled: True
sequence_packing:
enabled: False
tokenizer:
name: google/gemma-3-4b-it
chat_template: "{%- if add_bos_token|default(false) %}{{ bos_token }}{% endif %}{% for message in messages %}{% set role = 'model' if message['role'] == 'assistant' else message['role'] %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"

cluster:
gpus_per_node: 8
45 changes: 45 additions & 0 deletions examples/configs/grpo_adk_llama8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# GRPO configuration for unique numbers environment
defaults: "grpo_math_8B.yaml"

grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 16
max_rollout_turns: 20
max_num_steps: 100
val_at_start: false

data:
add_system_prompt: false
shuffle: false

checkpointing:
enabled: false
checkpoint_dir: "results/grpo-adk"
metric_name: "val_reward"
higher_is_better: true
keep_top_k: 3
save_period: 10

env:
unique_numbers:
cfg:
max_turns: 15
min_length: 5
max_length: 10
max_integer: 15

logger:
wandb_enabled: True
wandb:
project: "grpo-simulated-adk"
name: "llama-8b-__NOW__"

policy:
train_global_batch_size: 512
dynamic_batching:
enabled: False
tokenizer:
chat_template: '{%- if add_bos_token|default(false) %}{{ bos_token }}{% endif %}{% for message in messages %}{{ "<|start_header_id|>" + message.role + "<|end_header_id|>\n\n" + message.content | trim + "<|eot_id|>" }}{% endfor %}{% if add_generation_prompt %}{{ "<|start_header_id|>assistant<|end_header_id|>\n\n" }}{% endif %}'

cluster:
gpus_per_node: 8
273 changes: 273 additions & 0 deletions examples/run_grpo_unique_numbers_w_adk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Run GRPO with the Unique Numbers Simulator using ADK.

This script sets up and executes the Group Relative Policy Optimization (GRPO) algorithm
in a multi-turn conversational environment powered by the ADK framework.

### Task Overview
The objective is to train an agent to guess the number of unique integers in a list generated by a simulated user.
The interaction is structured as a turn-based dialogue:
- The user generates a list of integers.
- The agent queries specific positions in the list (by index).
- The user replies with the value at that index (if available).
- The agent continues the interaction until it makes a final guess at the number of unique integers.

### Environment Details
The environment is a simulated user that:
- Randomly generates a list of integers at setup.
- Responds to the agent's queries using an LLM via the ADK endpoint.
- Optionally evaluates the agent's final guess using an LLM-based grader (included for extensibility, though not essential for this task).

### Example Usage
uv run python examples/run_grpo_unique_numbers_w_adk.py

### Requirements
- A working ADK environment with access to a compatible LLM endpoint.
For the default Gemini endpoint, the following environment variables must be set:
- `GOOGLE_GENAI_USE_VERTEXAI=1`
- `GOOGLE_CLOUD_PROJECT="your-project-id"`
- `GOOGLE_CLOUD_LOCATION="your-location"`

- A properly configured GRPO YAML file.
By default, the script uses:
`examples/configs/grpo_adk_llama8b.yaml`
"""

import argparse
import itertools
import os
import pprint
import random
from datetime import datetime, timedelta
from typing import Iterator

from omegaconf import OmegaConf
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer

from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
from nemo_rl.distributed.ray_actor_environment_registry import (
get_actor_python_env,
)
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.environments.simulated_user.prompt import starting_user_prompt
from nemo_rl.environments.simulated_user.unique_numbers import (
UniqueNumbersEnv,
UniqueNumbersMetadata,
)
from nemo_rl.models.generation import configure_generation_config
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir

OmegaConf.register_new_resolver("mul", lambda a, b: a * b)


def parse_args():
parser = argparse.ArgumentParser(
description="Run GRPO with unique numbers simulator"
)
parser.add_argument(
"--config", type=str, default=None, help="Path to YAML config file"
)
args, overrides = parser.parse_known_args()
return args, overrides


def generate_datum(
tokenizer: AutoTokenizer,
env_cfg: dict,
task_name: str,
idx: int,
add_system_prompt: bool,
) -> DatumSpec:
# please check the specific chat_template in the yaml file
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": starting_user_prompt}],
tokenize=False,
# add_system_prompt=add_system_prompt,
add_bos_token=True,
add_generation_prompt=True,
add_special_tokens=False,
)
token_ids = tokenizer(
formatted_prompt, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]

def _generate_numbers(
min_length, max_length, max_integer, default_max_turns
) -> UniqueNumbersMetadata:
length = random.randint(min_length, max_length)
numbers = [random.randint(0, max_integer) for _ in range(length)]
return UniqueNumbersMetadata(
numbers=numbers,
unique_count=len(set(numbers)),
turn=0,
max_turns=default_max_turns,
)

metadata = _generate_numbers(
min_length=env_cfg["cfg"]["min_length"],
max_length=env_cfg["cfg"]["max_length"],
max_integer=env_cfg["cfg"]["max_integer"],
default_max_turns=env_cfg["cfg"]["max_turns"],
)

message_log: LLMMessageLogType = [
{"role": "user", "content": formatted_prompt, "token_ids": token_ids}
]
return {
"message_log": message_log,
"length": len(token_ids),
"extra_env_info": metadata,
"loss_multiplier": 1.0,
"idx": idx,
"task_name": task_name,
}


class IterableNumbersDataset(IterableDataset):
def __init__(self, tokenizer, env_cfg, task_name, add_system_prompt, length):
super().__init__()
self.tokenizer = tokenizer
self.env_cfg = env_cfg
self.task_name = task_name
self.add_system_prompt = add_system_prompt
self.length = length

def __iter__(self) -> Iterator[DatumSpec]:
for i in itertools.count():
yield generate_datum(
tokenizer=self.tokenizer,
env_cfg=self.env_cfg,
task_name=self.task_name,
idx=i,
add_system_prompt=self.add_system_prompt,
)

def __len__(self):
return self.length


def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt):
env_config = env_cfg[task_name]
env = UniqueNumbersEnv.options( # type: ignore # it's wrapped with ray.remote
num_gpus=0,
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv"
),
"env_vars": dict(os.environ), # Pass thru all user environment variables
},
).remote(cfg=dict(env_config["cfg"]))

task_to_env = {task_name: env}

train_ds = IterableNumbersDataset(
tokenizer=tokenizer,
env_cfg=env_config,
task_name=task_name,
add_system_prompt=add_system_prompt,
length=length,
)
val_ds = IterableNumbersDataset(
tokenizer=tokenizer,
env_cfg=env_config,
task_name=task_name,
add_system_prompt=add_system_prompt,
length=val_length,
)
val_task_to_env = task_to_env
return train_ds, val_ds, task_to_env, val_task_to_env


def main():
args, overrides = parse_args()
if not args.config:
args.config = os.path.join(
os.path.dirname(__file__), "configs", "grpo_adk_llama8b.yaml"
)
config = load_config(args.config)
if overrides:
config = parse_hydra_overrides(config, overrides)
config: MasterConfig = OmegaConf.to_container(config, resolve=True)

now_pst = datetime.utcnow() + timedelta(hours=-7)
config["logger"]["wandb"]["name"] = config["logger"]["wandb"]["name"].replace(
"__NOW__", now_pst.strftime("%m/%d-%H:%M")
)

config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
if config["checkpointing"]["enabled"]:
print(
f"\U0001f4ca Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
)

pprint.pprint(config)

init_ray()

tokenizer = get_tokenizer(config["policy"]["tokenizer"])
config["policy"]["generation"] = configure_generation_config(
config["policy"]["generation"], tokenizer
)

ds_length = (
config["grpo"]["num_prompts_per_step"]
* config["grpo"]["num_generations_per_prompt"]
* config["grpo"]["max_num_steps"]
)
dataset, val_dataset, task_to_env, val_task_to_env = setup_data(
tokenizer=tokenizer,
env_cfg=config["env"],
task_name="unique_numbers",
length=ds_length,
val_length=config["grpo"]["max_val_samples"],
add_system_prompt=config["data"]["add_system_prompt"],
)

(
policy,
policy_generation,
cluster,
dataloader,
val_dataloader,
loss_fn,
logger,
checkpointer,
grpo_state,
master_config,
) = setup(config, tokenizer, dataset, val_dataset)

grpo_train(
policy,
policy_generation,
dataloader,
val_dataloader,
tokenizer,
loss_fn,
task_to_env,
val_task_to_env,
logger,
checkpointer,
grpo_state,
master_config,
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions nemo_rl/distributed/ray_actor_environment_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": PY_EXECUTABLES.ADK,
"nemo_rl.environments.tools.retriever.RAGEnvironment": PY_EXECUTABLES.SYSTEM,
}

Expand Down
2 changes: 2 additions & 0 deletions nemo_rl/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class PY_EXECUTABLES:
# aren't installed. Simple workaround is to always run the mcore py_executable with --reinstall.
MCORE = "uv run --reinstall --extra mcore"

ADK = "uv run --locked --extra adk"


@ray.remote # pragma: no cover
def _get_node_ip_and_free_port() -> tuple[str, int]:
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/environments/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class EnvironmentReturn(NamedTuple, Generic[MetadataT]):
next_stop_strings: list[list[str] | None] | list[None]
rewards: Tensor
terminateds: Tensor
answers: list[str | None] | None
answers: list[str | None] | None = None


class EnvironmentInterface(abc.ABC, Generic[MetadataT]):
Expand Down
Empty file.
Loading