4848 max_prompt_length = 16000 , # at least 16000
4949 max_response_length = 8000 ,
5050 max_model_len = 24000 , # bigger than / equal to `max_prompt_length + max_response_length`
51+ max_response_length_in_one_turn = 4000 ,
5152)
5253
5354class EpisodeResult (BaseModel ):
@@ -102,6 +103,29 @@ async def proxy_chat_completion(base_url: str, api_key: str, request: Request, i
102103 return resp .json ()
103104
104105
106+ def _check_finish_reason_length (response_data : Dict | List [bytes ]) -> bool :
107+ """Return True if any choice has finish_reason='length'."""
108+ if isinstance (response_data , list ):
109+ for raw in response_data :
110+ line = raw .decode () if isinstance (raw , bytes ) else raw
111+ if not line .startswith ("data:" ):
112+ continue
113+ payload = line [len ("data:" ):].strip ()
114+ if payload == "[DONE]" :
115+ break
116+ try :
117+ chunk = json .loads (payload )
118+ finish_reason = chunk .get ("choices" , [{}])[0 ].get ("finish_reason" )
119+ if finish_reason == "length" :
120+ return True
121+ except Exception :
122+ pass
123+ return False
124+ else :
125+ choices = response_data .get ("choices" , [])
126+ return any (c .get ("finish_reason" ) == "length" for c in choices )
127+
128+
105129async def run_single_episode (episode_index : int , request : Request , is_stream : bool ) -> EpisodeResult :
106130 """Run a single episode."""
107131 assert swarm_client is not None
@@ -113,6 +137,18 @@ async def run_single_episode(episode_index: int, request: Request, is_stream: bo
113137 request = request ,
114138 is_stream = is_stream ,
115139 )
140+ if _check_finish_reason_length (response_data ):
141+ raise HTTPException (
142+ status_code = 400 ,
143+ detail = {
144+ "error" : {
145+ "message" : "This model's maximum context length is exceeded. Please reduce the length of the messages." ,
146+ "type" : "invalid_request_error" ,
147+ "param" : "messages" ,
148+ "code" : "context_length_exceeded" ,
149+ }
150+ },
151+ )
116152 return EpisodeResult (episode_uuid = episode_uuid , response = response_data )
117153 except Exception as e :
118154 logger .error (f"Error in episode { episode_index } : { e } " )
@@ -126,7 +162,10 @@ async def run_all_episodes(request: Request, is_stream: bool) -> List[EpisodeRes
126162 results = await asyncio .gather (* episode_tasks , return_exceptions = True )
127163 valid_results : List [EpisodeResult ] = []
128164 for result in results :
129- if isinstance (result , Exception ):
165+ if isinstance (result , HTTPException ) and result .status_code == 400 :
166+ # Propagate context_length_exceeded directly to client
167+ raise result
168+ elif isinstance (result , Exception ):
130169 logger .warning (f"Episode failed: { result } " )
131170 elif isinstance (result , EpisodeResult ):
132171 valid_results .append (result )
@@ -195,29 +234,19 @@ def start_engine_background():
195234async def one2many_proxy (request : Request , path : str ):
196235 """Main proxy endpoint."""
197236 global REQUEST_COUNTER
198- try :
199- if request .method == "POST" and path == "chat/completions" :
200- REQUEST_COUNTER += 1
201- request_id = f"req_{ REQUEST_COUNTER } _{ uuid .uuid4 ().hex [:8 ]} "
202- logger .info (f"Received chat completion request { request_id } " )
203- response_data = await handle_one2many_request (request , request_id )
204- if isinstance (response_data , list ):
205- async def stream_chunks (chunks : List [bytes ]):
206- for chunk in chunks :
207- yield chunk + b"\n \n "
208- return StreamingResponse (stream_chunks (response_data ), media_type = "text/event-stream" )
209- return response_data
210- else :
211- raise HTTPException (status_code = 404 , detail = "Not Found" )
212- except httpx .TimeoutException :
213- logger .error (f"Timeout proxying { request .method } { path } " )
214- raise HTTPException (status_code = 504 , detail = "Gateway Timeout" )
215- except httpx .ConnectError :
216- logger .error (f"Connection error proxying { request .method } { path } " )
217- raise HTTPException (status_code = 502 , detail = "Bad Gateway" )
218- except Exception as e :
219- logger .exception (f"Unexpected error proxying { request .method } { path } : { e } " )
220- raise HTTPException (status_code = 500 , detail = "Internal Server Error" )
237+ if request .method == "POST" and path == "chat/completions" :
238+ REQUEST_COUNTER += 1
239+ request_id = f"req_{ REQUEST_COUNTER } _{ uuid .uuid4 ().hex [:8 ]} "
240+ logger .info (f"Received chat completion request { request_id } " )
241+ response_data = await handle_one2many_request (request , request_id )
242+ if isinstance (response_data , list ):
243+ async def stream_chunks (chunks : List [bytes ]):
244+ for chunk in chunks :
245+ yield chunk + b"\n \n "
246+ return StreamingResponse (stream_chunks (response_data ), media_type = "text/event-stream" )
247+ return response_data
248+ else :
249+ raise HTTPException (status_code = 404 , detail = "Not Found" )
221250
222251
223252@app .get ("/health" )
0 commit comments