Skip to content

Commit fe7f637

Browse files
committed
add-openclaw-training
1 parent 504c5e4 commit fe7f637

File tree

3 files changed

+431
-0
lines changed

3 files changed

+431
-0
lines changed
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
A one-to-many proxy server.
5+
6+
------------------------------------------------------
7+
------------------------------------------------------
8+
When this server initialize, do the following things:
9+
1. connect to swarm server, sync training config and start engine with specified AgentJetJob, wait until the engine is ready.
10+
```
11+
SWARM_URL = "http://localhost:10086"
12+
num_repeat = 8
13+
swarm_client = SwarmClient(SWARM_URL)
14+
swarm_client.auto_sync_train_config_and_start_engine(
15+
AgentJetJob(
16+
algorithm="grpo",
17+
project_name="ajet-swarm",
18+
experiment_name="test",
19+
n_gpu=8,
20+
model='/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct',
21+
batch_size=16,
22+
num_repeat=num_repeat,
23+
),
24+
)
25+
```
26+
27+
2. read a user preference: 我希望我的助手足够幽默
28+
29+
30+
------------------------------------------------------
31+
------------------------------------------------------
32+
This server do the following things when receiving a LLM request:
33+
0. init task:
34+
Task(task_id="{a_random_uuid}, main_query="{user_request_message}")
35+
1. keep a record of user requests, ordered by arrival time, and assign a unique request_id to each request.
36+
2. repeat `num_repeat` times (in parallel):
37+
2-1. get episode base-url and api-key of this `repeat`:
38+
```
39+
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=120)
40+
```
41+
2-2. proxy chat completion request to the base-url with api-key, store (episode_response, episode_uuid).
42+
43+
3. when all `num_repeat` episodes finish:
44+
3-1. compare all episode_responses, generate score (-1 ~ +1) for each episode_response (write a random score generator for now)
45+
3-2. now we should have (episode_response_array, episode_uuid_array, episode_relative_reward_array)
46+
3-3. run for loop:
47+
for ... in (episode_response_array, episode_uuid_array, episode_relative_reward_array)
48+
workflow_output = WorkflowOutput(
49+
reward=relative_reward,
50+
metadata={},
51+
)
52+
swarm_worker.end_episode(task, episode_uuid, workflow_output)
53+
54+
4. select the episode_response with the highest score, return it to user (pretend it is a stream, although it is actually not).
55+
56+
5. end this request.
57+
58+
59+
------------------------
60+
61+
for swarm api, refer to tutorial/example_math_swarm/math.py
62+
63+
64+
# python -m ajet.tuner_lib.experimental.oai_model_one2many
65+
# python -m ajet.tuner_lib.experimental.oai_model_one2many_client
66+
67+
68+
"""
69+
70+
import os
71+
import uuid
72+
import random
73+
import asyncio
74+
import httpx
75+
import time
76+
from contextlib import asynccontextmanager
77+
from typing import Dict, List, Tuple, Optional
78+
from fastapi import FastAPI, Request, HTTPException
79+
from fastapi.responses import StreamingResponse
80+
from loguru import logger
81+
from pydantic import BaseModel
82+
83+
from ajet.schema.task import Task, WorkflowOutput
84+
from ajet.copilot.job import AgentJetJob
85+
from ajet.tuner_lib.experimental.swarm_client import SwarmClient
86+
from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
87+
from ajet.tuner_lib.experimental.interchange_utils import (
88+
ClaimEpisodeRequest,
89+
ClaimEpisodeResponse,
90+
)
91+
92+
93+
SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086")
94+
NUM_REPEAT = int(os.getenv("NUM_REPEAT", "8"))
95+
USER_REQUEST_RECORD: List[Dict] = []
96+
REQUEST_COUNTER = 0
97+
USER_PREFERENCE = "我希望我的助手足够幽默"
98+
99+
swarm_client: Optional[SwarmClient] = SwarmClient(SWARM_URL)
100+
101+
102+
@asynccontextmanager
103+
async def lifespan(app: FastAPI):
104+
global swarm_client
105+
logger.info(f"Initializing swarm client with URL: {SWARM_URL}")
106+
swarm_client = SwarmClient(SWARM_URL)
107+
108+
ajet_job = AgentJetJob(
109+
algorithm="grpo",
110+
project_name="ajet-swarm",
111+
experiment_name="test",
112+
n_gpu=8,
113+
model='/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct',
114+
batch_size=16,
115+
num_repeat=NUM_REPEAT,
116+
)
117+
118+
logger.info(f"Syncing train config and starting engine with num_repeat={NUM_REPEAT}")
119+
120+
import threading
121+
def start_engine_background():
122+
try:
123+
swarm_client.auto_sync_train_config_and_start_engine( # type: ignore[union-attr]
124+
ajet_job,
125+
force_restart=False,
126+
)
127+
logger.info("Swarm engine is ready!")
128+
except Exception as e:
129+
logger.warning(f"Engine auto-sync skipped or failed: {e}")
130+
131+
engine_thread = threading.Thread(target=start_engine_background, daemon=True)
132+
engine_thread.start()
133+
134+
yield
135+
136+
137+
app = FastAPI(title="One-to-Many Proxy Server", lifespan=lifespan)
138+
139+
140+
141+
class ChatCompletionRequest(BaseModel):
142+
model: str = "fill_whatever_model"
143+
messages: List[Dict[str, str]]
144+
stream: bool = False
145+
146+
147+
class EpisodeResult(BaseModel):
148+
episode_uuid: str
149+
response: Dict
150+
reward: float
151+
152+
153+
async def proxy_chat_completion(
154+
base_url: str,
155+
api_key: str,
156+
request_data: ChatCompletionRequest
157+
) -> Dict:
158+
headers = {
159+
"Authorization": f"Bearer {api_key}",
160+
"Content-Type": "application/json",
161+
"Connection": "close",
162+
}
163+
164+
# Force stream=False for internal requests
165+
request_dict = request_data.model_dump()
166+
request_dict["stream"] = False
167+
168+
async with httpx.AsyncClient(timeout=300.0) as client:
169+
resp = await client.post(
170+
f"{base_url}/chat/completions",
171+
json=request_dict,
172+
headers=headers,
173+
)
174+
resp.raise_for_status()
175+
return resp.json()
176+
177+
178+
def generate_random_score() -> float:
179+
return random.uniform(-1.0, 1.0)
180+
181+
182+
def begin_episode_direct(swarm_client: SwarmClient, discard_episode_timeout: int = 120, max_retries: int = 10) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
183+
"""Custom begin_episode that doesn't replace the base_url host."""
184+
185+
for attempt in range(max_retries):
186+
try:
187+
req_obj = ClaimEpisodeRequest(
188+
client_uuid=swarm_client.client_uuid,
189+
episode_type="train",
190+
discard_episode_timeout=discard_episode_timeout,
191+
throttle_policy=None
192+
)
193+
resp = swarm_client._http_client.post(
194+
f"{swarm_client.server_url}/claim_episode",
195+
json=req_obj.model_dump()
196+
)
197+
resp.raise_for_status()
198+
data = ClaimEpisodeResponse.model_validate(resp.json())
199+
200+
if data.success:
201+
episode_uuid = data.episode_uuid
202+
openai_base_url = data.openai_base_url
203+
openai_api_key = data.openai_api_key
204+
205+
logger.info(f"Claimed episode {episode_uuid}")
206+
return episode_uuid, OpenaiBaseUrlAndApiKey(
207+
base_url=openai_base_url,
208+
api_key=openai_api_key,
209+
episode_uuid=episode_uuid
210+
)
211+
else:
212+
logger.warning(f"Failed to claim episode: {data.fail_cause}")
213+
if "No available episodes" in data.fail_cause:
214+
time.sleep(2)
215+
else:
216+
time.sleep(5)
217+
except Exception as e:
218+
logger.error(f"Error claiming episode: {e}")
219+
time.sleep(2)
220+
221+
raise RuntimeError(f"Failed to claim episode after {max_retries} attempts")
222+
223+
224+
async def handle_one2many_request(request_data: ChatCompletionRequest, request_id: str) -> Dict:
225+
global USER_REQUEST_RECORD
226+
227+
task = Task(
228+
task_id=str(uuid.uuid4()),
229+
main_query=request_data.messages[-1]["content"] if request_data.messages else "",
230+
metadata={"user_preference": USER_PREFERENCE}
231+
)
232+
233+
assert swarm_client is not None, "Swarm client not initialized"
234+
235+
USER_REQUEST_RECORD.append({
236+
"request_id": request_id,
237+
"task_id": task.task_id,
238+
"query": task.main_query,
239+
})
240+
241+
async def run_episode(episode_index: int) -> EpisodeResult:
242+
loop = asyncio.get_event_loop()
243+
episode_uuid, api_baseurl_key = await loop.run_in_executor(
244+
None, lambda: begin_episode_direct(swarm_client, 300) # type: ignore[arg-type]
245+
)
246+
247+
try:
248+
response_data = await proxy_chat_completion(
249+
base_url=api_baseurl_key.base_url,
250+
api_key=api_baseurl_key.api_key,
251+
request_data=request_data,
252+
)
253+
254+
reward = generate_random_score()
255+
256+
return EpisodeResult(
257+
episode_uuid=episode_uuid,
258+
response=response_data,
259+
reward=reward,
260+
)
261+
except Exception as e:
262+
logger.error(f"Error in episode {episode_index}: {e}")
263+
swarm_client.abort_episode(episode_uuid) # type: ignore[union-attr]
264+
raise
265+
266+
tasks = [run_episode(i) for i in range(NUM_REPEAT)]
267+
episode_results: List[EpisodeResult | BaseException] = await asyncio.gather(*tasks, return_exceptions=True)
268+
269+
valid_results: List[EpisodeResult] = []
270+
for result in episode_results:
271+
if isinstance(result, Exception):
272+
logger.error(f"Episode failed with exception: {result}")
273+
continue
274+
if isinstance(result, EpisodeResult):
275+
valid_results.append(result)
276+
277+
for result in valid_results:
278+
workflow_output = WorkflowOutput(
279+
reward=result.reward,
280+
metadata={},
281+
)
282+
swarm_client.end_episode(task, result.episode_uuid, workflow_output) # type: ignore[union-attr]
283+
284+
if not valid_results:
285+
raise HTTPException(status_code=500, detail="All episodes failed")
286+
287+
best_result = max(valid_results, key=lambda x: x.reward)
288+
289+
return best_result.response
290+
291+
292+
@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
293+
async def one2many_proxy(request: Request, path: str):
294+
global REQUEST_COUNTER
295+
296+
try:
297+
if request.method == "POST" and path == "chat/completions":
298+
body = await request.json()
299+
request_data = ChatCompletionRequest(**body)
300+
301+
REQUEST_COUNTER += 1
302+
request_id = f"req_{REQUEST_COUNTER}_{uuid.uuid4().hex[:8]}"
303+
304+
logger.info(f"Received chat completion request {request_id}")
305+
306+
if request_data.stream:
307+
response_data = await handle_one2many_request(request_data, request_id)
308+
309+
async def stream_response():
310+
import json
311+
yield f"data: {json.dumps(response_data)}\n\n"
312+
yield "data: [DONE]\n\n"
313+
314+
return StreamingResponse(stream_response(), media_type="text/event-stream")
315+
else:
316+
response_data = await handle_one2many_request(request_data, request_id)
317+
return response_data
318+
else:
319+
raise HTTPException(status_code=404, detail="Not Found")
320+
321+
except httpx.TimeoutException:
322+
logger.error(f"Timeout proxying {request.method} {path}")
323+
raise HTTPException(status_code=504, detail="Gateway Timeout")
324+
325+
except httpx.ConnectError:
326+
logger.error(f"Connection error proxying {request.method} {path}")
327+
raise HTTPException(status_code=502, detail="Bad Gateway")
328+
329+
except Exception as e:
330+
logger.exception(f"Unexpected error proxying {request.method} {path}: {e}")
331+
raise HTTPException(status_code=500, detail="Internal Server Error")
332+
333+
334+
@app.get("/health")
335+
async def health_check():
336+
return {"status": "healthy", "user_preference": USER_PREFERENCE}
337+
338+
339+
@app.get("/requests")
340+
async def get_requests():
341+
return {"requests": USER_REQUEST_RECORD}
342+
343+
344+
if __name__ == "__main__":
345+
import uvicorn
346+
uvicorn.run(app, host="0.0.0.0", port=8000)

0 commit comments

Comments
 (0)