Skip to content

Commit f906d2e

Browse files
committed
test skill
1 parent c523a40 commit f906d2e

File tree

6 files changed

+82
-42
lines changed

6 files changed

+82
-42
lines changed

ajet/context_tracker/multiagent_tracking.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -603,26 +603,27 @@ def check_context_token_num_safe(
603603
add_generation_prompt=True,
604604
tokenize=False,
605605
)
606-
length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) # type: ignore
606+
prompt_token_length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) # type: ignore
607607
max_response_length_in_one_turn = self.config.ajet.rollout.max_response_length_in_one_turn
608608
max_model_len: int = self.config.ajet.rollout.max_model_len
609609
max_seq_length: int = max_model_len - max_response_length_in_one_turn
610-
# length: the length of current all previous context
610+
# prompt_token_length: the prompt_token_length of current all previous context
611611
# max_seq_length: max_model_len - max_response_length_in_one_turn
612-
if length < max_seq_length:
612+
if prompt_token_length < max_seq_length:
613613
token_overflow = False
614614
else:
615615
token_overflow = True
616616
if self.should_interrupt_soft_fn():
617617
ret = (False, token_overflow, "externally_interrupted")
618618
elif self.already_mad_flag and self.config.ajet.rollout.agent_madness_termination:
619619
ret = (False, token_overflow, "already_mad")
620-
elif length < max_seq_length:
620+
elif prompt_token_length < max_seq_length:
621621
ret = (
622622
True,
623623
token_overflow,
624-
f"safe[{length} < {max_model_len} - {max_response_length_in_one_turn}]",
624+
f"safe[{prompt_token_length} < {max_model_len} - {max_response_length_in_one_turn}]",
625625
)
626626
else:
627-
ret = (False, token_overflow, "token_overflow")
627+
ret = (False, token_overflow,
628+
f"token_overflow(prompt_token_length.{prompt_token_length}>=max_model_len.{max_model_len}-max_response_length_in_one_turn.{max_response_length_in_one_turn})")
628629
return ret

ajet/copilot/job.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def __init__(
9292
raise ValueError(f"Configuration yaml is absent! {base_yaml_config}")
9393

9494
# Validate: max_prompt_length, max_response_length, max_model_len must all be None or all be non-None
95-
length_params = [max_prompt_length, max_response_length, max_model_len]
95+
length_params = [max_prompt_length, max_response_length, max_model_len, max_response_length_in_one_turn]
9696
if not (all(p is None for p in length_params) or all(p is not None for p in length_params)):
97-
raise ValueError("max_prompt_length, max_response_length, max_model_len must all be None or all be non-None")
97+
raise ValueError("(`max_prompt_length`, `max_response_length`, `max_model_len`, `max_response_length_in_one_turn`) must all be None or all be non-None")
9898

9999
self.config_as_dict: dict = self.build_job_from_yaml(base_yaml_config)
100100
self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict)
@@ -159,6 +159,7 @@ def __init__(
159159

160160

161161
assert self.max_prompt_length + self.max_response_length <= self.max_model_len, "illegal token length"
162+
assert self.max_response_length_in_one_turn <= self.max_response_length
162163
if self.backbone == "trinity":
163164
raise NotImplementedError("Trinity backbone is not yet supported in AgentJetJob.")
164165

ajet/task_rollout/async_llm_bridge.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ async def llm_chat_verl(
8484
add_generation_prompt=True,
8585
tokenize=False,
8686
)
87-
prompt_ids = self.tokenizer(prompt_text)["input_ids"]
87+
prompt_token_ids = self.tokenizer(prompt_text)["input_ids"]
8888

8989
final_res = await self.async_rollout_manager.generate(
9090
request_id=request_id,
91-
prompt_ids=prompt_ids,
91+
prompt_ids=prompt_token_ids,
9292
sampling_params=updated_sampling_params,
9393
)
9494

@@ -135,17 +135,19 @@ async def llm_chat_verl(
135135
max_response_length_in_one_turn = self.config.ajet.rollout.max_response_length_in_one_turn
136136
max_model_len: int = self.config.ajet.rollout.max_model_len
137137
max_seq_length: int = max_model_len - max_response_length_in_one_turn
138-
if len(prompt_ids) < max_seq_length:
138+
if len(prompt_token_ids) >= max_seq_length:
139139
finish_reason = "length"
140140
else:
141141
finish_reason = "stop"
142142
if tool_calls:
143143
finish_reason = "tool_calls"
144144
usage = {
145-
"prompt_tokens": len(prompt_ids),
145+
"prompt_tokens": len(prompt_token_ids),
146146
"completion_tokens": len(token_array), # type: ignore
147-
"total_tokens": len(prompt_ids) + len(token_array), # type: ignore
147+
"total_tokens": len(prompt_token_ids) + len(token_array), # type: ignore
148148
}
149+
print("====----====usage", usage)
150+
print("====----====finish_reason", finish_reason)
149151
return {
150152
"role": "assistant",
151153
"request_id": request_id,
@@ -243,7 +245,7 @@ async def main():
243245
max_response_length_in_one_turn = self.config.ajet.rollout.max_response_length_in_one_turn
244246
max_model_len: int = self.config.ajet.rollout.max_model_len
245247
max_seq_length: int = max_model_len - max_response_length_in_one_turn
246-
if len(prompt_token_ids) < max_seq_length:
248+
if len(prompt_token_ids) >= max_seq_length:
247249
finish_reason = "length"
248250
else:
249251
finish_reason = "stop"
@@ -371,7 +373,7 @@ async def run_infer(
371373
if token_overflow:
372374
# ajet_action_when_overflow = self.config.ajet.rollout.ajet_action_when_overflow
373375
# cannot proceed due to context overflow
374-
return self.construct_overflow_response()
376+
return self.construct_overflow_response(info)
375377
# else:
376378
# otherwise, for abnormal output, can still proceed, but we do not track output anymore
377379

@@ -383,12 +385,13 @@ async def run_infer(
383385
return llm_output
384386

385387

386-
def construct_overflow_response(self):
388+
def construct_overflow_response(self, info):
387389
return {
388390
"role": "assistant",
389391
"request_id": "overflow_response",
390-
"content": "ajet_proxy: Exceeded max model context length.",
392+
"content": f"AgentJet: Exceeded max model context length. {info}",
391393
"tool_calls": None,
394+
"finish_reason": "length",
392395
"tokens": [],
393396
}
394397

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 非共享参数多智能体强化学习:学术翻译实战
1+
# 非共享参数多智能体强化学习实战
22

33
在传统的多智能体强化学习(MARL)系统中,所有智能体通常共享同一套模型参数——这意味着无论有多少个智能体,它们都共用一个"大脑"。这种设计虽然简单,但在实际应用中存在明显的局限性:不同智能体可能需要不同规模的模型来执行不同复杂度的任务。AgentJet 的 Swarm 训练模式突破了这一限制,实现了真正的**非共享参数多智能体强化学习**
44

@@ -176,6 +176,11 @@ sequenceDiagram
176176
4. 将各自的奖励汇报给对应的 Swarm Server
177177
5. 两个 Server 独立执行策略梯度更新
178178

179+
## 训练曲线
180+
181+
![alt text](https://img.alicdn.com/imgextra/i2/O1CN0161wtDk1zZwFmIX15x_!!6000000006729-2-tps-2978-1413.png)
182+
183+
179184
## 优势总结
180185

181186
与传统的单模型共享参数训练相比,非共享参数多智能体强化学习具有显著优势:

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ nav:
6464

6565
- • Blogs:
6666
- Swarm Intro (ZH): en/swarm_intro_blog_zh.md
67+
- Multi Model Trainning (ZH): en/example_train_multi_model.zh.md
6768

6869
plugins:
6970
- search:

tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
max_prompt_length=16000, # at least 16000
4949
max_response_length=8000,
5050
max_model_len=24000, # bigger than / equal to `max_prompt_length + max_response_length`
51+
max_response_length_in_one_turn=4000,
5152
)
5253

5354
class EpisodeResult(BaseModel):
@@ -102,6 +103,29 @@ async def proxy_chat_completion(base_url: str, api_key: str, request: Request, i
102103
return resp.json()
103104

104105

106+
def _check_finish_reason_length(response_data: Dict | List[bytes]) -> bool:
107+
"""Return True if any choice has finish_reason='length'."""
108+
if isinstance(response_data, list):
109+
for raw in response_data:
110+
line = raw.decode() if isinstance(raw, bytes) else raw
111+
if not line.startswith("data:"):
112+
continue
113+
payload = line[len("data:"):].strip()
114+
if payload == "[DONE]":
115+
break
116+
try:
117+
chunk = json.loads(payload)
118+
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
119+
if finish_reason == "length":
120+
return True
121+
except Exception:
122+
pass
123+
return False
124+
else:
125+
choices = response_data.get("choices", [])
126+
return any(c.get("finish_reason") == "length" for c in choices)
127+
128+
105129
async def run_single_episode(episode_index: int, request: Request, is_stream: bool) -> EpisodeResult:
106130
"""Run a single episode."""
107131
assert swarm_client is not None
@@ -113,6 +137,18 @@ async def run_single_episode(episode_index: int, request: Request, is_stream: bo
113137
request=request,
114138
is_stream=is_stream,
115139
)
140+
if _check_finish_reason_length(response_data):
141+
raise HTTPException(
142+
status_code=400,
143+
detail={
144+
"error": {
145+
"message": "This model's maximum context length is exceeded. Please reduce the length of the messages.",
146+
"type": "invalid_request_error",
147+
"param": "messages",
148+
"code": "context_length_exceeded",
149+
}
150+
},
151+
)
116152
return EpisodeResult(episode_uuid=episode_uuid, response=response_data)
117153
except Exception as e:
118154
logger.error(f"Error in episode {episode_index}: {e}")
@@ -126,7 +162,10 @@ async def run_all_episodes(request: Request, is_stream: bool) -> List[EpisodeRes
126162
results = await asyncio.gather(*episode_tasks, return_exceptions=True)
127163
valid_results: List[EpisodeResult] = []
128164
for result in results:
129-
if isinstance(result, Exception):
165+
if isinstance(result, HTTPException) and result.status_code == 400:
166+
# Propagate context_length_exceeded directly to client
167+
raise result
168+
elif isinstance(result, Exception):
130169
logger.warning(f"Episode failed: {result}")
131170
elif isinstance(result, EpisodeResult):
132171
valid_results.append(result)
@@ -195,29 +234,19 @@ def start_engine_background():
195234
async def one2many_proxy(request: Request, path: str):
196235
"""Main proxy endpoint."""
197236
global REQUEST_COUNTER
198-
try:
199-
if request.method == "POST" and path == "chat/completions":
200-
REQUEST_COUNTER += 1
201-
request_id = f"req_{REQUEST_COUNTER}_{uuid.uuid4().hex[:8]}"
202-
logger.info(f"Received chat completion request {request_id}")
203-
response_data = await handle_one2many_request(request, request_id)
204-
if isinstance(response_data, list):
205-
async def stream_chunks(chunks: List[bytes]):
206-
for chunk in chunks:
207-
yield chunk + b"\n\n"
208-
return StreamingResponse(stream_chunks(response_data), media_type="text/event-stream")
209-
return response_data
210-
else:
211-
raise HTTPException(status_code=404, detail="Not Found")
212-
except httpx.TimeoutException:
213-
logger.error(f"Timeout proxying {request.method} {path}")
214-
raise HTTPException(status_code=504, detail="Gateway Timeout")
215-
except httpx.ConnectError:
216-
logger.error(f"Connection error proxying {request.method} {path}")
217-
raise HTTPException(status_code=502, detail="Bad Gateway")
218-
except Exception as e:
219-
logger.exception(f"Unexpected error proxying {request.method} {path}: {e}")
220-
raise HTTPException(status_code=500, detail="Internal Server Error")
237+
if request.method == "POST" and path == "chat/completions":
238+
REQUEST_COUNTER += 1
239+
request_id = f"req_{REQUEST_COUNTER}_{uuid.uuid4().hex[:8]}"
240+
logger.info(f"Received chat completion request {request_id}")
241+
response_data = await handle_one2many_request(request, request_id)
242+
if isinstance(response_data, list):
243+
async def stream_chunks(chunks: List[bytes]):
244+
for chunk in chunks:
245+
yield chunk + b"\n\n"
246+
return StreamingResponse(stream_chunks(response_data), media_type="text/event-stream")
247+
return response_data
248+
else:
249+
raise HTTPException(status_code=404, detail="Not Found")
221250

222251

223252
@app.get("/health")

0 commit comments

Comments
 (0)