Skip to content

Commit 09f00fb

Browse files
committed
Add LoRA training support with math agent example
1 parent 135eaed commit 09f00fb

5 files changed

Lines changed: 219 additions & 1 deletion

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
## ✈️ News
2020

21+
- 2026.3.30 LoRA training example is now online! See [tutorial/example_math_lora](tutorial/example_math_lora/) for an example.
2122
- 2026.3.26 Upgrade verl backend to 0.7.1 to support more models and increase training speed! All [benchmark](https://benchmark.agentjet.top/) verified.
2223
- 2026.3.19 Support for latest Qwen3.5 models is [in progress](https://github.com/modelscope/AgentJet/pull/16).
2324
- 2026.3.12 Tuning Original OpenClaw Agent without Editing Any Agent Code. [EN Blog](https://modelscope.github.io/AgentJet/en/example_openclaw/) / [ZH Blog](https://modelscope.github.io/AgentJet/en/example_openclaw.zh/).
@@ -163,7 +164,6 @@ AgentJet is a constantly evolving project. We are planning to add the following
163164

164165
| Category | Feature | Status |
165166
| :--- | :--- | :--- |
166-
| **Examples** | Add LoRA training examples | Todo |
167167
| **Infra** | Optimize configurations for long-context adaptation on smaller GPUs | In Progress |
168168
| **Capability** | Multi-modal training support | Todo |
169169
| **Capability** | MARL Credit assignment | Todo |
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Training a basic math agent
2+
3+
4+
Please refer to document at [`docs/en/example_app_world.md`](docs/en/example_app_world.md)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from agentscope.message import Msg
2+
from loguru import logger
3+
4+
from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask
5+
6+
7+
def extract_final_answer(result) -> str:
8+
"""Extract the final answer from the agent's response."""
9+
try:
10+
if (
11+
hasattr(result, "metadata")
12+
and isinstance(result.metadata, dict)
13+
and "result" in result.metadata
14+
):
15+
return result.metadata["result"]
16+
if hasattr(result, "content"):
17+
if isinstance(result.content, dict) and "result" in result.content:
18+
return result.content["result"]
19+
return str(result.content)
20+
return str(result)
21+
except Exception as e:
22+
logger.warning(f"Extract final answer error: {e}. Raw: {result}")
23+
return str(result)
24+
25+
26+
system_prompt = """
27+
You are an agent specialized in solving math problems with tools.
28+
Please solve the math problem given to you.
29+
You can write and execute Python code to perform calculation or verify your answer.
30+
You should return your final answer within \\boxed{{}}.
31+
"""
32+
33+
34+
class ExampleMathLearn(Workflow):
35+
name: str = "math_agent_workflow"
36+
37+
async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput:
38+
from agentscope.agent import ReActAgent
39+
from agentscope.formatter import DashScopeChatFormatter
40+
from agentscope.memory import InMemoryMemory
41+
from agentscope.tool import Toolkit, execute_python_code
42+
43+
query = workflow_task.task.main_query
44+
self.toolkit = Toolkit()
45+
self.toolkit.register_tool_function(execute_python_code)
46+
self.agent = ReActAgent(
47+
name="math_react_agent",
48+
sys_prompt=system_prompt,
49+
model=tuner.as_agentscope_model(),
50+
formatter=DashScopeChatFormatter(),
51+
toolkit=self.toolkit,
52+
memory=InMemoryMemory(),
53+
max_iters=2,
54+
)
55+
self.agent.set_console_output_enabled(False)
56+
msg = Msg("user", query, role="user")
57+
result = await self.agent.reply(msg)
58+
final_answer = extract_final_answer(result)
59+
return WorkflowOutput(reward=None, metadata={"final_answer": final_answer})
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# ------------------ main configuration ------------------
2+
ajet:
3+
project_name: example_math_agent
4+
task_reader:
5+
type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo`
6+
# effective when `type: huggingface_dat_repo`
7+
huggingface_dat_repo:
8+
dataset_path: '/mnt/data_cpfs/model_cache/modelscope/dataset/openai/gsm8k/main'
9+
training_split: "train"
10+
validation_split: "test"
11+
12+
task_judge:
13+
# ✨✨✨✨ define your evaluation function
14+
judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAsJudge
15+
16+
model:
17+
# ✨✨✨✨ set the model to be trained
18+
path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct
19+
20+
rollout:
21+
user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ write and select workflow
22+
# user_workflow: "tutorial.example_math_agent.math_agent_langchain->ExampleMathLearn" # ✨if you prefer langchain version
23+
# user_workflow: "tutorial/example_math_agent/math_agent_oai_sdk.py->ExampleMathLearn_Simple_NoToolCall" # ✨if you prefer openai sdk version without toolcall
24+
# user_workflow: "tutorial/example_math_agent/math_agent_oai_sdk.py->ExampleMathLearn" # ✨if you prefer openai sdk version with toolcall
25+
# user_workflow: "tutorial/example_math_agent/math_agent_raw_http.py->ExampleMathLearn" # ✨if you do not want to use any agentic framwork at all
26+
# user_workflow: "tutorial/example_math_agent/math_agent_simplify.py->MathToolWorkflow" # ✨if you prefer to compute reward inside workflow
27+
temperature: 1.0
28+
max_env_worker: 64
29+
num_repeat: 6
30+
agent_madness_reward: 0.0
31+
tensor_model_parallel_size: 1
32+
max_num_seqs: 40
33+
multi_turn:
34+
max_sample_per_task: 2
35+
compute_madness_checklist:
36+
- "nonsense"
37+
- "wrong_toolcall"
38+
max_response_length_in_one_turn: 1024
39+
max_model_len: 10000
40+
n_vllm_engine: 2
41+
42+
data:
43+
train_batch_size: 100
44+
max_prompt_length: 3000
45+
max_response_length: 7000
46+
47+
debug:
48+
debug_max_parallel: 1
49+
debug_first_n_tasks: 1
50+
51+
trainer_common:
52+
save_freq: 100
53+
test_freq: 100
54+
total_epochs: 100
55+
logger: swanlab
56+
val_before_train: true
57+
58+
actor_rollout_ref:
59+
model:
60+
lora_rank: 32
61+
lora_alpha: 32
62+
target_modules: all-linear
63+
actor:
64+
optim:
65+
lr: 3e-5
66+
fsdp_config:
67+
param_offload: true
68+
optimizer_offload: true
69+
rollout:
70+
load_format: safetensors
71+
72+
trinity:
73+
synchronizer:
74+
sync_offset: 1
75+
sync_method: nccl
76+
77+
78+
# ------------------ do not modify ------------------
79+
hydra:
80+
searchpath:
81+
- file://ajet/default_config
82+
- file://ajet/default_config/verl
83+
- file://ajet/default_config/trinity
84+
85+
# ------------------ do not modify ------------------
86+
defaults:
87+
- verl_default
88+
- trinity_default
89+
- ajet_default
90+
- _self_
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import re
2+
3+
from ajet.task_judge.base_judge import BaseJudge
4+
from ajet.task_rollout.dashscope_llm_bridge import create_external_llm_fn
5+
from ajet.workflow import WorkflowOutput, WorkflowTask
6+
7+
8+
class MathAnswerAsJudge(BaseJudge):
9+
def __init__(self, config):
10+
self.config = config
11+
12+
def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> tuple:
13+
raw_reward = 0
14+
final_answer = workflow_output.metadata[
15+
"final_answer"
16+
] # By default there's no final_answer; register it by calling ajet_proxy.update_judge_input_dictionary(final_answer=final_answer) in the workflow
17+
reference_answer = workflow_task.task.metadata["answer"]
18+
reference_answer = reference_answer.split("####")[-1].strip()
19+
20+
pattern = r"\\boxed\{([^}]*)\}"
21+
match = re.search(pattern, final_answer)
22+
if match:
23+
result = match.group(1)
24+
is_success = result == reference_answer
25+
else:
26+
is_success = False
27+
28+
raw_reward = 1.0 if is_success else 0.0
29+
return raw_reward, is_success
30+
31+
32+
class MathAnswerAndLlmAsJudge(BaseJudge):
33+
def __init__(self, config):
34+
self.config = config
35+
36+
def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> tuple:
37+
raw_reward = 0
38+
final_answer = workflow_output.metadata[
39+
"final_answer"
40+
] # By default there's no final_answer; register it by calling ajet_proxy.update_judge_input_dictionary(final_answer=final_answer) in the workflow
41+
reference_answer = workflow_task.task.metadata["answer"]
42+
reference_answer = reference_answer.split("####")[-1].strip()
43+
44+
external_llm_fn = create_external_llm_fn(
45+
alien_llm_model=self.config.ajet.task_judge.alien_llm_model,
46+
alien_llm_response_length=self.config.ajet.task_judge.alien_llm_response_length,
47+
)
48+
messages = [
49+
{
50+
"role": "system",
51+
"content": "Is my result correct? If correct, say <Correct>, otherwise say <NotCorrect>.",
52+
},
53+
{
54+
"role": "user",
55+
"content": f"Is my result correct?\n\n\n----\nMy result: {final_answer}\n\n\n----\nReal result: {reference_answer}",
56+
},
57+
]
58+
res = external_llm_fn(messages=messages)
59+
if "<Correct>" in res["content"]:
60+
is_success = True
61+
raw_reward = 1.0
62+
else:
63+
is_success = False
64+
raw_reward = 0.0
65+
return raw_reward, is_success

0 commit comments

Comments
 (0)