From bbe5950ed4aafe3355036f5e6d57888ad75c950e Mon Sep 17 00:00:00 2001 From: CI User Date: Mon, 23 Feb 2026 23:57:00 +0800 Subject: [PATCH 01/47] Fix UTF-8 surrogate handling in HTTP JSON requests Fallback to ASCII-escaped JSON encoding when payload contains lone surrogate code units so streaming/non-stream requests no longer crash before dispatch. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- src/httpx_client.py | 46 ++++++++++++++++++++++++++++++++++++++++---- test_httpx_client.py | 13 +++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 test_httpx_client.py diff --git a/src/httpx_client.py b/src/httpx_client.py index 9f0db35b2..95c355f27 100644 --- a/src/httpx_client.py +++ b/src/httpx_client.py @@ -5,6 +5,7 @@ """ from contextlib import asynccontextmanager +import json as jsonlib from typing import Any, AsyncGenerator, Dict, Optional import httpx @@ -16,9 +17,13 @@ class HttpxClientManager: """通用HTTP客户端管理器""" - async def get_client_kwargs(self, timeout: float = 30.0, **kwargs) -> Dict[str, Any]: + async def get_client_kwargs( + self, timeout: Optional[float] = 30.0, **kwargs + ) -> Dict[str, Any]: """获取httpx客户端的通用配置参数""" - client_kwargs = {"timeout": timeout, **kwargs} + client_kwargs = {**kwargs} + if timeout is not None: + client_kwargs["timeout"] = timeout # 动态读取代理配置,支持热更新 current_proxy_config = await get_proxy_config() @@ -39,7 +44,7 @@ async def get_client( @asynccontextmanager async def get_streaming_client( - self, timeout: float = None, **kwargs + self, timeout: Optional[float] = None, **kwargs ) -> AsyncGenerator[httpx.AsyncClient, None]: """获取用于流式请求的HTTP客户端(无超时限制)""" client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) @@ -60,6 +65,24 @@ async def get_streaming_client( http_client = HttpxClientManager() +def _encode_json_body_safely(payload: Any) -> bytes: + try: + return jsonlib.dumps( + payload, + ensure_ascii=False, + separators=(",", ":"), + allow_nan=False, + ).encode("utf-8") + except UnicodeEncodeError: + log.warning("检测到孤立代理项,使用ASCII转义发送JSON请求体") + return jsonlib.dumps( + payload, + ensure_ascii=True, + separators=(",", ":"), + allow_nan=False, + ).encode("utf-8") + + # 通用的异步方法 async def get_async( url: str, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0, **kwargs @@ -79,6 +102,15 @@ async def post_async( ) -> httpx.Response: """通用异步POST请求""" async with http_client.get_client(timeout=timeout, **kwargs) as client: + if json is not None and data is None: + request_headers = dict(headers or {}) + request_headers.setdefault("Content-Type", "application/json") + safe_json_bytes = _encode_json_body_safely(json) + return await client.post( + url, + content=safe_json_bytes, + headers=request_headers, + ) return await client.post(url, data=data, json=json, headers=headers) @@ -91,10 +123,16 @@ async def stream_post_async( ): """流式异步POST请求""" async with http_client.get_streaming_client(**kwargs) as client: - async with client.stream("POST", url, json=body, headers=headers) as r: + request_headers = dict(headers or {}) + request_headers.setdefault("Content-Type", "application/json") + safe_json_bytes = _encode_json_body_safely(body) + async with client.stream( + "POST", url, content=safe_json_bytes, headers=request_headers + ) as r: # 错误直接返回 if r.status_code != 200: from fastapi import Response + yield Response(await r.aread(), r.status_code, dict(r.headers)) return diff --git a/test_httpx_client.py b/test_httpx_client.py new file mode 100644 index 000000000..e5b3b38fc --- /dev/null +++ b/test_httpx_client.py @@ -0,0 +1,13 @@ +from src.httpx_client import _encode_json_body_safely + + +def test_encode_json_body_safely_keeps_valid_unicode(): + payload = {"x": "😀"} + encoded = _encode_json_body_safely(payload) + assert encoded.decode("utf-8") == '{"x":"😀"}' + + +def test_encode_json_body_safely_falls_back_on_lone_surrogate(): + payload = {"x": "\ud83d"} + encoded = _encode_json_body_safely(payload) + assert encoded.decode("utf-8") == '{"x":"\\ud83d"}' From 2882d14271d986c67b752e145c7d1f8070039db2 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Tue, 24 Feb 2026 11:01:23 +0800 Subject: [PATCH 02/47] =?UTF-8?q?=E5=8F=8D=E9=87=8D=E5=8A=9B=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E9=99=84=E5=8A=A0=E9=BB=98=E8=AE=A4=E7=B3=BB=E7=BB=9F?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- log.py | 20 +++++++++++++++----- src/converter/gemini_fix.py | 3 +++ src/utils.py | 3 ++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/log.py b/log.py index a43f1e2df..ea944a741 100644 --- a/log.py +++ b/log.py @@ -34,14 +34,17 @@ # ----------------------------------------------------------------- _cached_log_level: int = LOG_LEVELS["info"] _cached_log_file: str = "log.txt" +# ENABLE_LOG=0/false/no/off 时彻底关闭日志 +_log_enabled: bool = True def _refresh_config(): """从环境变量刷新缓存配置(模块加载时及需要时调用)""" - global _cached_log_level, _cached_log_file + global _cached_log_level, _cached_log_file, _log_enabled level = os.getenv("LOG_LEVEL", "info").lower() _cached_log_level = LOG_LEVELS.get(level, LOG_LEVELS["info"]) _cached_log_file = os.getenv("LOG_FILE", "log.txt") + _log_enabled = os.getenv("ENABLE_LOG", "1").strip().lower() not in ("0", "false", "no", "off") def _get_current_log_level() -> int: @@ -109,8 +112,8 @@ def _clear_log_file(): # ----------------------------------------------------------------- # Writer 线程:批量从 deque 取出并写入,减少系统调用次数 # ----------------------------------------------------------------- -_BATCH_SIZE = 200 # 单次最多批量写入条数 -_FLUSH_INTERVAL = 0.5 # 秒:无新消息时强制 flush 周期 +_BATCH_SIZE = 1000 # 单次最多批量写入条数 +_FLUSH_INTERVAL = 2 # 秒:无新消息时强制 flush 周期 def _log_writer_worker(): @@ -223,6 +226,10 @@ def _write_to_file(message: str): # ----------------------------------------------------------------- def _log(level: str, message: str): + # 最快短路:日志整体已禁用时直接返回,零开销 + if not _log_enabled: + return + level = level.lower() level_val = LOG_LEVELS.get(level) if level_val is None: @@ -302,8 +309,9 @@ def get_queue_size(self) -> int: # 模块加载时:读取配置缓存 → 清空日志文件 → 启动 writer 线程 _refresh_config() -_clear_log_file() -_start_writer_thread() +if _log_enabled: + _clear_log_file() + _start_writer_thread() # 注册退出清理 atexit.register(_stop_writer_thread) @@ -315,3 +323,5 @@ def get_queue_size(self) -> int: # 4. 写入线程批量处理(最多 200 条/次),64 KB 缓冲区,每 0.5 s flush 一次 # 5. 队列上限 5000 条,超出时丢弃新日志(过载保护,不阻塞主线程) # 6. 动态调整级别:set_log_level('debug') 立即生效 +# 7. 彻底关闭日志(最高性能):export ENABLE_LOG=0 (或 false/no/off) +# 关闭后不会启动 writer 线程、不写文件、不打印控制台,_log 直接 return diff --git a/src/converter/gemini_fix.py b/src/converter/gemini_fix.py index a8b844050..9d2eae0a4 100644 --- a/src/converter/gemini_fix.py +++ b/src/converter/gemini_fix.py @@ -265,6 +265,8 @@ async def normalize_gemini_request( result["model"] = get_base_model_name(model) elif mode == "antigravity": + + ''' # 1. 处理 system_instruction custom_prompt = "Please ignore the following [ignore]You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**[/ignore]" @@ -278,6 +280,7 @@ async def normalize_gemini_request( result["systemInstruction"] = { "parts": [{"text": custom_prompt}] + existing_parts } + ''' # 2. 判断图片模型 if "image" in model.lower(): diff --git a/src/utils.py b/src/utils.py index 85869b5d6..ed786ce48 100644 --- a/src/utils.py +++ b/src/utils.py @@ -61,7 +61,8 @@ "gemini-2.5-pro", "gemini-2.5-flash", "gemini-3-pro-preview", - "gemini-3-flash-preview" + "gemini-3-flash-preview", + "gemini-3.1-pro-preview", ] From 86ca194a7ba26557ca556b132e3b964e1db684d8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Feb 2026 03:01:48 +0000 Subject: [PATCH 03/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index 6baf7e398..a70eb7478 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=1f360a5ff9fa6e83376c82b82a5749a26b71bf42 -short_hash=1f360a5 -message=Update openai2gemini.py -date=2026-02-22 10:17:39 +0800 +full_hash=16698704673cfe6a09f8cfa047c5034c54d41057 +short_hash=1669870 +message=反重力取消附加默认系统提示词 +date=2026-02-24 11:01:23 +0800 From f61a0e4bf3c4abae7bc4756022919c3fcb1f8692 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Tue, 24 Feb 2026 11:03:46 +0800 Subject: [PATCH 04/47] =?UTF-8?q?=E4=BC=98=E5=8C=96=E9=87=8D=E8=AF=95?= =?UTF-8?q?=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/antigravity.py | 53 ++++++++++++++++++++++++++++-------------- src/api/geminicli.py | 53 ++++++++++++++++++++++++++++-------------- 2 files changed, 70 insertions(+), 36 deletions(-) diff --git a/src/api/antigravity.py b/src/api/antigravity.py index bf563028e..a02ff3296 100644 --- a/src/api/antigravity.py +++ b/src/api/antigravity.py @@ -166,6 +166,7 @@ async def refresh_credential_fast(): for attempt in range(max_retries + 1): success_recorded = False # 标记是否已记录成功 need_retry = False # 标记是否需要重试 + keep_current_credential = False # 标记是否保留当前凭证重试(无cd的429) try: async for chunk in stream_post_async( @@ -190,15 +191,7 @@ async def refresh_credential_fast(): if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: log.warning(f"[ANTIGRAVITY STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") - # 并行预热下一个凭证,不阻塞当前处理 - if next_cred_task is None and attempt < max_retries: - next_cred_task = asyncio.create_task( - credential_manager.get_valid_credential( - mode="antigravity", model_name=model_name - ) - ) - - # 记录错误 + # 先解析冷却时间,再决定是否切换凭证 cooldown_until = None if (status_code == 429 or status_code == 503) and error_body: # 使用已缓存的error_body解析冷却时间 @@ -207,6 +200,16 @@ async def refresh_credential_fast(): except Exception: pass + # 对于没有触发cd的429错误,保留当前凭证重试;否则预热下一个凭证 + if status_code == 429 and cooldown_until is None: + keep_current_credential = True + elif next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="antigravity", model_name=model_name + ) + ) + await record_api_call_error( credential_manager, current_file, status_code, cooldown_until, mode="antigravity", model_name=model_name, @@ -284,6 +287,12 @@ async def refresh_credential_fast(): if need_retry: log.info(f"[ANTIGRAVITY STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + # 对于没有冷却时间的429错误,保留当前凭证重试 + if keep_current_credential: + log.info(f"[ANTIGRAVITY STREAM] 429无冷却时间,保留当前凭证重试: {current_file}") + await asyncio.sleep(retry_interval) + continue + # 使用预热的凭证任务,避免等待 if next_cred_task is not None: try: @@ -503,15 +512,7 @@ async def refresh_credential_fast(): if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: log.warning(f"[ANTIGRAVITY] 非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") - # 并行预热下一个凭证,不阻塞当前处理 - if next_cred_task is None and attempt < max_retries: - next_cred_task = asyncio.create_task( - credential_manager.get_valid_credential( - mode="antigravity", model_name=model_name - ) - ) - - # 记录错误 + # 先解析冷却时间,再决定是否切换凭证 cooldown_until = None if status_code == 429 or status_code == 503 and error_text: # 使用已缓存的error_text解析冷却时间 @@ -520,6 +521,16 @@ async def refresh_credential_fast(): except Exception: pass + # 对于没有触发cd的429错误,不预热新凭证 + if not (status_code == 429 and cooldown_until is None): + # 并行预热下一个凭证,不阻塞当前处理 + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="antigravity", model_name=model_name + ) + ) + await record_api_call_error( credential_manager, current_file, status_code, cooldown_until, mode="antigravity", model_name=model_name, @@ -535,6 +546,12 @@ async def refresh_credential_fast(): if should_retry and attempt < max_retries: need_retry = True + + # 对于没有冷却时间的429错误,保留当前凭证重试 + if status_code == 429 and cooldown_until is None: + log.info(f"[ANTIGRAVITY] 429无冷却时间,保留当前凭证重试: {current_file}") + await asyncio.sleep(retry_interval) + continue else: # 不重试,直接返回原始错误 log.error(f"[ANTIGRAVITY] 达到最大重试次数或不应重试,返回原始错误") diff --git a/src/api/geminicli.py b/src/api/geminicli.py index fc79e5a74..d08da054e 100644 --- a/src/api/geminicli.py +++ b/src/api/geminicli.py @@ -174,6 +174,7 @@ async def refresh_credential_fast(): for attempt in range(max_retries + 1): success_recorded = False # 标记是否已记录成功 need_retry = False # 标记是否需要重试 + keep_current_credential = False # 标记是否保留当前凭证重试(无cd的429) try: async for chunk in stream_post_async( @@ -198,15 +199,7 @@ async def refresh_credential_fast(): if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: log.warning(f"[GEMINICLI STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") - # 并行预热下一个凭证,不阻塞当前处理 - if next_cred_task is None and attempt < max_retries: - next_cred_task = asyncio.create_task( - credential_manager.get_valid_credential( - mode="geminicli", model_name=model_name - ) - ) - - # 记录错误 + # 先解析冷却时间,再决定是否切换凭证 cooldown_until = None if (status_code == 429 or status_code == 503) and error_body: # 使用已缓存的error_body解析冷却时间 @@ -215,6 +208,16 @@ async def refresh_credential_fast(): except Exception: pass + # 对于没有触发cd的429错误,保留当前凭证重试;否则预热下一个凭证 + if status_code == 429 and cooldown_until is None: + keep_current_credential = True + elif next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + ) + await record_api_call_error( credential_manager, current_file, status_code, cooldown_until, mode="geminicli", model_name=model_name, @@ -303,6 +306,12 @@ async def refresh_credential_fast(): if need_retry: log.info(f"[GEMINICLI STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + # 对于没有冷却时间的429错误,保留当前凭证重试 + if keep_current_credential: + log.info(f"[GEMINICLI STREAM] 429无冷却时间,保留当前凭证重试: {current_file}") + await asyncio.sleep(retry_interval) + continue + # 使用预热的凭证任务,避免等待 if next_cred_task is not None: try: @@ -492,15 +501,7 @@ async def refresh_credential_fast(): if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: log.warning(f"[NON-STREAM] 非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") - # 并行预热下一个凭证,不阻塞当前处理 - if next_cred_task is None and attempt < max_retries: - next_cred_task = asyncio.create_task( - credential_manager.get_valid_credential( - mode="geminicli", model_name=model_name - ) - ) - - # 记录错误 + # 先解析冷却时间,再决定是否切换凭证 cooldown_until = None if (status_code == 429 or status_code == 503) and error_text: # 使用已缓存的error_text解析冷却时间 @@ -509,6 +510,16 @@ async def refresh_credential_fast(): except Exception: pass + # 对于没有触发cd的429错误,不预热新凭证 + if not (status_code == 429 and cooldown_until is None): + # 并行预热下一个凭证,不阻塞当前处理 + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) + ) + await record_api_call_error( credential_manager, current_file, status_code, cooldown_until, mode="geminicli", model_name=model_name, @@ -526,6 +537,12 @@ async def refresh_credential_fast(): # 重新获取凭证并重试 log.info(f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + # 对于没有冷却时间的429错误,保留当前凭证重试 + if status_code == 429 and cooldown_until is None: + log.info(f"[NON-STREAM] 429无冷却时间,保留当前凭证重试: {current_file}") + await asyncio.sleep(retry_interval) + continue + # 使用预热的凭证任务,避免等待 if next_cred_task is not None: try: From 83bffcfb90d336d3a0b15567f4b832843b6c5e9e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Feb 2026 03:03:55 +0000 Subject: [PATCH 05/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index a70eb7478..09784c5e9 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=16698704673cfe6a09f8cfa047c5034c54d41057 -short_hash=1669870 -message=反重力取消附加默认系统提示词 -date=2026-02-24 11:01:23 +0800 +full_hash=538325d3d1583201fddfd73a4e606ef1c44eb67e +short_hash=538325d +message=优化重试机制 +date=2026-02-24 11:03:46 +0800 From 1cc231d1f4f4d99c5ad21d780e3f54e015045419 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Tue, 24 Feb 2026 11:39:35 +0800 Subject: [PATCH 06/47] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/antigravity.py | 38 ++++++++++++++------------ src/api/geminicli.py | 38 ++++++++++++++------------ src/credential_manager.py | 39 +++++++++++--------------- src/storage/mongodb_manager.py | 42 ++++++++++++++++++++++++++++ src/storage/sqlite_manager.py | 50 ++++++++++++++++++++++++++++++++++ 5 files changed, 150 insertions(+), 57 deletions(-) diff --git a/src/api/antigravity.py b/src/api/antigravity.py index a02ff3296..e2cd9f35a 100644 --- a/src/api/antigravity.py +++ b/src/api/antigravity.py @@ -166,7 +166,7 @@ async def refresh_credential_fast(): for attempt in range(max_retries + 1): success_recorded = False # 标记是否已记录成功 need_retry = False # 标记是否需要重试 - keep_current_credential = False # 标记是否保留当前凭证重试(无cd的429) + keep_current_credential = False # 标记是否保留当前凭证重试(无cd的429/503) try: async for chunk in stream_post_async( @@ -200,8 +200,8 @@ async def refresh_credential_fast(): except Exception: pass - # 对于没有触发cd的429错误,保留当前凭证重试;否则预热下一个凭证 - if status_code == 429 and cooldown_until is None: + # 对于没有触发cd的429/503错误,保留当前凭证重试;否则预热下一个凭证 + if (status_code == 429 or status_code == 503) and cooldown_until is None: keep_current_credential = True elif next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( @@ -210,11 +210,13 @@ async def refresh_credential_fast(): ) ) - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="antigravity", model_name=model_name, - error_message=error_body - ) + # 无cd的429/503保留当前凭证重试,无需记录错误 + if not keep_current_credential: + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="antigravity", model_name=model_name, + error_message=error_body + ) # 检查是否应该重试 should_retry = await handle_error_with_retry( @@ -289,7 +291,7 @@ async def refresh_credential_fast(): # 对于没有冷却时间的429错误,保留当前凭证重试 if keep_current_credential: - log.info(f"[ANTIGRAVITY STREAM] 429无冷却时间,保留当前凭证重试: {current_file}") + log.info(f"[ANTIGRAVITY STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}") await asyncio.sleep(retry_interval) continue @@ -522,7 +524,7 @@ async def refresh_credential_fast(): pass # 对于没有触发cd的429错误,不预热新凭证 - if not (status_code == 429 and cooldown_until is None): + if not ((status_code == 429 or status_code == 503) and cooldown_until is None): # 并行预热下一个凭证,不阻塞当前处理 if next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( @@ -531,11 +533,13 @@ async def refresh_credential_fast(): ) ) - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="antigravity", model_name=model_name, - error_message=error_text - ) + # 无cd的429/503保留当前凭证重试,无需记录错误 + if not ((status_code == 429 or status_code == 503) and cooldown_until is None): + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="antigravity", model_name=model_name, + error_message=error_text + ) # 检查是否应该重试 should_retry = await handle_error_with_retry( @@ -548,8 +552,8 @@ async def refresh_credential_fast(): need_retry = True # 对于没有冷却时间的429错误,保留当前凭证重试 - if status_code == 429 and cooldown_until is None: - log.info(f"[ANTIGRAVITY] 429无冷却时间,保留当前凭证重试: {current_file}") + if (status_code == 429 or status_code == 503) and cooldown_until is None: + log.info(f"[ANTIGRAVITY] {status_code}无冷却时间,保留当前凭证重试: {current_file}") await asyncio.sleep(retry_interval) continue else: diff --git a/src/api/geminicli.py b/src/api/geminicli.py index d08da054e..434ea6cdf 100644 --- a/src/api/geminicli.py +++ b/src/api/geminicli.py @@ -174,7 +174,7 @@ async def refresh_credential_fast(): for attempt in range(max_retries + 1): success_recorded = False # 标记是否已记录成功 need_retry = False # 标记是否需要重试 - keep_current_credential = False # 标记是否保留当前凭证重试(无cd的429) + keep_current_credential = False # 标记是否保留当前凭证重试(无cd的429/503) try: async for chunk in stream_post_async( @@ -208,8 +208,8 @@ async def refresh_credential_fast(): except Exception: pass - # 对于没有触发cd的429错误,保留当前凭证重试;否则预热下一个凭证 - if status_code == 429 and cooldown_until is None: + # 对于没有触发cd的429/503错误,保留当前凭证重试;否则预热下一个凭证 + if (status_code == 429 or status_code == 503) and cooldown_until is None: keep_current_credential = True elif next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( @@ -218,11 +218,13 @@ async def refresh_credential_fast(): ) ) - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="geminicli", model_name=model_name, - error_message=error_body - ) + # 无cd的429/503保留当前凭证重试,无需记录错误 + if not keep_current_credential: + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="geminicli", model_name=model_name, + error_message=error_body + ) # 检查是否应该重试 should_retry = await handle_error_with_retry( @@ -308,7 +310,7 @@ async def refresh_credential_fast(): # 对于没有冷却时间的429错误,保留当前凭证重试 if keep_current_credential: - log.info(f"[GEMINICLI STREAM] 429无冷却时间,保留当前凭证重试: {current_file}") + log.info(f"[GEMINICLI STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}") await asyncio.sleep(retry_interval) continue @@ -511,7 +513,7 @@ async def refresh_credential_fast(): pass # 对于没有触发cd的429错误,不预热新凭证 - if not (status_code == 429 and cooldown_until is None): + if not ((status_code == 429 or status_code == 503) and cooldown_until is None): # 并行预热下一个凭证,不阻塞当前处理 if next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( @@ -520,11 +522,13 @@ async def refresh_credential_fast(): ) ) - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="geminicli", model_name=model_name, - error_message=error_text - ) + # 无cd的429/503保留当前凭证重试,无需记录错误 + if not ((status_code == 429 or status_code == 503) and cooldown_until is None): + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="geminicli", model_name=model_name, + error_message=error_text + ) # 检查是否应该重试(会自动处理禁用逻辑) should_retry = await handle_error_with_retry( @@ -538,8 +542,8 @@ async def refresh_credential_fast(): log.info(f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") # 对于没有冷却时间的429错误,保留当前凭证重试 - if status_code == 429 and cooldown_until is None: - log.info(f"[NON-STREAM] 429无冷却时间,保留当前凭证重试: {current_file}") + if (status_code == 429 or status_code == 503) and cooldown_until is None: + log.info(f"[NON-STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}") await asyncio.sleep(retry_interval) continue diff --git a/src/credential_manager.py b/src/credential_manager.py index d5c116dfe..861c09418 100644 --- a/src/credential_manager.py +++ b/src/credential_manager.py @@ -2,6 +2,7 @@ 凭证管理器 """ +import asyncio import time from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Tuple @@ -263,34 +264,29 @@ async def record_api_call_result( """ await self._ensure_initialized() try: - state_updates = {} - if success: - state_updates["last_success"] = time.time() - # 清除错误码和错误信息 - state_updates["error_codes"] = [] - state_updates["error_messages"] = [] - - # 如果提供了 model_name,清除该模型的冷却 - if model_name: - if hasattr(self._storage_adapter._backend, 'set_model_cooldown'): - await self._storage_adapter._backend.set_model_cooldown( - credential_name, model_name, None, mode=mode - ) + # 条件写入:仅当凭证有错误状态或模型冷却时才写 DB,零内存缓存 + # fire-and-forget,不阻塞请求链路 + asyncio.create_task( + self._storage_adapter._backend.record_success( + credential_name, model_name=model_name, mode=mode + ) + ) elif error_code: - # 记录错误码和错误信息(覆盖模式) - error_codes = [error_code] - - # 保存错误信息(使用字典覆盖模式,与 panel/creds.py 保持一致) + # 记录错误码和错误信息 error_messages = {} if error_message: error_messages[str(error_code)] = error_message - state_updates["error_codes"] = error_codes - state_updates["error_messages"] = error_messages + state_updates = { + "error_codes": [error_code], + "error_messages": error_messages, + } + + await self.update_credential_state(credential_name, state_updates, mode=mode) - # 如果提供了冷却时间和模型名,设置模型级冷却 + # 设置模型级冷却 if cooldown_until is not None and model_name: if hasattr(self._storage_adapter._backend, 'set_model_cooldown'): await self._storage_adapter._backend.set_model_cooldown( @@ -301,9 +297,6 @@ async def record_api_call_result( f"冷却至: {datetime.fromtimestamp(cooldown_until, timezone.utc).isoformat()}" ) - if state_updates: - await self.update_credential_state(credential_name, state_updates, mode=mode) - except Exception as e: log.error(f"Error recording API call result for {credential_name}: {e}") diff --git a/src/storage/mongodb_manager.py b/src/storage/mongodb_manager.py index cceff0e48..fa5adac66 100644 --- a/src/storage/mongodb_manager.py +++ b/src/storage/mongodb_manager.py @@ -1008,3 +1008,45 @@ async def set_model_cooldown( except Exception as e: log.error(f"Error setting model cooldown for {filename}: {e}") return False + + async def record_success( + self, + filename: str, + model_name: Optional[str] = None, + mode: str = "geminicli" + ) -> None: + """ + 成功调用后的条件写入: + - 只有当前 error_codes 非空时才清除错误并写 last_success + - 只有当前存在该模型的冷却键时才清除 + 通过 MongoDB 服务端条件匹配实现 + """ + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + collection_name = self._get_collection_name(mode) + collection = self._db[collection_name] + now = time.time() + + # 条件写入:只有 error_codes 非空时才触发,避免无意义的写 IO + await collection.update_one( + {"filename": filename, "error_codes": {"$ne": []}}, + {"$set": { + "last_success": now, + "error_codes": [], + "error_messages": {}, + "updated_at": now, + }} + ) + + # 条件删除模型冷却:只有该键存在时才写入 + if model_name: + escaped = self._escape_model_name(model_name) + await collection.update_one( + {"filename": filename, f"model_cooldowns.{escaped}": {"$exists": True}}, + {"$unset": {f"model_cooldowns.{escaped}": ""}, "$set": {"updated_at": now}} + ) + + except Exception as e: + log.error(f"Error recording success for {filename}: {e}") \ No newline at end of file diff --git a/src/storage/sqlite_manager.py b/src/storage/sqlite_manager.py index 65f6104a0..59b4cfd64 100644 --- a/src/storage/sqlite_manager.py +++ b/src/storage/sqlite_manager.py @@ -1270,3 +1270,53 @@ async def set_model_cooldown( except Exception as e: log.error(f"Error setting model cooldown for {filename}: {e}") return False + + async def record_success( + self, + filename: str, + model_name: Optional[str] = None, + mode: str = "geminicli" + ) -> None: + """ + 成功调用后的条件写入: + - 只有当前 error_codes 非空时才清除错误并写 last_success + - 只有当前存在该模型的冷却键时才清除 + 通过 SQL WHERE 条件匹配实现 + """ + self._ensure_initialized() + filename = os.path.basename(filename) + + try: + table_name = self._get_table_name(mode) + async with aiosqlite.connect(self._db_path) as db: + # 条件写入:只有 error_codes 非空时才触发 + await db.execute(f""" + UPDATE {table_name} + SET last_success = unixepoch(), + error_codes = '[]', + error_messages = '{{}}', + updated_at = unixepoch() + WHERE filename = ? + AND (error_codes IS NOT NULL AND error_codes != '[]' AND error_codes != '') + """, (filename,)) + + # 条件删除模型冷却:只有模型键存在时才写入 + if model_name: + async with db.execute(f""" + SELECT model_cooldowns FROM {table_name} WHERE filename = ? + """, (filename,)) as cursor: + row = await cursor.fetchone() + if row: + cooldowns = json.loads(row[0] or '{}') + if model_name in cooldowns: + cooldowns.pop(model_name) + await db.execute(f""" + UPDATE {table_name} + SET model_cooldowns = ?, updated_at = unixepoch() + WHERE filename = ? + """, (json.dumps(cooldowns), filename)) + + await db.commit() + + except Exception as e: + log.error(f"Error recording success for {filename}: {e}") \ No newline at end of file From 7c0b378e193816201c804eb50f97a192fe38794e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Feb 2026 03:39:46 +0000 Subject: [PATCH 07/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index 09784c5e9..f4cc791d8 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=538325d3d1583201fddfd73a4e606ef1c44eb67e -short_hash=538325d -message=优化重试机制 -date=2026-02-24 11:03:46 +0800 +full_hash=6cca1630cd98304b5c9d30b8e574348413f56e5a +short_hash=6cca163 +message=优化性能 +date=2026-02-24 11:39:35 +0800 From 461218e81dc773544e31ff714e695c495b9d2650 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Tue, 24 Feb 2026 12:01:23 +0800 Subject: [PATCH 08/47] =?UTF-8?q?=E4=BC=98=E5=8C=96mongodb=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/storage/mongodb_manager.py | 70 +++++++++++----------------------- 1 file changed, 23 insertions(+), 47 deletions(-) diff --git a/src/storage/mongodb_manager.py b/src/storage/mongodb_manager.py index fa5adac66..396a21dc5 100644 --- a/src/storage/mongodb_manager.py +++ b/src/storage/mongodb_manager.py @@ -3,6 +3,7 @@ """ import os +import random import time from typing import Any, Dict, List, Optional @@ -205,7 +206,7 @@ async def get_next_available_credential( - 如果模型名包含 "flash": 直接混用所有可用凭证,不区分 preview 状态 - 如果模型名不包含 "preview" 且不包含 "flash": 优先使用 preview=False 的凭证,没有时才使用 preview=True - 对于 antigravity: 不检查 preview 状态 - - 使用聚合管道在数据库层面过滤冷却状态,性能更优 + - 使用 count + random skip + limit(1) 替代 $sample,避免全集合扫描 """ self._ensure_initialized() @@ -214,56 +215,31 @@ async def get_next_available_credential( collection = self._db[collection_name] current_time = time.time() - # 构建聚合管道 - pipeline = [ - # 第一步: 筛选未禁用的凭证 - {"$match": {"disabled": False}}, - ] + # 构建普通查询(避免 $sample 聚合导致全集合扫描) + match_query: Dict[str, Any] = {"disabled": False} - # 如果提供了 model_name,添加冷却检查 + # 冷却检查:直接用 MongoDB 查询表达,无需 $addFields if model_name: - # 转义模型名中的点号 escaped_model_name = self._escape_model_name(model_name) - pipeline.extend([ - # 第二步: 添加冷却状态字段 - { - "$addFields": { - "is_available": { - "$or": [ - # model_cooldowns 中没有该 model_name - {"$not": {"$ifNull": [f"$model_cooldowns.{escaped_model_name}", False]}}, - # 或者冷却时间已过期 - {"$lte": [f"$model_cooldowns.{escaped_model_name}", current_time]} - ] - } - } - }, - # 第三步: 只保留可用的凭证 - {"$match": {"is_available": True}}, - ]) - - # 对于 geminicli 模式,根据模型名的 preview 状态筛选凭证 - if mode == "geminicli" and model_name: - is_preview_model = "preview" in model_name.lower() - - if is_preview_model: - # 模型名包含 preview,只能使用 preview=True 的凭证 - pipeline.append({"$match": {"preview": True}}) - - # 随机抽取一个 - pipeline.append({"$sample": {"size": 1}}) - - # 只投影需要的字段 - pipeline.append({ - "$project": { - "filename": 1, - "credential_data": 1, - "_id": 0 - } - }) + field = f"model_cooldowns.{escaped_model_name}" + match_query["$or"] = [ + {field: {"$exists": False}}, + {field: {"$lte": current_time}}, + ] + + # geminicli preview 筛选 + if mode == "geminicli" and model_name and "preview" in model_name.lower(): + match_query["preview"] = True + + # 统计符合条件的凭证总数(走索引,极快) + count = await collection.count_documents(match_query) + if count == 0: + return None - # 执行聚合 - docs = await collection.aggregate(pipeline).to_list(length=1) + # 随机偏移 + limit(1),替代 $sample,避免全集合随机排序 + skip_n = random.randint(0, count - 1) + projection = {"filename": 1, "credential_data": 1, "_id": 0} + docs = await collection.find(match_query, projection).skip(skip_n).limit(1).to_list(1) if docs: doc = docs[0] From c4f907bb98f5c267e7d7a049e4fc42b38f7ea2bd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Feb 2026 04:01:34 +0000 Subject: [PATCH 09/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index f4cc791d8..b7106228a 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=6cca1630cd98304b5c9d30b8e574348413f56e5a -short_hash=6cca163 -message=优化性能 -date=2026-02-24 11:39:35 +0800 +full_hash=24f4f86a0f136820e54c6e5d5df2a61ea257ed3c +short_hash=24f4f86 +message=优化mongodb查询 +date=2026-02-24 12:01:23 +0800 From 4839cac2d602a3e874c4fbef17b2b9aad672a668 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Tue, 24 Feb 2026 20:22:55 +0800 Subject: [PATCH 10/47] =?UTF-8?q?=E6=9B=B4=E6=96=B0redis=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements-termux.txt | 3 +- requirements.txt | 1 + src/storage/mongodb_manager.py | 236 ++++++++++++++++++++++++++++++++- 3 files changed, 237 insertions(+), 3 deletions(-) diff --git a/requirements-termux.txt b/requirements-termux.txt index f265e54ba..4e970a6eb 100644 --- a/requirements-termux.txt +++ b/requirements-termux.txt @@ -9,4 +9,5 @@ PyJWT oauthlib motor pypinyin -aiosqlite \ No newline at end of file +aiosqlite +redis \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2945c0c96..b6d668577 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ oauthlib>=3.3.1 motor>=3.7.1 aiosqlite>=0.20.0 pypinyin>=0.51.0 +redis>=4.2.0 diff --git a/src/storage/mongodb_manager.py b/src/storage/mongodb_manager.py index 396a21dc5..c2e7aa384 100644 --- a/src/storage/mongodb_manager.py +++ b/src/storage/mongodb_manager.py @@ -48,6 +48,10 @@ def __init__(self): self._config_cache: Dict[str, Any] = {} self._config_loaded = False + # Redis 缓存(仅当 REDIS_URL 环境变量存在时启用) + self._redis = None + self._redis_enabled: bool = False + async def initialize(self) -> None: """初始化 MongoDB 连接""" if self._initialized: @@ -75,6 +79,9 @@ async def initialize(self) -> None: self._initialized = True log.info(f"MongoDB storage initialized (database: {database_name})") + # 尝试初始化 Redis(可选) + await self._init_redis() + except Exception as e: log.error(f"Error initializing MongoDB: {e}") raise @@ -162,8 +169,177 @@ async def _load_config_cache(self): log.error(f"Error loading config cache: {e}") self._config_cache = {} + # ============ Redis 缓存(可选,仅当 REDIS_URL 存在时启用)============ + + async def _init_redis(self) -> None: + """初始化 Redis 连接并重建凭证池缓存(若 REDIS_URL 存在)""" + redis_url = os.getenv("REDIS_URL") + if not redis_url: + return + + try: + import redis.asyncio as aioredis # type: ignore + except ImportError: + log.warning("redis package not installed, Redis cache disabled. Run: pip install redis") + return + + try: + self._redis = aioredis.from_url(redis_url, decode_responses=True) + await self._redis.ping() + self._redis_enabled = True + log.info("Redis connected, rebuilding credential pool cache...") + + # 并行重建两个 mode 的缓存 + import asyncio + await asyncio.gather( + self._rebuild_redis_cache("geminicli"), + self._rebuild_redis_cache("antigravity"), + ) + log.info("Redis credential pool cache ready") + except Exception as e: + log.warning(f"Redis init failed, falling back to MongoDB-only mode: {e}") + self._redis = None + self._redis_enabled = False + + # ---- Redis key 工具 ---- + + def _rk_avail(self, mode: str) -> str: + """所有未禁用凭证的 Redis Set key""" + return f"gcli:avail:{mode}" + + def _rk_preview(self, mode: str) -> str: + """未禁用且 preview=True 的凭证 Redis Set key(仅 geminicli)""" + return f"gcli:preview:{mode}" + + def _rk_cd(self, mode: str, filename: str, escaped_model: str) -> str: + """模型冷却 Redis key(带 TTL)""" + return f"gcli:cd:{mode}:{filename}:{escaped_model}" + + # ---- Redis 缓存维护 ---- + + async def _rebuild_redis_cache(self, mode: str) -> None: + """从 MongoDB 重建指定 mode 的 Redis 凭证池缓存""" + if not self._redis: + return + try: + collection = self._db[self._get_collection_name(mode)] + projection: Dict[str, Any] = {"filename": 1, "disabled": 1, "_id": 0} + if mode == "geminicli": + projection["preview"] = 1 + + avail: List[str] = [] + preview: List[str] = [] + async for doc in collection.find({}, projection=projection): + if not doc.get("disabled", False): + avail.append(doc["filename"]) + if mode == "geminicli" and doc.get("preview", True): + preview.append(doc["filename"]) + + pipe = self._redis.pipeline() + pipe.delete(self._rk_avail(mode)) + pipe.delete(self._rk_preview(mode)) + if avail: + pipe.sadd(self._rk_avail(mode), *avail) + if mode == "geminicli" and preview: + pipe.sadd(self._rk_preview(mode), *preview) + await pipe.execute() + log.debug(f"Redis cache rebuilt [{mode}]: {len(avail)} avail, {len(preview)} preview") + except Exception as e: + log.warning(f"Redis rebuild cache error [{mode}]: {e}") + + async def _redis_add_cred(self, mode: str, filename: str, preview: bool = True) -> None: + """将凭证加入 Redis 可用池""" + if not self._redis_enabled: + return + try: + pipe = self._redis.pipeline() + pipe.sadd(self._rk_avail(mode), filename) + if mode == "geminicli" and preview: + pipe.sadd(self._rk_preview(mode), filename) + await pipe.execute() + except Exception as e: + log.warning(f"Redis add_cred error: {e}") + + async def _redis_remove_cred(self, mode: str, filename: str) -> None: + """从 Redis 所有池中移除凭证""" + if not self._redis_enabled: + return + try: + pipe = self._redis.pipeline() + pipe.srem(self._rk_avail(mode), filename) + pipe.srem(self._rk_preview(mode), filename) + await pipe.execute() + except Exception as e: + log.warning(f"Redis remove_cred error: {e}") + + async def _redis_sync_cred(self, mode: str, filename: str, disabled: bool, preview: bool) -> None: + """根据最新状态同步单个凭证在 Redis 中的集合成员""" + if not self._redis_enabled: + return + try: + pipe = self._redis.pipeline() + if disabled: + pipe.srem(self._rk_avail(mode), filename) + pipe.srem(self._rk_preview(mode), filename) + else: + pipe.sadd(self._rk_avail(mode), filename) + if mode == "geminicli": + if preview: + pipe.sadd(self._rk_preview(mode), filename) + else: + pipe.srem(self._rk_preview(mode), filename) + await pipe.execute() + except Exception as e: + log.warning(f"Redis sync_cred error: {e}") + + async def _get_next_available_from_redis( + self, mode: str, model_name: Optional[str] + ) -> Optional[tuple]: + """ + Redis 快速路径:随机取候选凭证,跳过冷却中的,返回 (filename, credential_data)。 + 失败或池为空时返回 None,由调用方降级到 MongoDB。 + """ + try: + # 选择候选池 + if mode == "geminicli" and model_name and "preview" in model_name.lower(): + pool_key = self._rk_preview(mode) + else: + pool_key = self._rk_avail(mode) + + pool_size = await self._redis.scard(pool_key) + if pool_size == 0: + return None + + # 一次取多个随机成员,减少 round-trip + sample_size = min(pool_size, 10) + candidates = await self._redis.srandmember(pool_key, sample_size) + if not candidates: + return None + + # 过滤冷却中的凭证 + if model_name: + escaped = self._escape_model_name(model_name) + for filename in candidates: + cd_key = self._rk_cd(mode, filename, escaped) + if not await self._redis.exists(cd_key): + credential_data = await self.get_credential(filename, mode) + return filename, credential_data + # 所有候选都在冷却中,降级到 MongoDB + return None + else: + filename = candidates[0] + credential_data = await self.get_credential(filename, mode) + return filename, credential_data + except Exception as e: + log.warning(f"Redis get_next_available error: {e}") + return None + async def close(self) -> None: """关闭 MongoDB 连接""" + if self._redis: + await self._redis.aclose() + self._redis = None + self._redis_enabled = False if self._client: self._client.close() self._client = None @@ -206,10 +382,19 @@ async def get_next_available_credential( - 如果模型名包含 "flash": 直接混用所有可用凭证,不区分 preview 状态 - 如果模型名不包含 "preview" 且不包含 "flash": 优先使用 preview=False 的凭证,没有时才使用 preview=True - 对于 antigravity: 不检查 preview 状态 - - 使用 count + random skip + limit(1) 替代 $sample,避免全集合扫描 + - 开启 Redis 时:利用 Redis Set 随机选凭证 + TTL key 判断冷却 + - 未开启 Redis 时:使用 count + random skip + limit(1) """ self._ensure_initialized() + # Redis 快速路径 + if self._redis_enabled: + result = await self._get_next_available_from_redis(mode, model_name) + if result is not None: + return result + # result 为 None 有两种可能:池为空或所有候选都冷却中 + # 后者需降级到 MongoDB 以得到更大的样本空间 + try: collection_name = self._get_collection_name(mode) collection = self._db[collection_name] @@ -337,10 +522,12 @@ async def store_credential(self, filename: str, credential_data: Dict[str, Any], new_credential["preview"] = True await collection.insert_one(new_credential) + # 新凭证插入成功,添加到 Redis 可用池 + await self._redis_add_cred(mode, filename, preview=True) except Exception as insert_error: # 处理并发插入导致的重复键错误 if "duplicate key" in str(insert_error).lower(): - # 重试更新 + # 重试更新(已存在的凭证,无需更新 Redis) await collection.update_one( {"filename": filename}, {"$set": {"credential_data": credential_data, "updated_at": current_ts}} @@ -417,6 +604,8 @@ async def delete_credential(self, filename: str, mode: str = "geminicli") -> boo deleted_count = result.deleted_count if deleted_count > 0: + # 从 Redis 池中移除 + await self._redis_remove_cred(mode, filename) log.debug(f"Deleted {deleted_count} credential(s): {filename} (mode={mode})") return True else: @@ -541,6 +730,33 @@ async def update_credential_state( ) updated_count = result.modified_count + result.matched_count + # 如果 disabled 或 preview 发生变化,同步 Redis 池成员关系 + if self._redis_enabled and ("disabled" in valid_updates or "preview" in valid_updates): + if "disabled" in valid_updates and valid_updates["disabled"]: + # 直接禁用:从两个集合中移除 + await self._redis_remove_cred(mode, filename) + else: + # 启用或修改 preview:需知道最新的 disabled + preview 状态 + # 从 valid_updates 中取,无则向 MongoDB 查一次 + if "disabled" in valid_updates and "preview" in valid_updates: + await self._redis_sync_cred( + mode, filename, + disabled=bool(valid_updates["disabled"]), + preview=bool(valid_updates["preview"]), + ) + else: + # 只知部分信息,查一次 MongoDB 获取完整状态 + snap = await collection.find_one( + {"filename": filename}, + {"disabled": 1, "preview": 1, "_id": 0} + ) + if snap: + await self._redis_sync_cred( + mode, filename, + disabled=bool(snap.get("disabled", False)), + preview=bool(snap.get("preview", True)), + ) + return updated_count > 0 except Exception as e: @@ -978,6 +1194,19 @@ async def set_model_cooldown( log.warning(f"Credential {filename} not found") return False + # 同步写入 Redis TTL key + if self._redis_enabled: + cd_key = self._rk_cd(mode, filename, escaped_model_name) + if cooldown_until is None: + await self._redis.delete(cd_key) + else: + ttl = int(cooldown_until - time.time()) + if ttl > 0: + await self._redis.setex(cd_key, ttl, str(cooldown_until)) + else: + # 冷却已经过期,确保清除 + await self._redis.delete(cd_key) + log.debug(f"Set model cooldown: {filename}, model_name={model_name}, cooldown_until={cooldown_until}") return True @@ -1023,6 +1252,9 @@ async def record_success( {"filename": filename, f"model_cooldowns.{escaped}": {"$exists": True}}, {"$unset": {f"model_cooldowns.{escaped}": ""}, "$set": {"updated_at": now}} ) + # 同步删除 Redis 冷却 key + if self._redis_enabled: + await self._redis.delete(self._rk_cd(mode, filename, escaped)) except Exception as e: log.error(f"Error recording success for {filename}: {e}") \ No newline at end of file From 857c6b879f34a37c7e309c1b551698983e83a6c6 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Feb 2026 12:23:06 +0000 Subject: [PATCH 11/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index b7106228a..d4cfeef66 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=24f4f86a0f136820e54c6e5d5df2a61ea257ed3c -short_hash=24f4f86 -message=优化mongodb查询 -date=2026-02-24 12:01:23 +0800 +full_hash=70c6fb4acc3bdd8fc5efbb702a3dceff5d3880c4 +short_hash=70c6fb4 +message=更新redis缓存 +date=2026-02-24 20:22:55 +0800 From 07f82502360b89c783323d075bc78ae2bc54c8a9 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Tue, 24 Feb 2026 20:46:42 +0800 Subject: [PATCH 12/47] Update mongodb_manager.py --- src/storage/mongodb_manager.py | 39 +++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/src/storage/mongodb_manager.py b/src/storage/mongodb_manager.py index c2e7aa384..2b1e5217f 100644 --- a/src/storage/mongodb_manager.py +++ b/src/storage/mongodb_manager.py @@ -218,7 +218,11 @@ def _rk_cd(self, mode: str, filename: str, escaped_model: str) -> str: # ---- Redis 缓存维护 ---- async def _rebuild_redis_cache(self, mode: str) -> None: - """从 MongoDB 重建指定 mode 的 Redis 凭证池缓存""" + """ + 从 MongoDB 重建指定 mode 的 Redis 凭证池缓存。 + + 使用临时 key + RENAME 原子替换 + """ if not self._redis: return try: @@ -235,14 +239,34 @@ async def _rebuild_redis_cache(self, mode: str) -> None: if mode == "geminicli" and doc.get("preview", True): preview.append(doc["filename"]) + tmp_avail = self._rk_avail(mode) + ":tmp" + tmp_preview = self._rk_preview(mode) + ":tmp" + pipe = self._redis.pipeline() - pipe.delete(self._rk_avail(mode)) - pipe.delete(self._rk_preview(mode)) + # 先写临时 key(此时正式 key 仍完整可用) + pipe.delete(tmp_avail) + pipe.delete(tmp_preview) if avail: - pipe.sadd(self._rk_avail(mode), *avail) + pipe.sadd(tmp_avail, *avail) if mode == "geminicli" and preview: - pipe.sadd(self._rk_preview(mode), *preview) + pipe.sadd(tmp_preview, *preview) await pipe.execute() + + # RENAME 是原子操作:瞬间切换,不存在空窗 + pipe2 = self._redis.pipeline() + if avail: + pipe2.rename(tmp_avail, self._rk_avail(mode)) + else: + pipe2.delete(self._rk_avail(mode)) + pipe2.delete(tmp_avail) + if mode == "geminicli": + if preview: + pipe2.rename(tmp_preview, self._rk_preview(mode)) + else: + pipe2.delete(self._rk_preview(mode)) + pipe2.delete(tmp_preview) + await pipe2.execute() + log.debug(f"Redis cache rebuilt [{mode}]: {len(avail)} avail, {len(preview)} preview") except Exception as e: log.warning(f"Redis rebuild cache error [{mode}]: {e}") @@ -308,6 +332,7 @@ async def _get_next_available_from_redis( pool_size = await self._redis.scard(pool_key) if pool_size == 0: + log.debug(f"[Redis MISS] mode={mode} pool_key={pool_key}: pool empty, fallback to MongoDB") return None # 一次取多个随机成员,减少 round-trip @@ -323,12 +348,15 @@ async def _get_next_available_from_redis( cd_key = self._rk_cd(mode, filename, escaped) if not await self._redis.exists(cd_key): credential_data = await self.get_credential(filename, mode) + log.debug(f"[Redis HIT] mode={mode} model={model_name} -> {filename}") return filename, credential_data # 所有候选都在冷却中,降级到 MongoDB + log.debug(f"[Redis MISS] mode={mode} model={model_name}: all {len(candidates)} candidates in cooldown, fallback to MongoDB") return None else: filename = candidates[0] credential_data = await self.get_credential(filename, mode) + log.debug(f"[Redis HIT] mode={mode} -> {filename}") return filename, credential_data except Exception as e: log.warning(f"Redis get_next_available error: {e}") @@ -394,6 +422,7 @@ async def get_next_available_credential( return result # result 为 None 有两种可能:池为空或所有候选都冷却中 # 后者需降级到 MongoDB 以得到更大的样本空间 + log.debug(f"[MongoDB fallback] mode={mode} model={model_name}") try: collection_name = self._get_collection_name(mode) From b95143799dc6117d3941e440285cdc985dfaa2f2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Feb 2026 12:46:53 +0000 Subject: [PATCH 13/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index d4cfeef66..af089ad5b 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=70c6fb4acc3bdd8fc5efbb702a3dceff5d3880c4 -short_hash=70c6fb4 -message=更新redis缓存 -date=2026-02-24 20:22:55 +0800 +full_hash=1a936049ccde5afcc1d3dc7b4622b32ef8b7ae85 +short_hash=1a93604 +message=Update mongodb_manager.py +date=2026-02-24 20:46:42 +0800 From facccca6305019cc0d3fdb0b85e5b92b382004cd Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Tue, 24 Feb 2026 20:55:28 +0800 Subject: [PATCH 14/47] Update pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7693d8cc3..99702a83a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "python-multipart>=0.0.20", "pypinyin>=0.51.0", "aiosqlite>=0.20.0", + "redis>=7.2.0", ] [project.optional-dependencies] From 3db1545768efd2dba688a079444652a4b200cafa Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Feb 2026 12:55:39 +0000 Subject: [PATCH 15/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index af089ad5b..ff787ad83 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=1a936049ccde5afcc1d3dc7b4622b32ef8b7ae85 -short_hash=1a93604 -message=Update mongodb_manager.py -date=2026-02-24 20:46:42 +0800 +full_hash=1c0b17ad3d36831564a712983d01e264680d1cf1 +short_hash=1c0b17a +message=Update pyproject.toml +date=2026-02-24 20:55:28 +0800 From 0f3a4a21ba74fa22c48fce1474f3c0c818de94d4 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Tue, 24 Feb 2026 21:01:19 +0800 Subject: [PATCH 16/47] Update .env.example --- .env.example | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.env.example b/.env.example index 5620277d8..1f3b7babb 100644 --- a/.env.example +++ b/.env.example @@ -43,6 +43,13 @@ PASSWORD=pwd # MongoDB 分布式存储模式配置 (第二优先级) # 设置 MONGODB_URI 后自动启用 MongoDB 模式,不再使用本地文件存储 +# Redis 缓存存储配置 +# 设置 REDIS_URL 后自动启用 Redis 模式,性能最佳,可大幅降低 MongoDB 的读写压力 +# 本地 Redis: redis://127.0.0.1:6379/0 +# 带密码: redis://:password@127.0.0.1:6379/0 +# 默认: 无 (不启用 Redis 缓存) +REDIS_URL=redis://127.0.0.1:6379/0 + # MongoDB 连接字符串 (设置后启用 MongoDB 分布式存储模式) # 本地 MongoDB: mongodb://localhost:27017 # 带认证: mongodb://admin:password@localhost:27017/admin From 05956675057050efef1a7a5a3d22746a42d0375c Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Tue, 24 Feb 2026 21:18:01 +0800 Subject: [PATCH 17/47] Update mongodb_manager.py --- src/storage/mongodb_manager.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/src/storage/mongodb_manager.py b/src/storage/mongodb_manager.py index 2b1e5217f..db527592a 100644 --- a/src/storage/mongodb_manager.py +++ b/src/storage/mongodb_manager.py @@ -227,17 +227,31 @@ async def _rebuild_redis_cache(self, mode: str) -> None: return try: collection = self._db[self._get_collection_name(mode)] - projection: Dict[str, Any] = {"filename": 1, "disabled": 1, "_id": 0} + # 同时投影 model_cooldowns,以便重建冷却 TTL Key + projection: Dict[str, Any] = {"filename": 1, "disabled": 1, "model_cooldowns": 1, "_id": 0} if mode == "geminicli": projection["preview"] = 1 avail: List[str] = [] preview: List[str] = [] + cooldown_entries: List[tuple] = [] # (cd_key, ttl_seconds, value) + current_time = time.time() + async for doc in collection.find({}, projection=projection): if not doc.get("disabled", False): - avail.append(doc["filename"]) + filename = doc["filename"] + avail.append(filename) if mode == "geminicli" and doc.get("preview", True): - preview.append(doc["filename"]) + preview.append(filename) + + # 收集未过期的模型冷却,重建 Redis TTL Key + model_cooldowns = doc.get("model_cooldowns") or {} + for escaped_model, cooldown_until in model_cooldowns.items(): + if isinstance(cooldown_until, (int, float)) and cooldown_until > current_time: + ttl = int(cooldown_until - current_time) + if ttl > 0: + cd_key = self._rk_cd(mode, filename, escaped_model) + cooldown_entries.append((cd_key, ttl, str(cooldown_until))) tmp_avail = self._rk_avail(mode) + ":tmp" tmp_preview = self._rk_preview(mode) + ":tmp" @@ -267,7 +281,19 @@ async def _rebuild_redis_cache(self, mode: str) -> None: pipe2.delete(tmp_preview) await pipe2.execute() - log.debug(f"Redis cache rebuilt [{mode}]: {len(avail)} avail, {len(preview)} preview") + # 批量恢复未过期的模型冷却 TTL Key + # 这一步必须在重建池之后执行,否则 Redis 重启后冷却 key 丢失, + # 导致 Redis 快速路径选出仍处于冷却中的凭证 + if cooldown_entries: + pipe3 = self._redis.pipeline() + for cd_key, ttl, value in cooldown_entries: + pipe3.setex(cd_key, ttl, value) + await pipe3.execute() + + log.debug( + f"Redis cache rebuilt [{mode}]: {len(avail)} avail, {len(preview)} preview, " + f"{len(cooldown_entries)} cooldown key(s) restored" + ) except Exception as e: log.warning(f"Redis rebuild cache error [{mode}]: {e}") From b68da131d0b3f85b4a151b2cddfe4aa97d651ab4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Feb 2026 13:01:28 +0000 Subject: [PATCH 18/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index ff787ad83..f5deca0d1 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=1c0b17ad3d36831564a712983d01e264680d1cf1 -short_hash=1c0b17a -message=Update pyproject.toml -date=2026-02-24 20:55:28 +0800 +full_hash=2b504d74da95ee4e291d526328d76c3c0df68270 +short_hash=2b504d7 +message=Update .env.example +date=2026-02-24 21:01:19 +0800 From de10836e6bf23b32ff43812e57252fe9df850745 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Feb 2026 13:18:46 +0000 Subject: [PATCH 19/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index f5deca0d1..f9765fa56 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=2b504d74da95ee4e291d526328d76c3c0df68270 -short_hash=2b504d7 -message=Update .env.example -date=2026-02-24 21:01:19 +0800 +full_hash=773bff80df681336fcf913f4eda5810a91f211fe +short_hash=773bff8 +message=Merge branch 'master' of https://github.com/su-kaka/gcli2api +date=2026-02-24 21:18:30 +0800 From 8bc83346dc14eda0028baf557d3b5f5036134809 Mon Sep 17 00:00:00 2001 From: haosenwang1018 Date: Wed, 25 Feb 2026 06:19:47 +0000 Subject: [PATCH 20/47] fix: replace 2 bare excepts with except Exception --- src/converter/openai2gemini.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/converter/openai2gemini.py b/src/converter/openai2gemini.py index 8c496dd4a..7f7187a19 100644 --- a/src/converter/openai2gemini.py +++ b/src/converter/openai2gemini.py @@ -1231,7 +1231,7 @@ def convert_gemini_to_openai_response( return json.loads(str(body)) else: return {"error": str(gemini_response)} - except: + except Exception: return {"error": str(gemini_response)} # 确保是字典格式 @@ -1247,7 +1247,7 @@ def convert_gemini_to_openai_response( gemini_response = json.loads(str(body)) else: gemini_response = json.loads(str(gemini_response)) - except: + except Exception: return {"error": "Invalid response format"} # 处理 GeminiCLI 的 response 包装格式 From 4247d026ce5f60810b8ed6ffc3fa7af882a50026 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 25 Feb 2026 07:34:26 +0000 Subject: [PATCH 21/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index f9765fa56..f4f8ba25e 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=773bff80df681336fcf913f4eda5810a91f211fe -short_hash=773bff8 -message=Merge branch 'master' of https://github.com/su-kaka/gcli2api -date=2026-02-24 21:18:30 +0800 +full_hash=7781092553206cce0ae8361862db281f98934ea9 +short_hash=7781092 +message=Merge pull request #340 from haosenwang1018/fix/bare-excepts +date=2026-02-25 15:34:16 +0800 From 76715339073d6dcc575fe30f27489f8579892e35 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Wed, 25 Feb 2026 17:34:45 +0800 Subject: [PATCH 22/47] Update anti_truncation.py --- src/converter/anti_truncation.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/converter/anti_truncation.py b/src/converter/anti_truncation.py index 5f7c78db3..b8543023a 100644 --- a/src/converter/anti_truncation.py +++ b/src/converter/anti_truncation.py @@ -240,6 +240,21 @@ async def process_stream(self) -> AsyncGenerator[bytes, None]: yield line continue + # 处理上游生成器 yield 出 Response 对象的情况(错误响应) + from fastapi import Response as FastAPIResponse + if isinstance(line, FastAPIResponse): + log.error(f"Anti-truncation: Received Response object from stream (status={line.status_code}), treating as error") + error_chunk = { + "error": { + "message": line.body.decode('utf-8', errors='ignore') if hasattr(line, 'body') and line.body else "Upstream error", + "type": "api_error", + "code": line.status_code, + } + } + yield f"data: {json.dumps(error_chunk)}\n\n".encode() + yield b"data: [DONE]\n\n" + return + # 处理 bytes 类型的流式数据 if isinstance(line, bytes): # 解码 bytes 为字符串 From a526efaeabdebd0790f73ff60be01972c7e0b97c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 25 Feb 2026 09:34:58 +0000 Subject: [PATCH 23/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index f4f8ba25e..e83f1c974 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=7781092553206cce0ae8361862db281f98934ea9 -short_hash=7781092 -message=Merge pull request #340 from haosenwang1018/fix/bare-excepts -date=2026-02-25 15:34:16 +0800 +full_hash=5cb68349f003978398add8e9d364df5eef9b6fd1 +short_hash=5cb6834 +message=Update anti_truncation.py +date=2026-02-25 17:34:45 +0800 From 21e94486c328137c1948fbb61bf4d3d0e1fe2885 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Fri, 27 Feb 2026 12:30:17 +0800 Subject: [PATCH 24/47] Update gemini_fix.py --- src/converter/gemini_fix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/converter/gemini_fix.py b/src/converter/gemini_fix.py index 9d2eae0a4..513bc164f 100644 --- a/src/converter/gemini_fix.py +++ b/src/converter/gemini_fix.py @@ -48,7 +48,7 @@ def prepare_image_generation_request( if image_size: image_config["imageSize"] = image_size - request_body["model"] = "gemini-3-pro-image" # 统一使用基础模型名 + request_body["model"] = "gemini-3.1-flash-image" # 统一使用基础模型名 request_body["generationConfig"] = { "candidateCount": 1, "imageConfig": image_config From 77bbd2114d93a787803816131022da9804375d97 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Feb 2026 04:30:29 +0000 Subject: [PATCH 25/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index e83f1c974..282c7296e 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=5cb68349f003978398add8e9d364df5eef9b6fd1 -short_hash=5cb6834 -message=Update anti_truncation.py -date=2026-02-25 17:34:45 +0800 +full_hash=e00d7ecabba4be42fe1869af2bb7022b56816508 +short_hash=e00d7ec +message=Update gemini_fix.py +date=2026-02-27 12:30:17 +0800 From ceefc00b6072e68f4375f844c350da7ad77626e9 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Fri, 27 Feb 2026 17:11:48 +0800 Subject: [PATCH 26/47] =?UTF-8?q?=E9=85=8D=E7=BD=AE=E9=A1=B9=EF=BC=9A?= =?UTF-8?q?=E9=87=8D=E8=AF=95=E6=97=B6=E6=98=AF=E5=90=A6=E5=88=87=E6=8D=A2?= =?UTF-8?q?=E5=87=AD=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 20 ++++++++++++++++++++ front/common.js | 2 ++ front/control_panel.html | 20 ++++++++++++++++---- front/control_panel_mobile.html | 10 ++++++++++ src/api/antigravity.py | 22 +++++++++++++--------- src/api/geminicli.py | 23 +++++++++++++---------- src/panel/config_routes.py | 5 +++++ 7 files changed, 79 insertions(+), 23 deletions(-) diff --git a/config.py b/config.py index 11114c4df..d776e7a59 100644 --- a/config.py +++ b/config.py @@ -35,6 +35,7 @@ "RETRY_429_MAX_RETRIES": "retry_429_max_retries", "RETRY_429_ENABLED": "retry_429_enabled", "RETRY_429_INTERVAL": "retry_429_interval", + "RETRY_429_KEEP_CREDENTIAL": "retry_429_keep_credential", "ANTI_TRUNCATION_MAX_ATTEMPTS": "anti_truncation_max_attempts", "COMPATIBILITY_MODE": "compatibility_mode_enabled", "RETURN_THOUGHTS_TO_FRONTEND": "return_thoughts_to_frontend", @@ -179,6 +180,25 @@ async def get_retry_429_interval() -> float: return float(await get_config_value("retry_429_interval", 0.1)) +async def get_retry_429_keep_credential() -> bool: + """ + Get 429/503 keep credential setting. + + 控制当 429/503 错误没有冷却时间(无 cd)时,重试是否保持当前凭证。 + - True(默认):保留当前凭证重试,不切换 + - False:切换到下一个可用凭证重试 + + Environment variable: RETRY_429_KEEP_CREDENTIAL + Database config key: retry_429_keep_credential + Default: True + """ + env_value = os.getenv("RETRY_429_KEEP_CREDENTIAL") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("retry_429_keep_credential", True)) + + async def get_anti_truncation_max_attempts() -> int: """ Get maximum attempts for anti-truncation continuation. diff --git a/front/common.js b/front/common.js index 871dbc3e0..8f9b6d593 100644 --- a/front/common.js +++ b/front/common.js @@ -2689,6 +2689,7 @@ function populateConfigForm() { document.getElementById('retry429Enabled').checked = Boolean(c.retry_429_enabled); setConfigField('retry429MaxRetries', c.retry_429_max_retries || 20); setConfigField('retry429Interval', c.retry_429_interval || 0.1); + document.getElementById('retry429KeepCredential').checked = Boolean(c.retry_429_keep_credential !== false); document.getElementById('compatibilityModeEnabled').checked = Boolean(c.compatibility_mode_enabled); document.getElementById('returnThoughtsToFrontend').checked = Boolean(c.return_thoughts_to_frontend !== false); @@ -2740,6 +2741,7 @@ async function saveConfig() { retry_429_enabled: getChecked('retry429Enabled'), retry_429_max_retries: getInt('retry429MaxRetries', 20), retry_429_interval: getFloat('retry429Interval', 0.1), + retry_429_keep_credential: getChecked('retry429KeepCredential'), compatibility_mode_enabled: getChecked('compatibilityModeEnabled'), return_thoughts_to_frontend: getChecked('returnThoughtsToFrontend'), antigravity_stream2nostream: getChecked('antigravityStream2nostream'), diff --git a/front/control_panel.html b/front/control_panel.html index 0f0ebe7ea..938458b50 100644 --- a/front/control_panel.html +++ b/front/control_panel.html @@ -1160,6 +1160,7 @@ from { opacity: 0; } + to { opacity: 1; } @@ -1182,6 +1183,7 @@ transform: translateY(-50px); opacity: 0; } + to { transform: translateY(0); opacity: 1; @@ -1336,7 +1338,8 @@

GCLI2API 管理面板

diff --git a/src/api/antigravity.py b/src/api/antigravity.py index e2cd9f35a..63800064f 100644 --- a/src/api/antigravity.py +++ b/src/api/antigravity.py @@ -14,6 +14,7 @@ get_antigravity_api_url, get_antigravity_stream2nostream, get_auto_ban_error_codes, + get_retry_429_keep_credential, ) from log import log @@ -144,6 +145,7 @@ async def stream_request( DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 next_cred_task = None # 预热的下一个凭证任务 + keep_credential_config = await get_retry_429_keep_credential() # 无cd时是否保留凭证 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) async def refresh_credential_fast(): @@ -200,10 +202,10 @@ async def refresh_credential_fast(): except Exception: pass - # 对于没有触发cd的429/503错误,保留当前凭证重试;否则预热下一个凭证 + # 对于没有触发cd的429/503错误,根据配置决定是否保留当前凭证重试 if (status_code == 429 or status_code == 503) and cooldown_until is None: - keep_current_credential = True - elif next_cred_task is None and attempt < max_retries: + keep_current_credential = keep_credential_config + if not keep_current_credential and next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( credential_manager.get_valid_credential( mode="antigravity", model_name=model_name @@ -430,6 +432,7 @@ async def non_stream_request( DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 next_cred_task = None # 预热的下一个凭证任务 + keep_credential_config = await get_retry_429_keep_credential() # 无cd时是否保留凭证 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) async def refresh_credential_fast(): @@ -523,8 +526,9 @@ async def refresh_credential_fast(): except Exception: pass - # 对于没有触发cd的429错误,不预热新凭证 - if not ((status_code == 429 or status_code == 503) and cooldown_until is None): + # 对于没有触发cd且配置为保留凭证时,不预热新凭证 + no_cd_keep = (status_code == 429 or status_code == 503) and cooldown_until is None and keep_credential_config + if not no_cd_keep: # 并行预热下一个凭证,不阻塞当前处理 if next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( @@ -533,8 +537,8 @@ async def refresh_credential_fast(): ) ) - # 无cd的429/503保留当前凭证重试,无需记录错误 - if not ((status_code == 429 or status_code == 503) and cooldown_until is None): + # 无cd且保留凭证时不记录错误 + if not no_cd_keep: await record_api_call_error( credential_manager, current_file, status_code, cooldown_until, mode="antigravity", model_name=model_name, @@ -551,8 +555,8 @@ async def refresh_credential_fast(): if should_retry and attempt < max_retries: need_retry = True - # 对于没有冷却时间的429错误,保留当前凭证重试 - if (status_code == 429 or status_code == 503) and cooldown_until is None: + # 对于没有冷却时间且配置为保留凭证,保留当前凭证重试 + if no_cd_keep: log.info(f"[ANTIGRAVITY] {status_code}无冷却时间,保留当前凭证重试: {current_file}") await asyncio.sleep(retry_interval) continue diff --git a/src/api/geminicli.py b/src/api/geminicli.py index 434ea6cdf..21984e2c9 100644 --- a/src/api/geminicli.py +++ b/src/api/geminicli.py @@ -18,7 +18,7 @@ from typing import Any, Dict, Optional from fastapi import Response -from config import get_code_assist_endpoint, get_auto_ban_error_codes +from config import get_code_assist_endpoint, get_auto_ban_error_codes, get_retry_429_keep_credential from log import log from src.credential_manager import credential_manager @@ -147,6 +147,7 @@ async def stream_request( DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 next_cred_task = None # 预热的下一个凭证任务 + keep_credential_config = await get_retry_429_keep_credential() # 无cd时是否保留凭证 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) async def refresh_credential_fast(): @@ -208,10 +209,10 @@ async def refresh_credential_fast(): except Exception: pass - # 对于没有触发cd的429/503错误,保留当前凭证重试;否则预热下一个凭证 + # 对于没有触发cd的429/503错误,根据配置决定是否保留当前凭证重试 if (status_code == 429 or status_code == 503) and cooldown_until is None: - keep_current_credential = True - elif next_cred_task is None and attempt < max_retries: + keep_current_credential = keep_credential_config + if not keep_current_credential and next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( credential_manager.get_valid_credential( mode="geminicli", model_name=model_name @@ -428,6 +429,7 @@ async def non_stream_request( DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 next_cred_task = None # 预热的下一个凭证任务 + keep_credential_config = await get_retry_429_keep_credential() # 无cd时是否保留凭证 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) async def refresh_credential_fast(): @@ -512,8 +514,9 @@ async def refresh_credential_fast(): except Exception: pass - # 对于没有触发cd的429错误,不预热新凭证 - if not ((status_code == 429 or status_code == 503) and cooldown_until is None): + # 对于没有触发cd且配置为保留凭证时,不预热新凭证 + no_cd_keep = (status_code == 429 or status_code == 503) and cooldown_until is None and keep_credential_config + if not no_cd_keep: # 并行预热下一个凭证,不阻塞当前处理 if next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( @@ -522,8 +525,8 @@ async def refresh_credential_fast(): ) ) - # 无cd的429/503保留当前凭证重试,无需记录错误 - if not ((status_code == 429 or status_code == 503) and cooldown_until is None): + # 无cd且保留凭证时不记录错误 + if not no_cd_keep: await record_api_call_error( credential_manager, current_file, status_code, cooldown_until, mode="geminicli", model_name=model_name, @@ -541,8 +544,8 @@ async def refresh_credential_fast(): # 重新获取凭证并重试 log.info(f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") - # 对于没有冷却时间的429错误,保留当前凭证重试 - if (status_code == 429 or status_code == 503) and cooldown_until is None: + # 对于没有冷却时间且配置为保留凭证,保留当前凭证重试 + if no_cd_keep: log.info(f"[NON-STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}") await asyncio.sleep(retry_interval) continue diff --git a/src/panel/config_routes.py b/src/panel/config_routes.py index 66916593c..8f82aa9d8 100644 --- a/src/panel/config_routes.py +++ b/src/panel/config_routes.py @@ -48,6 +48,7 @@ async def get_config(token: str = Depends(verify_panel_token)): current_config["retry_429_max_retries"] = await config.get_retry_429_max_retries() current_config["retry_429_enabled"] = await config.get_retry_429_enabled() current_config["retry_429_interval"] = await config.get_retry_429_interval() + current_config["retry_429_keep_credential"] = await config.get_retry_429_keep_credential() # 抗截断配置 current_config["anti_truncation_max_attempts"] = await config.get_anti_truncation_max_attempts() @@ -140,6 +141,10 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa if not isinstance(new_config["antigravity_stream2nostream"], bool): raise HTTPException(status_code=400, detail="Antigravity流式转非流式开关必须是布尔值") + if "retry_429_keep_credential" in new_config: + if not isinstance(new_config["retry_429_keep_credential"], bool): + raise HTTPException(status_code=400, detail="429/503无cd时保持凭证开关必须是布尔值") + # 验证服务器配置 if "host" in new_config: if not isinstance(new_config["host"], str) or not new_config["host"].strip(): From 274ec710b980a4c6f4c4218c0540266324f3788c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Feb 2026 09:11:58 +0000 Subject: [PATCH 27/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index 282c7296e..95b7fb7bd 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=e00d7ecabba4be42fe1869af2bb7022b56816508 -short_hash=e00d7ec -message=Update gemini_fix.py -date=2026-02-27 12:30:17 +0800 +full_hash=ee5ea5dd651b2525dfb39f6d621c5f9d9c1e703d +short_hash=ee5ea5d +message=配置项:重试时是否切换凭证 +date=2026-02-27 17:11:48 +0800 From 2e3821625ffe71aff58c38507183025319785168 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Fri, 27 Feb 2026 17:13:02 +0800 Subject: [PATCH 28/47] 1 --- front/control_panel.html | 3 ++- front/control_panel_mobile.html | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/front/control_panel.html b/front/control_panel.html index 938458b50..857d5430a 100644 --- a/front/control_panel.html +++ b/front/control_panel.html @@ -2059,7 +2059,8 @@

重试配置

启用后,429/503时保留当前凭证继续重试,而非切换到下一个凭证 diff --git a/front/control_panel_mobile.html b/front/control_panel_mobile.html index 32a6d1730..45562041a 100644 --- a/front/control_panel_mobile.html +++ b/front/control_panel_mobile.html @@ -805,6 +805,7 @@ from { opacity: 0; } + to { opacity: 1; } @@ -827,6 +828,7 @@ transform: translateY(-50px); opacity: 0; } + to { transform: translateY(0); opacity: 1; @@ -1788,8 +1790,8 @@

错误重试配置

兼容性配置 ✓ 支持热更新
- 🔄 说明:针对Antigravity模式的优化选项。启用后,即使客户端请求非流式响应,后端也会使用流式API获取数据并收集完整后再返回。 + 🔄 + 说明:针对Antigravity模式的优化选项。启用后,即使客户端请求非流式响应,后端也会使用流式API获取数据并收集完整后再返回。
适用场景:某些情况下流式API比非流式API更稳定,启用此选项可以提高响应质量。
默认:已启用
From fe0e7d893e3ff8b936ff30090b686026f322cb57 Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Fri, 27 Feb 2026 17:32:00 +0800 Subject: [PATCH 29/47] =?UTF-8?q?render=E4=BF=9D=E6=B4=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 36 ++++++++++++++ front/common.js | 7 ++- front/control_panel.html | 19 +++++++ front/control_panel_mobile.html | 17 +++++++ src/keeplive.py | 88 +++++++++++++++++++++++++++++++++ src/panel/config_routes.py | 25 ++++++++++ web.py | 13 +++++ 7 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 src/keeplive.py diff --git a/config.py b/config.py index d776e7a59..dd2fc020e 100644 --- a/config.py +++ b/config.py @@ -45,6 +45,8 @@ "API_PASSWORD": "api_password", "PANEL_PASSWORD": "panel_password", "PASSWORD": "password", + "KEEPALIVE_URL": "keepalive_url", + "KEEPALIVE_INTERVAL": "keepalive_interval", } @@ -459,3 +461,37 @@ async def get_antigravity_api_url() -> str: "ANTIGRAVITY_API_URL", ) ) + + +async def get_keepalive_url() -> str: + """ + Get keep-alive URL setting. + + 配置后保活服务会定期向该URL发送GET请求。 + 留空表示禁用保活服务。 + + Environment variable: KEEPALIVE_URL + Database config key: keepalive_url + Default: "" (disabled) + """ + return str(await get_config_value("keepalive_url", "", "KEEPALIVE_URL")) + + +async def get_keepalive_interval() -> int: + """ + Get keep-alive interval in seconds. + + 保活请求发送间隔(秒)。 + + Environment variable: KEEPALIVE_INTERVAL + Database config key: keepalive_interval + Default: 60 + """ + env_value = os.getenv("KEEPALIVE_INTERVAL") + if env_value: + try: + return int(env_value) + except ValueError: + pass + + return int(await get_config_value("keepalive_interval", 60)) diff --git a/front/common.js b/front/common.js index 8f9b6d593..d6288c5a4 100644 --- a/front/common.js +++ b/front/common.js @@ -2696,6 +2696,9 @@ function populateConfigForm() { document.getElementById('antigravityStream2nostream').checked = Boolean(c.antigravity_stream2nostream !== false); setConfigField('antiTruncationMaxAttempts', c.anti_truncation_max_attempts || 3); + + setConfigField('keepaliveUrl', c.keepalive_url || ''); + setConfigField('keepaliveInterval', c.keepalive_interval || 60); } function setConfigField(fieldId, value) { @@ -2745,7 +2748,9 @@ async function saveConfig() { compatibility_mode_enabled: getChecked('compatibilityModeEnabled'), return_thoughts_to_frontend: getChecked('returnThoughtsToFrontend'), antigravity_stream2nostream: getChecked('antigravityStream2nostream'), - anti_truncation_max_attempts: getInt('antiTruncationMaxAttempts', 3) + anti_truncation_max_attempts: getInt('antiTruncationMaxAttempts', 3), + keepalive_url: getValue('keepaliveUrl'), + keepalive_interval: getInt('keepaliveInterval', 60) }; const response = await fetch('./config/save', { diff --git a/front/control_panel.html b/front/control_panel.html index 857d5430a..671a80269 100644 --- a/front/control_panel.html +++ b/front/control_panel.html @@ -2138,6 +2138,25 @@

抗截断配置

+
+

保活配置

+ +
+ + + 配置后服务会定期向该 URL 发送 GET 请求以保持在线;留空则禁用保活 ✓ 支持热更新 +
+ +
+ + + 两次保活请求之间的等待时间,范围 5 - 86400 秒,默认 60 ✓ 支持热更新 +
+
+

配置热更新说明

diff --git a/front/control_panel_mobile.html b/front/control_panel_mobile.html index 45562041a..130d4ff41 100644 --- a/front/control_panel_mobile.html +++ b/front/control_panel_mobile.html @@ -1875,6 +1875,23 @@

抗截断配置

+
+

保活配置

+ +
+ + + 配置后服务会定期向该 URL 发送 GET 请求以保持在线;留空则禁用保活 ✓ 支持热更新 +
+ +
+ + + 两次保活请求之间的等待时间,范围 5 - 86400 秒,默认 60 ✓ 支持热更新 +
+
+

配置热更新说明

diff --git a/src/keeplive.py b/src/keeplive.py new file mode 100644 index 000000000..0844d70f7 --- /dev/null +++ b/src/keeplive.py @@ -0,0 +1,88 @@ +""" +保活服务模块 +定期向配置的URL发送GET请求,保持服务在线 +未配置保活URL时不启动任何任务,零资源占用 +""" + +import asyncio +from typing import Optional + +from config import get_keepalive_interval, get_keepalive_url +from log import log +from src.httpx_client import get_async + + +class KeepAliveService: + """保活服务:定期向指定URL发送GET请求""" + + def __init__(self): + self._task: Optional[asyncio.Task] = None + + async def _run(self, url: str, interval: int): + """保活循环,读取到有效URL才会被调用""" + log.info(f"[KeepAlive] 保活任务启动,URL={url},间隔={interval}s") + while True: + try: + response = await get_async(url, timeout=30.0) + log.info(f"[KeepAlive] GET {url} -> {response.status_code}") + except asyncio.CancelledError: + raise + except Exception as e: + log.warning(f"[KeepAlive] GET {url} 失败: {e}") + + try: + await asyncio.sleep(interval) + except asyncio.CancelledError: + raise + + async def start(self): + """ + 启动保活服务。 + 仅当配置了有效的保活URL时才创建后台任务,否则零开销。 + """ + if self._task and not self._task.done(): + # 已有任务在运行,不重复启动 + return + + url = await get_keepalive_url() + interval = await get_keepalive_interval() + + if not url or not url.strip(): + log.debug("[KeepAlive] 未配置保活URL,保活服务不启动") + return + + if interval <= 0: + log.warning(f"[KeepAlive] 保活间隔无效({interval}),保活服务不启动") + return + + self._task = asyncio.create_task( + self._run(url.strip(), interval), name="keepalive_service" + ) + + async def stop(self): + """停止保活服务""" + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + log.info("[KeepAlive] 保活服务已停止") + self._task = None + + async def restart(self): + """ + 重启保活服务。 + 配置变更时调用,会停止旧任务并根据最新配置决定是否启动新任务。 + """ + await self.stop() + await self.start() + + @property + def is_running(self) -> bool: + """当前保活任务是否在运行""" + return self._task is not None and not self._task.done() + + +# 全局保活服务实例 +keepalive_service = KeepAliveService() diff --git a/src/panel/config_routes.py b/src/panel/config_routes.py index 8f82aa9d8..a565c7bf7 100644 --- a/src/panel/config_routes.py +++ b/src/panel/config_routes.py @@ -9,6 +9,7 @@ import config from log import log +from src.keeplive import keepalive_service from src.models import ConfigSaveRequest from src.storage_adapter import get_storage_adapter from src.utils import verify_panel_token @@ -62,6 +63,10 @@ async def get_config(token: str = Depends(verify_panel_token)): # Antigravity流式转非流式配置 current_config["antigravity_stream2nostream"] = await config.get_antigravity_stream2nostream() + # 保活配置 + current_config["keepalive_url"] = await config.get_keepalive_url() + current_config["keepalive_interval"] = await config.get_keepalive_interval() + # 服务器配置 current_config["host"] = await config.get_server_host() current_config["port"] = await config.get_server_port() @@ -144,7 +149,19 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa if "retry_429_keep_credential" in new_config: if not isinstance(new_config["retry_429_keep_credential"], bool): raise HTTPException(status_code=400, detail="429/503无cd时保持凭证开关必须是布尔值") + # 验证保活配置 + if "keepalive_url" in new_config: + if not isinstance(new_config["keepalive_url"], str): + raise HTTPException(status_code=400, detail="保活URL必须是字符串") + if "keepalive_interval" in new_config: + try: + interval = int(new_config["keepalive_interval"]) + if interval < 5 or interval > 86400: + raise HTTPException(status_code=400, detail="保活间隔必须在 5-86400 秒之间") + new_config["keepalive_interval"] = interval + except (ValueError, TypeError): + raise HTTPException(status_code=400, detail="保活间隔必须是有效整数") # 验证服务器配置 if "host" in new_config: if not isinstance(new_config["host"], str) or not new_config["host"].strip(): @@ -184,6 +201,14 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa # 重新加载配置缓存(关键!) await config.reload_config() + # 如果保活相关配置发生变化,立即重启保活服务 + keepalive_keys = {"keepalive_url", "keepalive_interval"} + if keepalive_keys & set(new_config.keys()): + try: + await keepalive_service.restart() + except Exception as e: + log.warning(f"重启保活服务失败: {e}") + # 验证保存后的结果 test_api_password = await config.get_api_password() test_panel_password = await config.get_panel_password() diff --git a/web.py b/web.py index cc635f353..17b0fced8 100644 --- a/web.py +++ b/web.py @@ -27,6 +27,7 @@ from src.router.geminicli.model_list import router as geminicli_model_list_router from src.task_manager import shutdown_all_tasks from src.panel import router as panel_router +from src.keeplive import keepalive_service # 全局凭证管理器 global_credential_manager = None @@ -59,11 +60,23 @@ async def lifespan(app: FastAPI): # OAuth回调服务器将在需要时按需启动 + # 启动保活服务(未配置URL时自动跳过,零开销) + try: + await keepalive_service.start() + except Exception as e: + log.error(f"保活服务启动失败: {e}") + yield # 清理资源 log.info("开始关闭 GCLI2API 主服务") + # 停止保活服务 + try: + await keepalive_service.stop() + except Exception as e: + log.error(f"关闭保活服务时出错: {e}") + # 首先关闭所有异步任务 try: await shutdown_all_tasks(timeout=10.0) From 53ff29fed2b34d66159c20e9047fa5805ff25b23 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Feb 2026 09:13:13 +0000 Subject: [PATCH 30/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index 95b7fb7bd..5b99d35d0 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=ee5ea5dd651b2525dfb39f6d621c5f9d9c1e703d -short_hash=ee5ea5d -message=配置项:重试时是否切换凭证 -date=2026-02-27 17:11:48 +0800 +full_hash=ec3a427311f85361065fa698ed9ebb1f8056e80a +short_hash=ec3a427 +message=1 +date=2026-02-27 17:13:02 +0800 From fedf2cd3a2e3fed146c91ea3e40ed756c66036f4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Feb 2026 12:35:12 +0000 Subject: [PATCH 31/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index 5b99d35d0..920c094a6 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=ec3a427311f85361065fa698ed9ebb1f8056e80a -short_hash=ec3a427 -message=1 -date=2026-02-27 17:13:02 +0800 +full_hash=169accc1f76c046746118bd7ab087774920d6c32 +short_hash=169accc +message=Merge branch 'master' of https://github.com/su-kaka/gcli2api +date=2026-02-27 20:34:56 +0800 From dd5b63083870f058207f6ad07caafd963666094b Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Fri, 27 Feb 2026 23:04:41 +0800 Subject: [PATCH 32/47] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E9=87=8D=E8=AF=95?= =?UTF-8?q?=E6=97=B6=E4=B8=8D=E5=88=87=E6=8D=A2=E5=87=AD=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 22 +---------- front/common.js | 2 - front/control_panel.html | 9 ----- front/control_panel_mobile.html | 10 ----- src/api/antigravity.py | 69 +++++++++++--------------------- src/api/geminicli.py | 70 +++++++++++---------------------- src/panel/config_routes.py | 7 ---- 7 files changed, 46 insertions(+), 143 deletions(-) diff --git a/config.py b/config.py index dd2fc020e..77b5639ba 100644 --- a/config.py +++ b/config.py @@ -35,7 +35,6 @@ "RETRY_429_MAX_RETRIES": "retry_429_max_retries", "RETRY_429_ENABLED": "retry_429_enabled", "RETRY_429_INTERVAL": "retry_429_interval", - "RETRY_429_KEEP_CREDENTIAL": "retry_429_keep_credential", "ANTI_TRUNCATION_MAX_ATTEMPTS": "anti_truncation_max_attempts", "COMPATIBILITY_MODE": "compatibility_mode_enabled", "RETURN_THOUGHTS_TO_FRONTEND": "return_thoughts_to_frontend", @@ -179,26 +178,7 @@ async def get_retry_429_interval() -> float: except ValueError: pass - return float(await get_config_value("retry_429_interval", 0.1)) - - -async def get_retry_429_keep_credential() -> bool: - """ - Get 429/503 keep credential setting. - - 控制当 429/503 错误没有冷却时间(无 cd)时,重试是否保持当前凭证。 - - True(默认):保留当前凭证重试,不切换 - - False:切换到下一个可用凭证重试 - - Environment variable: RETRY_429_KEEP_CREDENTIAL - Database config key: retry_429_keep_credential - Default: True - """ - env_value = os.getenv("RETRY_429_KEEP_CREDENTIAL") - if env_value: - return env_value.lower() in ("true", "1", "yes", "on") - - return bool(await get_config_value("retry_429_keep_credential", True)) + return float(await get_config_value("retry_429_interval", 1)) async def get_anti_truncation_max_attempts() -> int: diff --git a/front/common.js b/front/common.js index d6288c5a4..16e5017dc 100644 --- a/front/common.js +++ b/front/common.js @@ -2689,7 +2689,6 @@ function populateConfigForm() { document.getElementById('retry429Enabled').checked = Boolean(c.retry_429_enabled); setConfigField('retry429MaxRetries', c.retry_429_max_retries || 20); setConfigField('retry429Interval', c.retry_429_interval || 0.1); - document.getElementById('retry429KeepCredential').checked = Boolean(c.retry_429_keep_credential !== false); document.getElementById('compatibilityModeEnabled').checked = Boolean(c.compatibility_mode_enabled); document.getElementById('returnThoughtsToFrontend').checked = Boolean(c.return_thoughts_to_frontend !== false); @@ -2744,7 +2743,6 @@ async function saveConfig() { retry_429_enabled: getChecked('retry429Enabled'), retry_429_max_retries: getInt('retry429MaxRetries', 20), retry_429_interval: getFloat('retry429Interval', 0.1), - retry_429_keep_credential: getChecked('retry429KeepCredential'), compatibility_mode_enabled: getChecked('compatibilityModeEnabled'), return_thoughts_to_frontend: getChecked('returnThoughtsToFrontend'), antigravity_stream2nostream: getChecked('antigravityStream2nostream'), diff --git a/front/control_panel.html b/front/control_panel.html index 671a80269..0ac18ed90 100644 --- a/front/control_panel.html +++ b/front/control_panel.html @@ -2056,15 +2056,6 @@

重试配置

step="0.01" /> 遇到错误时每两次重试间的等待时间
- -
- - 启用后,429/503时保留当前凭证继续重试,而非切换到下一个凭证 -
diff --git a/front/control_panel_mobile.html b/front/control_panel_mobile.html index 130d4ff41..4360ac995 100644 --- a/front/control_panel_mobile.html +++ b/front/control_panel_mobile.html @@ -1787,16 +1787,6 @@

错误重试配置

遇到错误时每两次重试间的等待时间 - -
- - 启用后,429/503时保留当前凭证继续重试,而非切换到下一个凭证 -
diff --git a/src/api/antigravity.py b/src/api/antigravity.py index 63800064f..b55ba5788 100644 --- a/src/api/antigravity.py +++ b/src/api/antigravity.py @@ -14,7 +14,6 @@ get_antigravity_api_url, get_antigravity_stream2nostream, get_auto_ban_error_codes, - get_retry_429_keep_credential, ) from log import log @@ -145,7 +144,6 @@ async def stream_request( DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 next_cred_task = None # 预热的下一个凭证任务 - keep_credential_config = await get_retry_429_keep_credential() # 无cd时是否保留凭证 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) async def refresh_credential_fast(): @@ -168,7 +166,6 @@ async def refresh_credential_fast(): for attempt in range(max_retries + 1): success_recorded = False # 标记是否已记录成功 need_retry = False # 标记是否需要重试 - keep_current_credential = False # 标记是否保留当前凭证重试(无cd的429/503) try: async for chunk in stream_post_async( @@ -193,32 +190,28 @@ async def refresh_credential_fast(): if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: log.warning(f"[ANTIGRAVITY STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") - # 先解析冷却时间,再决定是否切换凭证 + # 解析冷却时间 cooldown_until = None if (status_code == 429 or status_code == 503) and error_body: - # 使用已缓存的error_body解析冷却时间 try: cooldown_until = await parse_and_log_cooldown(error_body, mode="antigravity") except Exception: pass - # 对于没有触发cd的429/503错误,根据配置决定是否保留当前凭证重试 - if (status_code == 429 or status_code == 503) and cooldown_until is None: - keep_current_credential = keep_credential_config - if not keep_current_credential and next_cred_task is None and attempt < max_retries: + # 预热下一个凭证 + if next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( credential_manager.get_valid_credential( mode="antigravity", model_name=model_name ) ) - # 无cd的429/503保留当前凭证重试,无需记录错误 - if not keep_current_credential: - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="antigravity", model_name=model_name, - error_message=error_body - ) + # 记录错误并切换凭证 + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="antigravity", model_name=model_name, + error_message=error_body + ) # 检查是否应该重试 should_retry = await handle_error_with_retry( @@ -291,12 +284,6 @@ async def refresh_credential_fast(): if need_retry: log.info(f"[ANTIGRAVITY STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") - # 对于没有冷却时间的429错误,保留当前凭证重试 - if keep_current_credential: - log.info(f"[ANTIGRAVITY STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}") - await asyncio.sleep(retry_interval) - continue - # 使用预热的凭证任务,避免等待 if next_cred_task is not None: try: @@ -432,7 +419,6 @@ async def non_stream_request( DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 next_cred_task = None # 预热的下一个凭证任务 - keep_credential_config = await get_retry_429_keep_credential() # 无cd时是否保留凭证 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) async def refresh_credential_fast(): @@ -517,34 +503,29 @@ async def refresh_credential_fast(): if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: log.warning(f"[ANTIGRAVITY] 非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") - # 先解析冷却时间,再决定是否切换凭证 + # 解析冷却时间 cooldown_until = None if status_code == 429 or status_code == 503 and error_text: - # 使用已缓存的error_text解析冷却时间 try: cooldown_until = await parse_and_log_cooldown(error_text, mode="antigravity") except Exception: pass - # 对于没有触发cd且配置为保留凭证时,不预热新凭证 - no_cd_keep = (status_code == 429 or status_code == 503) and cooldown_until is None and keep_credential_config - if not no_cd_keep: - # 并行预热下一个凭证,不阻塞当前处理 - if next_cred_task is None and attempt < max_retries: - next_cred_task = asyncio.create_task( - credential_manager.get_valid_credential( - mode="antigravity", model_name=model_name - ) + # 并行预热下一个凭证,不阻塞当前处理 + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="antigravity", model_name=model_name ) - - # 无cd且保留凭证时不记录错误 - if not no_cd_keep: - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="antigravity", model_name=model_name, - error_message=error_text ) + # 记录错误并切换凭证 + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="antigravity", model_name=model_name, + error_message=error_text + ) + # 检查是否应该重试 should_retry = await handle_error_with_retry( credential_manager, status_code, current_file, @@ -554,12 +535,6 @@ async def refresh_credential_fast(): if should_retry and attempt < max_retries: need_retry = True - - # 对于没有冷却时间且配置为保留凭证,保留当前凭证重试 - if no_cd_keep: - log.info(f"[ANTIGRAVITY] {status_code}无冷却时间,保留当前凭证重试: {current_file}") - await asyncio.sleep(retry_interval) - continue else: # 不重试,直接返回原始错误 log.error(f"[ANTIGRAVITY] 达到最大重试次数或不应重试,返回原始错误") diff --git a/src/api/geminicli.py b/src/api/geminicli.py index 21984e2c9..82fa0b3c8 100644 --- a/src/api/geminicli.py +++ b/src/api/geminicli.py @@ -18,7 +18,7 @@ from typing import Any, Dict, Optional from fastapi import Response -from config import get_code_assist_endpoint, get_auto_ban_error_codes, get_retry_429_keep_credential +from config import get_code_assist_endpoint, get_auto_ban_error_codes from log import log from src.credential_manager import credential_manager @@ -147,7 +147,6 @@ async def stream_request( DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 next_cred_task = None # 预热的下一个凭证任务 - keep_credential_config = await get_retry_429_keep_credential() # 无cd时是否保留凭证 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) async def refresh_credential_fast(): @@ -175,7 +174,6 @@ async def refresh_credential_fast(): for attempt in range(max_retries + 1): success_recorded = False # 标记是否已记录成功 need_retry = False # 标记是否需要重试 - keep_current_credential = False # 标记是否保留当前凭证重试(无cd的429/503) try: async for chunk in stream_post_async( @@ -200,32 +198,28 @@ async def refresh_credential_fast(): if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: log.warning(f"[GEMINICLI STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") - # 先解析冷却时间,再决定是否切换凭证 + # 解析冷却时间 cooldown_until = None if (status_code == 429 or status_code == 503) and error_body: - # 使用已缓存的error_body解析冷却时间 try: cooldown_until = await parse_and_log_cooldown(error_body, mode="geminicli") except Exception: pass - # 对于没有触发cd的429/503错误,根据配置决定是否保留当前凭证重试 - if (status_code == 429 or status_code == 503) and cooldown_until is None: - keep_current_credential = keep_credential_config - if not keep_current_credential and next_cred_task is None and attempt < max_retries: + # 预热下一个凭证 + if next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( credential_manager.get_valid_credential( mode="geminicli", model_name=model_name ) ) - # 无cd的429/503保留当前凭证重试,无需记录错误 - if not keep_current_credential: - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="geminicli", model_name=model_name, - error_message=error_body - ) + # 记录错误并切换凭证 + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="geminicli", model_name=model_name, + error_message=error_body + ) # 检查是否应该重试 should_retry = await handle_error_with_retry( @@ -309,12 +303,6 @@ async def refresh_credential_fast(): if need_retry: log.info(f"[GEMINICLI STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") - # 对于没有冷却时间的429错误,保留当前凭证重试 - if keep_current_credential: - log.info(f"[GEMINICLI STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}") - await asyncio.sleep(retry_interval) - continue - # 使用预热的凭证任务,避免等待 if next_cred_task is not None: try: @@ -429,7 +417,6 @@ async def non_stream_request( DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 next_cred_task = None # 预热的下一个凭证任务 - keep_credential_config = await get_retry_429_keep_credential() # 无cd时是否保留凭证 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) async def refresh_credential_fast(): @@ -505,34 +492,29 @@ async def refresh_credential_fast(): if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: log.warning(f"[NON-STREAM] 非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") - # 先解析冷却时间,再决定是否切换凭证 + # 解析冷却时间 cooldown_until = None if (status_code == 429 or status_code == 503) and error_text: - # 使用已缓存的error_text解析冷却时间 try: cooldown_until = await parse_and_log_cooldown(error_text, mode="geminicli") except Exception: pass - # 对于没有触发cd且配置为保留凭证时,不预热新凭证 - no_cd_keep = (status_code == 429 or status_code == 503) and cooldown_until is None and keep_credential_config - if not no_cd_keep: - # 并行预热下一个凭证,不阻塞当前处理 - if next_cred_task is None and attempt < max_retries: - next_cred_task = asyncio.create_task( - credential_manager.get_valid_credential( - mode="geminicli", model_name=model_name - ) + # 并行预热下一个凭证,不阻塞当前处理 + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name ) - - # 无cd且保留凭证时不记录错误 - if not no_cd_keep: - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="geminicli", model_name=model_name, - error_message=error_text ) + # 记录错误并切换凭证 + await record_api_call_error( + credential_manager, current_file, status_code, + cooldown_until, mode="geminicli", model_name=model_name, + error_message=error_text + ) + # 检查是否应该重试(会自动处理禁用逻辑) should_retry = await handle_error_with_retry( credential_manager, status_code, current_file, @@ -544,12 +526,6 @@ async def refresh_credential_fast(): # 重新获取凭证并重试 log.info(f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") - # 对于没有冷却时间且配置为保留凭证,保留当前凭证重试 - if no_cd_keep: - log.info(f"[NON-STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}") - await asyncio.sleep(retry_interval) - continue - # 使用预热的凭证任务,避免等待 if next_cred_task is not None: try: diff --git a/src/panel/config_routes.py b/src/panel/config_routes.py index a565c7bf7..5c090fa7b 100644 --- a/src/panel/config_routes.py +++ b/src/panel/config_routes.py @@ -2,8 +2,6 @@ 配置路由模块 - 处理 /config/* 相关的HTTP请求 """ -import os - from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import JSONResponse @@ -49,8 +47,6 @@ async def get_config(token: str = Depends(verify_panel_token)): current_config["retry_429_max_retries"] = await config.get_retry_429_max_retries() current_config["retry_429_enabled"] = await config.get_retry_429_enabled() current_config["retry_429_interval"] = await config.get_retry_429_interval() - current_config["retry_429_keep_credential"] = await config.get_retry_429_keep_credential() - # 抗截断配置 current_config["anti_truncation_max_attempts"] = await config.get_anti_truncation_max_attempts() @@ -146,9 +142,6 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa if not isinstance(new_config["antigravity_stream2nostream"], bool): raise HTTPException(status_code=400, detail="Antigravity流式转非流式开关必须是布尔值") - if "retry_429_keep_credential" in new_config: - if not isinstance(new_config["retry_429_keep_credential"], bool): - raise HTTPException(status_code=400, detail="429/503无cd时保持凭证开关必须是布尔值") # 验证保活配置 if "keepalive_url" in new_config: if not isinstance(new_config["keepalive_url"], str): From bc4dbd5a57ac0dd1db249154aaecba61e0b231d1 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Feb 2026 15:04:54 +0000 Subject: [PATCH 33/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index 920c094a6..a7c3ba429 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=169accc1f76c046746118bd7ab087774920d6c32 -short_hash=169accc -message=Merge branch 'master' of https://github.com/su-kaka/gcli2api -date=2026-02-27 20:34:56 +0800 +full_hash=209b3503cd21cfeb9e0eda8d802e8ef3651ea567 +short_hash=209b350 +message=移除重试时不切换凭证 +date=2026-02-27 23:04:41 +0800 From fe1081957c9480479527e2fd1dbadaaaeb45301d Mon Sep 17 00:00:00 2001 From: su-kaka <3493227712@qq.com> Date: Fri, 27 Feb 2026 23:32:41 +0800 Subject: [PATCH 34/47] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E4=BC=A0=E8=BE=93=E6=97=B6=E7=9A=84=E6=8A=A5=E9=94=99=E6=98=BE?= =?UTF-8?q?=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/router/antigravity/anthropic.py | 18 +++++++++--------- src/router/antigravity/gemini.py | 10 +++++++--- src/router/antigravity/openai.py | 18 +++++++++--------- src/router/geminicli/anthropic.py | 18 +++++++++--------- src/router/geminicli/gemini.py | 25 +++++++++++++++---------- src/router/geminicli/openai.py | 18 +++++++++--------- 6 files changed, 58 insertions(+), 49 deletions(-) diff --git a/src/router/antigravity/anthropic.py b/src/router/antigravity/anthropic.py index 50e74ac20..aeadf3d05 100644 --- a/src/router/antigravity/anthropic.py +++ b/src/router/antigravity/anthropic.py @@ -193,12 +193,12 @@ async def get_response(): # 错误响应 - 提取错误信息并以SSE格式返回 log.error(f"Fake streaming got error response: status={response.status_code}") - if hasattr(response, "body"): - error_body = response.body.decode() if isinstance(response.body, bytes) else response.body - elif hasattr(response, "content"): - error_body = response.content.decode() if isinstance(response.content, bytes) else response.content - else: - error_body = str(response) + raw = None + if hasattr(response, "body") and response.body: + raw = response.body.decode('utf-8') if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content") and response.content: + raw = response.content.decode('utf-8') if isinstance(response.content, bytes) else response.content + error_body = raw or "" try: error_data = json.loads(error_body) @@ -212,8 +212,7 @@ async def get_response(): yield f"data: {json.dumps(anthropic_error)}\n\n".encode() except Exception: # 如果无法解析为JSON,包装成错误对象 - yield f"data: {json.dumps({'error': error_body})}\n\n".encode() - + yield f"data: {json.dumps({'error': {'code': response.status_code, 'message': error_body or 'upstream error', 'status': 'ERROR'}})}\n\n".encode() yield "data: [DONE]\n\n".encode() return @@ -328,8 +327,8 @@ async def gemini_chunk_wrapper(): # 检查是否是Response对象(错误情况) if isinstance(chunk, Response): # 错误响应,不进行转换,直接传递 - error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') gemini_error = json.loads(error_content.decode('utf-8')) from src.converter.anthropic2gemini import gemini_to_anthropic_response anthropic_error = gemini_to_anthropic_response( @@ -340,6 +339,7 @@ async def gemini_chunk_wrapper(): yield f"data: {json.dumps(anthropic_error)}\n\n".encode('utf-8') except Exception: yield f"data: {json.dumps({'type': 'error', 'error': {'type': 'api_error', 'message': 'Stream error'}})}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" return else: # 确保是bytes类型 diff --git a/src/router/antigravity/gemini.py b/src/router/antigravity/gemini.py index 8b5713f1f..97e370b69 100644 --- a/src/router/antigravity/gemini.py +++ b/src/router/antigravity/gemini.py @@ -356,10 +356,14 @@ async def normal_stream_generator(): # 检查是否是Response对象(错误情况) if isinstance(chunk, Response): # 将Response转换为SSE格式的错误消息 - error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') - error_json = json.loads(error_content.decode('utf-8')) - # 以SSE格式返回错误 + try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') + error_json = json.loads(error_content.decode('utf-8')) + except Exception: + error_json = {"error": {"code": chunk.status_code, "message": "upstream error", "status": "ERROR"}} + log.error(f"[ANTIGRAVITY STREAM] 返回错误给客户端: status={chunk.status_code}, error={str(error_json)[:200]}") yield f"data: {json.dumps(error_json)}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" return # 处理SSE格式的chunk diff --git a/src/router/antigravity/openai.py b/src/router/antigravity/openai.py index ac8b017da..76713e355 100644 --- a/src/router/antigravity/openai.py +++ b/src/router/antigravity/openai.py @@ -190,12 +190,12 @@ async def get_response(): # 错误响应 - 提取错误信息并以SSE格式返回 log.error(f"Fake streaming got error response: status={response.status_code}") - if hasattr(response, "body"): - error_body = response.body.decode() if isinstance(response.body, bytes) else response.body - elif hasattr(response, "content"): - error_body = response.content.decode() if isinstance(response.content, bytes) else response.content - else: - error_body = str(response) + raw = None + if hasattr(response, "body") and response.body: + raw = response.body.decode('utf-8') if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content") and response.content: + raw = response.content.decode('utf-8') if isinstance(response.content, bytes) else response.content + error_body = raw or "" try: error_data = json.loads(error_body) @@ -209,8 +209,7 @@ async def get_response(): yield f"data: {json.dumps(openai_error)}\n\n".encode() except Exception: # 如果无法解析为JSON,包装成错误对象 - yield f"data: {json.dumps({'error': error_body})}\n\n".encode() - + yield f"data: {json.dumps({'error': {'code': response.status_code, 'message': error_body or 'upstream error', 'status': 'ERROR'}})}\n\n".encode() yield "data: [DONE]\n\n".encode() return @@ -354,8 +353,8 @@ async def normal_stream_generator(): # 检查是否是Response对象(错误情况) if isinstance(chunk, Response): # 将Response转换为SSE格式的错误消息 - error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') gemini_error = json.loads(error_content.decode('utf-8')) # 转换为 OpenAI 格式错误 from src.converter.openai2gemini import convert_gemini_to_openai_response @@ -367,6 +366,7 @@ async def normal_stream_generator(): yield f"data: {json.dumps(openai_error)}\n\n".encode('utf-8') except Exception: yield f"data: {json.dumps({'error': 'Stream error'})}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" return else: # 正常的bytes数据,转换为 OpenAI 格式 diff --git a/src/router/geminicli/anthropic.py b/src/router/geminicli/anthropic.py index 33f3391cf..a01ee21b7 100644 --- a/src/router/geminicli/anthropic.py +++ b/src/router/geminicli/anthropic.py @@ -193,12 +193,12 @@ async def get_response(): # 错误响应 - 提取错误信息并以SSE格式返回 log.error(f"Fake streaming got error response: status={response.status_code}") - if hasattr(response, "body"): - error_body = response.body.decode() if isinstance(response.body, bytes) else response.body - elif hasattr(response, "content"): - error_body = response.content.decode() if isinstance(response.content, bytes) else response.content - else: - error_body = str(response) + raw = None + if hasattr(response, "body") and response.body: + raw = response.body.decode('utf-8') if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content") and response.content: + raw = response.content.decode('utf-8') if isinstance(response.content, bytes) else response.content + error_body = raw or "" try: error_data = json.loads(error_body) @@ -212,8 +212,7 @@ async def get_response(): yield f"data: {json.dumps(anthropic_error)}\n\n".encode() except Exception: # 如果无法解析为JSON,包装成错误对象 - yield f"data: {json.dumps({'error': error_body})}\n\n".encode() - + yield f"data: {json.dumps({'error': {'code': response.status_code, 'message': error_body or 'upstream error', 'status': 'ERROR'}})}\n\n".encode() yield "data: [DONE]\n\n".encode() return @@ -328,8 +327,8 @@ async def gemini_chunk_wrapper(): # 检查是否是Response对象(错误情况) if isinstance(chunk, Response): # 错误响应,不进行转换,直接传递 - error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') gemini_error = json.loads(error_content.decode('utf-8')) from src.converter.anthropic2gemini import gemini_to_anthropic_response anthropic_error = gemini_to_anthropic_response( @@ -340,6 +339,7 @@ async def gemini_chunk_wrapper(): yield f"data: {json.dumps(anthropic_error)}\n\n".encode('utf-8') except Exception: yield f"data: {json.dumps({'type': 'error', 'error': {'type': 'api_error', 'message': 'Stream error'}})}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" return else: # 确保是bytes类型 diff --git a/src/router/geminicli/gemini.py b/src/router/geminicli/gemini.py index 0b2ef9238..7957ce5ea 100644 --- a/src/router/geminicli/gemini.py +++ b/src/router/geminicli/gemini.py @@ -203,12 +203,12 @@ async def get_response(): # 错误响应 - 提取错误信息并以SSE格式返回 log.error(f"Fake streaming got error response: status={response.status_code}") - if hasattr(response, "body"): - error_body = response.body.decode() if isinstance(response.body, bytes) else response.body - elif hasattr(response, "content"): - error_body = response.content.decode() if isinstance(response.content, bytes) else response.content - else: - error_body = str(response) + raw = None + if hasattr(response, "body") and response.body: + raw = response.body.decode('utf-8') if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content") and response.content: + raw = response.content.decode('utf-8') if isinstance(response.content, bytes) else response.content + error_body = raw or "" try: error_data = json.loads(error_body) @@ -216,7 +216,7 @@ async def get_response(): yield f"data: {json.dumps(error_data)}\n\n".encode() except Exception: # 如果无法解析为JSON,包装成错误对象 - yield f"data: {json.dumps({'error': error_body})}\n\n".encode() + yield f"data: {json.dumps({'error': {'code': response.status_code, 'message': error_body or 'upstream error', 'status': 'ERROR'}})}\n\n".encode() yield "data: [DONE]\n\n".encode() return @@ -355,10 +355,15 @@ async def normal_stream_generator(): # 检查是否是Response对象(错误情况) if isinstance(chunk, Response): # 将Response转换为SSE格式的错误消息 - error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') - error_json = json.loads(error_content.decode('utf-8')) - # 以SSE格式返回错误 + try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') + error_json = json.loads(error_content.decode('utf-8')) + except Exception: + error_json = {"error": {"code": chunk.status_code, "message": "upstream error", "status": "ERROR"}} + log.error(f"[GEMINICLI STREAM] 返回错误给客户端: status={chunk.status_code}, error={str(error_json)[:200]}") + # 以SSE格式返回错误,并以[DONE]结束 yield f"data: {json.dumps(error_json)}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" return # 处理SSE格式的chunk diff --git a/src/router/geminicli/openai.py b/src/router/geminicli/openai.py index df4fa1895..3d58b4891 100644 --- a/src/router/geminicli/openai.py +++ b/src/router/geminicli/openai.py @@ -190,12 +190,12 @@ async def get_response(): # 错误响应 - 提取错误信息并以SSE格式返回 log.error(f"Fake streaming got error response: status={response.status_code}") - if hasattr(response, "body"): - error_body = response.body.decode() if isinstance(response.body, bytes) else response.body - elif hasattr(response, "content"): - error_body = response.content.decode() if isinstance(response.content, bytes) else response.content - else: - error_body = str(response) + raw = None + if hasattr(response, "body") and response.body: + raw = response.body.decode('utf-8') if isinstance(response.body, bytes) else response.body + elif hasattr(response, "content") and response.content: + raw = response.content.decode('utf-8') if isinstance(response.content, bytes) else response.content + error_body = raw or "" try: error_data = json.loads(error_body) @@ -209,8 +209,7 @@ async def get_response(): yield f"data: {json.dumps(openai_error)}\n\n".encode() except Exception: # 如果无法解析为JSON,包装成错误对象 - yield f"data: {json.dumps({'error': error_body})}\n\n".encode() - + yield f"data: {json.dumps({'error': {'code': response.status_code, 'message': error_body or 'upstream error', 'status': 'ERROR'}})}\n\n".encode() yield "data: [DONE]\n\n".encode() return @@ -354,8 +353,8 @@ async def normal_stream_generator(): # 检查是否是Response对象(错误情况) if isinstance(chunk, Response): # 将Response转换为SSE格式的错误消息 - error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') try: + error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') gemini_error = json.loads(error_content.decode('utf-8')) # 转换为 OpenAI 格式错误 from src.converter.openai2gemini import convert_gemini_to_openai_response @@ -367,6 +366,7 @@ async def normal_stream_generator(): yield f"data: {json.dumps(openai_error)}\n\n".encode('utf-8') except Exception: yield f"data: {json.dumps({'error': 'Stream error'})}\n\n".encode('utf-8') + yield b"data: [DONE]\n\n" return else: # 正常的bytes数据,转换为 OpenAI 格式 From 3af729ee3c952e5b59afd537036434cbf42a45fd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 27 Feb 2026 15:33:04 +0000 Subject: [PATCH 35/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index a7c3ba429..0307d9bc4 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=209b3503cd21cfeb9e0eda8d802e8ef3651ea567 -short_hash=209b350 -message=移除重试时不切换凭证 -date=2026-02-27 23:04:41 +0800 +full_hash=3a2aec4900e6e2139a2156e3cb50ab351e5fb7ab +short_hash=3a2aec4 +message=优化流式传输时的报错显示 +date=2026-02-27 23:32:41 +0800 From f67c1e12532b416e642de26c7da547c138bc1955 Mon Sep 17 00:00:00 2001 From: afu6609 Date: Sat, 28 Feb 2026 11:34:44 +0800 Subject: [PATCH 36/47] fix: add patternProperties/dependencies/propertyNames to _clean_schema_for_claude unsupported_keys Google's Gemini API uses a protobuf-based schema parser that only accepts a whitelist of JSON Schema fields. When the model name contains 'claude', tools are cleaned via _clean_schema_for_claude() instead of _clean_schema_for_gemini(). However, _clean_schema_for_claude() was missing patternProperties, dependencies, and propertyNames from its unsupported_keys set, causing Google API to reject requests with: Invalid JSON payload received. Unknown name "patternProperties" This affects any OpenAI-compatible client (e.g. OpenClaw) that includes patternProperties in tool parameter schemas when routed through the Claude/Antigravity path. Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1/Schema --- src/converter/openai2gemini.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/converter/openai2gemini.py b/src/converter/openai2gemini.py index 7f7187a19..739b4bfee 100644 --- a/src/converter/openai2gemini.py +++ b/src/converter/openai2gemini.py @@ -285,6 +285,7 @@ def _clean_schema_for_claude(schema: Any, root_schema: Optional[Dict[str, Any]] "const", # const 可能导致问题 "contentEncoding", "contentMediaType", "oneOf", # oneOf 可能导致问题,用 anyOf 替代 + "patternProperties", "dependencies", "propertyNames", # Google API 不支持 } for key in list(result.keys()): From 85a0188f8dc8dae9786bd0170ab3ee953ead8fce Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 28 Feb 2026 07:39:57 +0000 Subject: [PATCH 37/47] chore: update version.txt [skip ci] --- version.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/version.txt b/version.txt index 0307d9bc4..b4fe1f1d4 100644 --- a/version.txt +++ b/version.txt @@ -1,4 +1,4 @@ -full_hash=3a2aec4900e6e2139a2156e3cb50ab351e5fb7ab -short_hash=3a2aec4 -message=优化流式传输时的报错显示 -date=2026-02-27 23:32:41 +0800 +full_hash=43e82b5fbd4930dcc96cd02205b032df16cf70ec +short_hash=43e82b5 +message=Merge pull request #347 from afu6609/fix/clean-schema-claude-patternProperties +date=2026-02-28 15:39:48 +0800 From 400d1c4df85b62304483f4a7dd45f531460fc9a3 Mon Sep 17 00:00:00 2001 From: CI User Date: Thu, 26 Feb 2026 01:56:37 +0800 Subject: [PATCH 38/47] fix: enable HTTP/2, handle MODEL_CAPACITY_EXHAUSTED with exponential backoff, raise retry interval default - Enable HTTP/2 in httpx client to match Google cloudcode-pa endpoint expectations (fixes intermittent 'All connection attempts failed' errors) - Detect MODEL_CAPACITY_EXHAUSTED errors in 429 responses and apply exponential backoff with jitter (capped at 12s) instead of flat retry interval - When capacity is exhausted, rotate to next credential instead of hammering the same one (prevents thundering herd amplification) - Raise default RETRY_429_INTERVAL from 0.1s to 1.0s - Improve exception handler logging with type/repr for better diagnostics - Fix CredentialManager type hints in utils.py (Any instead of concrete class) --- config.py | 19 +- src/api/geminicli.py | 483 ++++++++++++++++++++++++++++++++----------- src/api/utils.py | 185 ++++++++++------- src/httpx_client.py | 7 +- 4 files changed, 503 insertions(+), 191 deletions(-) diff --git a/config.py b/config.py index 77b5639ba..fd53b3787 100644 --- a/config.py +++ b/config.py @@ -51,6 +51,7 @@ # ====================== 配置系统 ====================== + async def init_config(): """初始化配置缓存(启动时调用一次)""" global _config_cache, _config_initialized @@ -60,6 +61,7 @@ async def init_config(): try: from src.storage_adapter import get_storage_adapter + storage_adapter = await get_storage_adapter() _config_cache = await storage_adapter.get_all_config() _config_initialized = True @@ -75,10 +77,11 @@ async def reload_config(): try: from src.storage_adapter import get_storage_adapter + storage_adapter = await get_storage_adapter() # 如果后端支持 reload_config_cache,调用它 - if hasattr(storage_adapter._backend, 'reload_config_cache'): + if hasattr(storage_adapter._backend, "reload_config_cache"): await storage_adapter._backend.reload_config_cache() # 重新加载配置缓存 @@ -93,7 +96,9 @@ def _get_cached_config(key: str, default: Any = None) -> Any: return _config_cache.get(key, default) -async def get_config_value(key: str, default: Any = None, env_var: Optional[str] = None) -> Any: +async def get_config_value( + key: str, default: Any = None, env_var: Optional[str] = None +) -> Any: """Get configuration value with priority: ENV > Storage > default.""" # 确保配置已初始化 if not _config_initialized: @@ -178,7 +183,7 @@ async def get_retry_429_interval() -> float: except ValueError: pass - return float(await get_config_value("retry_429_interval", 1)) + return float(await get_config_value("retry_429_interval", 1.0)) async def get_anti_truncation_max_attempts() -> int: @@ -295,7 +300,9 @@ async def get_code_assist_endpoint() -> str: """ return str( await get_config_value( - "code_assist_endpoint", "https://cloudcode-pa.googleapis.com", "CODE_ASSIST_ENDPOINT" + "code_assist_endpoint", + "https://cloudcode-pa.googleapis.com", + "CODE_ASSIST_ENDPOINT", ) ) @@ -419,7 +426,9 @@ async def get_service_usage_api_url() -> str: """ return str( await get_config_value( - "service_usage_api_url", "https://serviceusage.googleapis.com", "SERVICE_USAGE_API_URL" + "service_usage_api_url", + "https://serviceusage.googleapis.com", + "SERVICE_USAGE_API_URL", ) ) diff --git a/src/api/geminicli.py b/src/api/geminicli.py index 82fa0b3c8..e0c56d46c 100644 --- a/src/api/geminicli.py +++ b/src/api/geminicli.py @@ -15,6 +15,8 @@ import asyncio import json +import random +import time from typing import Any, Dict, Optional from fastapi import Response @@ -34,6 +36,59 @@ ) from src.utils import GEMINICLI_USER_AGENT + +def _extract_gemini_error_info( + error_text: Optional[str], +) -> tuple[Optional[str], Optional[str]]: + """提取 Gemini 错误状态与 ErrorInfo.reason,兼容非标准错误结构。""" + if not error_text: + return None, None + try: + error_data = json.loads(error_text) + except Exception: + return None, None + + error_obj = error_data.get("error", {}) if isinstance(error_data, dict) else {} + error_status = error_obj.get("status") if isinstance(error_obj, dict) else None + details = error_obj.get("details", []) if isinstance(error_obj, dict) else [] + error_message = error_obj.get("message") if isinstance(error_obj, dict) else None + + # 兼容非标准错误结构(例如仅返回 message 文本) + top_message = error_data.get("message") if isinstance(error_data, dict) else None + message_candidates = [ + error_message if isinstance(error_message, str) else None, + top_message if isinstance(top_message, str) else None, + ] + merged_message = " ".join([m for m in message_candidates if m]).lower() + + reason = None + if isinstance(details, list): + for detail in details: + if not isinstance(detail, dict): + continue + if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": + reason = detail.get("reason") + if reason: + break + + # 文本兜底:识别容量耗尽的非标准消息 + if not reason and merged_message: + if ( + "no capacity available" in merged_message + or "model is overloaded" in merged_message + ): + reason = "MODEL_CAPACITY_EXHAUSTED" + + return error_status, reason + + +def _compute_capacity_retry_delay(base_interval: float, attempt: int) -> float: + """容量耗尽时的退避等待(指数退避 + 抖动),上限 12 秒。""" + exp_backoff = base_interval * (2**attempt) + jitter = random.uniform(0.0, max(0.2, base_interval)) + return min(12.0, exp_backoff + jitter) + + # ==================== 全局凭证管理器 ==================== # 使用全局单例 credential_manager,自动初始化 @@ -41,20 +96,21 @@ # ==================== 请求准备 ==================== + async def prepare_request_headers_and_payload( payload: dict, credential_data: dict, target_url: str ): """ 从凭证数据准备请求头和最终payload - + Args: payload: 原始请求payload credential_data: 凭证数据字典 target_url: 目标URL - + Returns: 元组: (headers, final_payload, target_url) - + Raises: Exception: 如果凭证中缺少必要字段 """ @@ -84,6 +140,7 @@ async def prepare_request_headers_and_payload( # ==================== 新的流式和非流式请求函数 ==================== + async def stream_request( body: Dict[str, Any], native: bool = False, @@ -113,7 +170,7 @@ async def stream_request( yield Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, - media_type="application/json" + media_type="application/json", ) return @@ -121,9 +178,14 @@ async def stream_request( # 2. 构建URL和请求头 try: - auth_headers, final_payload, target_url = await prepare_request_headers_and_payload( - body, credential_data, - f"{await get_code_assist_endpoint()}/v1internal:streamGenerateContent?alt=sse" + ( + auth_headers, + final_payload, + target_url, + ) = await prepare_request_headers_and_payload( + body, + credential_data, + f"{await get_code_assist_endpoint()}/v1internal:streamGenerateContent?alt=sse", ) # 合并自定义headers @@ -135,7 +197,7 @@ async def stream_request( yield Response( content=json.dumps({"error": f"准备请求失败: {str(e)}"}), status_code=500, - media_type="application/json" + media_type="application/json", ) return @@ -159,7 +221,9 @@ async def refresh_credential_fast(): current_file, credential_data = cred_result try: # 只更新token和project_id,不重建整个headers和payload - token = credential_data.get("token") or credential_data.get("access_token", "") + token = credential_data.get("token") or credential_data.get( + "access_token", "" + ) project_id = credential_data.get("project_id", "") if not token or not project_id: return None @@ -177,10 +241,7 @@ async def refresh_credential_fast(): try: async for chunk in stream_post_async( - url=target_url, - body=final_payload, - native=native, - headers=auth_headers + url=target_url, body=final_payload, native=native, headers=auth_headers ): # 判断是否是Response对象 if isinstance(chunk, Response): @@ -190,42 +251,88 @@ async def refresh_credential_fast(): # 缓存错误解析结果,避免重复decode error_body = None try: - error_body = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + error_body = ( + chunk.body.decode("utf-8") + if isinstance(chunk.body, bytes) + else str(chunk.body) + ) except Exception: error_body = "" # 如果错误码是429、503或者在禁用码当中,做好记录后进行重试 - if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: - log.warning(f"[GEMINICLI STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") + if ( + status_code == 429 + or status_code == 503 + or status_code in DISABLE_ERROR_CODES + ): + log.warning( + f"[GEMINICLI STREAM] 流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}" + ) - # 解析冷却时间 + # 先解析冷却时间与错误类型,再决定是否切换凭证 cooldown_until = None + error_status = None + error_reason = None if (status_code == 429 or status_code == 503) and error_body: try: - cooldown_until = await parse_and_log_cooldown(error_body, mode="geminicli") + cooldown_until = await parse_and_log_cooldown( + error_body, mode="geminicli" + ) except Exception: pass + error_status, error_reason = _extract_gemini_error_info( + error_body + ) - # 预热下一个凭证 - if next_cred_task is None and attempt < max_retries: + # 429 MODEL_CAPACITY_EXHAUSTED 视为容量拥塞,不保留同一凭证立即重试 + is_model_capacity_exhausted = ( + status_code == 429 + and error_reason == "MODEL_CAPACITY_EXHAUSTED" + ) + + # 对于没有触发cd且非容量耗尽的429/503错误,保留当前凭证重试;否则预热下一个凭证 + if ( + (status_code == 429 or status_code == 503) + and cooldown_until is None + and not is_model_capacity_exhausted + ): + keep_current_credential = True + elif next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( credential_manager.get_valid_credential( mode="geminicli", model_name=model_name ) ) - # 记录错误并切换凭证 - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="geminicli", model_name=model_name, - error_message=error_body - ) + # 容量耗尽没有 quotaResetTimeStamp 时,设置短模型级冷却 + if is_model_capacity_exhausted and cooldown_until is None: + cooldown_until = ( + time.time() + + _compute_capacity_retry_delay(retry_interval, attempt) + ) + + # 无cd的429/503保留当前凭证重试,无需记录错误 + if not keep_current_credential: + await record_api_call_error( + credential_manager, + current_file, + status_code, + cooldown_until, + mode="geminicli", + model_name=model_name, + error_message=error_body, + ) # 检查是否应该重试 should_retry = await handle_error_with_retry( - credential_manager, status_code, current_file, - retry_config["retry_enabled"], attempt, max_retries, retry_interval, - mode="geminicli" + credential_manager, + status_code, + current_file, + retry_config["retry_enabled"], + attempt, + max_retries, + retry_interval, + mode="geminicli", ) if should_retry and attempt < max_retries: @@ -233,27 +340,39 @@ async def refresh_credential_fast(): break # 跳出内层循环,准备重试 else: # 不重试,直接返回原始错误 - log.error(f"[GEMINICLI STREAM] 达到最大重试次数或不应重试,返回原始错误") + log.error( + f"[GEMINICLI STREAM] 达到最大重试次数或不应重试,返回原始错误" + ) yield chunk return elif status_code == 404 and "preview" in model_name.lower(): # 特殊处理:preview模型返回404,说明该凭证不支持preview模型 - log.warning(f"[GEMINICLI STREAM] Preview模型404错误,凭证不支持preview: {current_file}") + log.warning( + f"[GEMINICLI STREAM] Preview模型404错误,凭证不支持preview: {current_file}" + ) # 将该凭证的preview状态设置为False try: await credential_manager.update_credential_state( current_file, {"preview": False}, mode="geminicli" ) - log.info(f"[GEMINICLI STREAM] 已将凭证 {current_file} 的preview状态设置为False") + log.info( + f"[GEMINICLI STREAM] 已将凭证 {current_file} 的preview状态设置为False" + ) except Exception as e: - log.error(f"[GEMINICLI STREAM] 更新凭证preview状态失败: {e}") + log.error( + f"[GEMINICLI STREAM] 更新凭证preview状态失败: {e}" + ) # 记录404错误 await record_api_call_error( - credential_manager, current_file, status_code, - None, mode="geminicli", model_name=model_name, - error_message=error_body + credential_manager, + current_file, + status_code, + None, + mode="geminicli", + model_name=model_name, + error_message=error_body, ) # 预热下一个凭证(会自动跳过preview=False的凭证) @@ -269,16 +388,24 @@ async def refresh_credential_fast(): need_retry = True break else: - log.error(f"[GEMINICLI STREAM] 达到最大重试次数,返回404错误") + log.error( + f"[GEMINICLI STREAM] 达到最大重试次数,返回404错误" + ) yield chunk return else: # 错误码不在禁用码当中,直接返回,无需重试 - log.error(f"[GEMINICLI STREAM] 流式请求失败,非重试错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}") + log.error( + f"[GEMINICLI STREAM] 流式请求失败,非重试错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_body[:500] if error_body else '无'}" + ) await record_api_call_error( - credential_manager, current_file, status_code, - None, mode="geminicli", model_name=model_name, - error_message=error_body + credential_manager, + current_file, + status_code, + None, + mode="geminicli", + model_name=model_name, + error_message=error_body, ) yield chunk return @@ -287,10 +414,15 @@ async def refresh_credential_fast(): # 只在第一个chunk时记录成功 if not success_recorded: await record_api_call_success( - credential_manager, current_file, mode="geminicli", model_name=model_name + credential_manager, + current_file, + mode="geminicli", + model_name=model_name, ) success_recorded = True - log.debug(f"[GEMINICLI STREAM] 开始接收流式响应,模型: {model_name}") + log.debug( + f"[GEMINICLI STREAM] 开始接收流式响应,模型: {model_name}" + ) yield chunk @@ -301,7 +433,17 @@ async def refresh_credential_fast(): # 统一处理重试 if need_retry: - log.info(f"[GEMINICLI STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + log.info( + f"[GEMINICLI STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})..." + ) + + # 对于没有冷却时间的429错误,保留当前凭证重试 + if keep_current_credential: + log.info( + f"[GEMINICLI STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}" + ) + await asyncio.sleep(retry_interval) + continue # 使用预热的凭证任务,避免等待 if next_cred_task is not None: @@ -312,7 +454,9 @@ async def refresh_credential_fast(): if cred_result: current_file, credential_data = cred_result # 使用快速更新方式 - token = credential_data.get("token") or credential_data.get("access_token", "") + token = credential_data.get("token") or credential_data.get( + "access_token", "" + ) project_id = credential_data.get("project_id", "") if token and project_id: auth_headers["Authorization"] = f"Bearer {token}" @@ -331,20 +475,30 @@ async def refresh_credential_fast(): yield Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, - media_type="application/json" + media_type="application/json", ) return continue # 重试 except Exception as e: - log.error(f"[GEMINICLI STREAM] 流式请求异常: {e}, 凭证: {current_file}") + log.error( + f"[GEMINICLI STREAM] 流式请求异常: type={type(e).__name__}, " + f"detail={repr(e)}, 凭证: {current_file}" + ) if attempt < max_retries: - log.info(f"[GEMINICLI STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") - await asyncio.sleep(retry_interval) + delay = _compute_capacity_retry_delay(retry_interval, attempt) + log.info( + f"[GEMINICLI STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1}), " + f"等待 {delay:.1f}s..." + ) + await asyncio.sleep(delay) continue else: # 所有重试都失败,返回最后一次的错误(如果有) - log.error(f"[GEMINICLI STREAM] 所有重试均失败,最后异常: {e}") + log.error( + f"[GEMINICLI STREAM] 所有重试均失败,最后异常: type={type(e).__name__}, " + f"detail={repr(e)}" + ) if last_error_response: yield last_error_response else: @@ -352,7 +506,7 @@ async def refresh_credential_fast(): yield Response( content=json.dumps({"error": f"流式请求异常: {str(e)}"}), status_code=500, - media_type="application/json" + media_type="application/json", ) return @@ -385,16 +539,21 @@ async def non_stream_request( return Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, - media_type="application/json" + media_type="application/json", ) current_file, credential_data = cred_result # 2. 构建URL和请求头 try: - auth_headers, final_payload, target_url = await prepare_request_headers_and_payload( - body, credential_data, - f"{await get_code_assist_endpoint()}/v1internal:generateContent" + ( + auth_headers, + final_payload, + target_url, + ) = await prepare_request_headers_and_payload( + body, + credential_data, + f"{await get_code_assist_endpoint()}/v1internal:generateContent", ) # 合并自定义headers @@ -406,7 +565,7 @@ async def non_stream_request( return Response( content=json.dumps({"error": f"准备请求失败: {str(e)}"}), status_code=500, - media_type="application/json" + media_type="application/json", ) # 3. 调用post_async进行请求 @@ -429,7 +588,9 @@ async def refresh_credential_fast(): current_file, credential_data = cred_result try: # 只更新token和project_id,不重建整个headers和payload - token = credential_data.get("token") or credential_data.get("access_token", "") + token = credential_data.get("token") or credential_data.get( + "access_token", "" + ) project_id = credential_data.get("project_id", "") if not token or not project_id: return None @@ -444,10 +605,7 @@ async def refresh_credential_fast(): for attempt in range(max_retries + 1): try: response = await post_async( - url=target_url, - json=final_payload, - headers=auth_headers, - timeout=300.0 + url=target_url, json=final_payload, headers=auth_headers, timeout=300.0 ) status_code = response.status_code @@ -455,29 +613,28 @@ async def refresh_credential_fast(): # 成功 if status_code == 200: await record_api_call_success( - credential_manager, current_file, mode="geminicli", model_name=model_name + credential_manager, + current_file, + mode="geminicli", + model_name=model_name, ) # 创建响应头,移除压缩相关的header避免重复解压 response_headers = dict(response.headers) - response_headers.pop('content-encoding', None) - response_headers.pop('content-length', None) + response_headers.pop("content-encoding", None) + response_headers.pop("content-length", None) return Response( - content=response.content, - status_code=200, - headers=response_headers + content=response.content, status_code=200, headers=response_headers ) # 失败 - 记录最后一次错误 # 创建响应头,移除压缩相关的header避免重复解压 error_headers = dict(response.headers) - error_headers.pop('content-encoding', None) - error_headers.pop('content-length', None) + error_headers.pop("content-encoding", None) + error_headers.pop("content-length", None) last_error_response = Response( - content=response.content, - status_code=status_code, - headers=error_headers + content=response.content, status_code=status_code, headers=error_headers ) # 判断是否需要重试 @@ -489,42 +646,97 @@ async def refresh_credential_fast(): pass # 统一处理所有需要重试的错误码(429、503、禁用码) - if status_code == 429 or status_code == 503 or status_code in DISABLE_ERROR_CODES: - log.warning(f"[NON-STREAM] 非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") + if ( + status_code == 429 + or status_code == 503 + or status_code in DISABLE_ERROR_CODES + ): + log.warning( + f"[NON-STREAM] 非流式请求失败 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}" + ) - # 解析冷却时间 + # 先解析冷却时间与错误类型,再决定是否切换凭证 cooldown_until = None + error_status = None + error_reason = None if (status_code == 429 or status_code == 503) and error_text: try: - cooldown_until = await parse_and_log_cooldown(error_text, mode="geminicli") + cooldown_until = await parse_and_log_cooldown( + error_text, mode="geminicli" + ) except Exception: pass + error_status, error_reason = _extract_gemini_error_info(error_text) - # 并行预热下一个凭证,不阻塞当前处理 - if next_cred_task is None and attempt < max_retries: - next_cred_task = asyncio.create_task( - credential_manager.get_valid_credential( - mode="geminicli", model_name=model_name + is_model_capacity_exhausted = ( + status_code == 429 and error_reason == "MODEL_CAPACITY_EXHAUSTED" + ) + + # 对于没有触发cd且非容量耗尽的429错误,不预热新凭证 + if not ( + (status_code == 429 or status_code == 503) + and cooldown_until is None + and not is_model_capacity_exhausted + ): + # 并行预热下一个凭证,不阻塞当前处理 + if next_cred_task is None and attempt < max_retries: + next_cred_task = asyncio.create_task( + credential_manager.get_valid_credential( + mode="geminicli", model_name=model_name + ) ) + + # 容量耗尽没有 quotaResetTimeStamp 时,设置短模型级冷却 + if is_model_capacity_exhausted and cooldown_until is None: + cooldown_until = time.time() + _compute_capacity_retry_delay( + retry_interval, attempt ) - # 记录错误并切换凭证 - await record_api_call_error( - credential_manager, current_file, status_code, - cooldown_until, mode="geminicli", model_name=model_name, - error_message=error_text - ) + # 无cd的429/503保留当前凭证重试,无需记录错误 + if not ( + (status_code == 429 or status_code == 503) + and cooldown_until is None + and not is_model_capacity_exhausted + ): + await record_api_call_error( + credential_manager, + current_file, + status_code, + cooldown_until, + mode="geminicli", + model_name=model_name, + error_message=error_text, + ) # 检查是否应该重试(会自动处理禁用逻辑) should_retry = await handle_error_with_retry( - credential_manager, status_code, current_file, - retry_config["retry_enabled"], attempt, max_retries, retry_interval, - mode="geminicli" + credential_manager, + status_code, + current_file, + retry_config["retry_enabled"], + attempt, + max_retries, + retry_interval, + mode="geminicli", ) if should_retry and attempt < max_retries: # 重新获取凭证并重试 - log.info(f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + log.info( + f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})..." + ) + + # 对于没有冷却时间且非容量耗尽的429错误,保留当前凭证重试 + if ( + (status_code == 429 or status_code == 503) + and cooldown_until is None + and not is_model_capacity_exhausted + ): + log.info( + f"[NON-STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}" + ) + await asyncio.sleep(retry_interval) + continue # 使用预热的凭证任务,避免等待 if next_cred_task is not None: @@ -535,7 +747,9 @@ async def refresh_credential_fast(): if cred_result: current_file, credential_data = cred_result # 使用快速更新方式 - token = credential_data.get("token") or credential_data.get("access_token", "") + token = credential_data.get( + "token" + ) or credential_data.get("access_token", "") project_id = credential_data.get("project_id", "") if token and project_id: auth_headers["Authorization"] = f"Bearer {token}" @@ -554,7 +768,7 @@ async def refresh_credential_fast(): return Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, - media_type="application/json" + media_type="application/json", ) continue # 重试 else: @@ -563,22 +777,30 @@ async def refresh_credential_fast(): return last_error_response elif status_code == 404 and "preview" in model_name.lower(): # 特殊处理:preview模型返回404,说明该凭证不支持preview模型 - log.warning(f"[NON-STREAM] Preview模型404错误,凭证不支持preview: {current_file}") + log.warning( + f"[NON-STREAM] Preview模型404错误,凭证不支持preview: {current_file}" + ) # 将该凭证的preview状态设置为False try: await credential_manager.update_credential_state( current_file, {"preview": False}, mode="geminicli" ) - log.info(f"[NON-STREAM] 已将凭证 {current_file} 的preview状态设置为False") + log.info( + f"[NON-STREAM] 已将凭证 {current_file} 的preview状态设置为False" + ) except Exception as e: log.error(f"[NON-STREAM] 更新凭证preview状态失败: {e}") # 记录404错误 await record_api_call_error( - credential_manager, current_file, status_code, - None, mode="geminicli", model_name=model_name, - error_message=error_text + credential_manager, + current_file, + status_code, + None, + mode="geminicli", + model_name=model_name, + error_message=error_text, ) # 预热下一个凭证(会自动跳过preview=False的凭证) @@ -591,7 +813,9 @@ async def refresh_credential_fast(): # 触发重试 if attempt < max_retries: - log.info(f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})...") + log.info( + f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})..." + ) # 使用预热的凭证任务,避免等待 if next_cred_task is not None: @@ -602,7 +826,9 @@ async def refresh_credential_fast(): if cred_result: current_file, credential_data = cred_result # 使用快速更新方式 - token = credential_data.get("token") or credential_data.get("access_token", "") + token = credential_data.get( + "token" + ) or credential_data.get("access_token", "") project_id = credential_data.get("project_id", "") if token and project_id: auth_headers["Authorization"] = f"Bearer {token}" @@ -621,7 +847,7 @@ async def refresh_credential_fast(): return Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, - media_type="application/json" + media_type="application/json", ) continue # 重试 else: @@ -629,30 +855,46 @@ async def refresh_credential_fast(): return last_error_response else: # 错误码不在重试范围内,直接返回 - log.error(f"[NON-STREAM] 非流式请求失败,非重试错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}") + log.error( + f"[NON-STREAM] 非流式请求失败,非重试错误码 (status={status_code}), 凭证: {current_file}, 响应: {error_text[:500] if error_text else '无'}" + ) await record_api_call_error( - credential_manager, current_file, status_code, - None, mode="geminicli", model_name=model_name, - error_message=error_text + credential_manager, + current_file, + status_code, + None, + mode="geminicli", + model_name=model_name, + error_message=error_text, ) return last_error_response except Exception as e: - log.error(f"非流式请求异常: {e}, 凭证: {current_file}") + log.error( + f"非流式请求异常: type={type(e).__name__}, detail={repr(e)}, " + f"凭证: {current_file}" + ) if attempt < max_retries: - log.info(f"[NON-STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1})...") - await asyncio.sleep(retry_interval) + delay = _compute_capacity_retry_delay(retry_interval, attempt) + log.info( + f"[NON-STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1}), " + f"等待 {delay:.1f}s..." + ) + await asyncio.sleep(delay) continue else: # 所有重试都失败,返回最后一次的错误(如果有)或500错误 - log.error(f"[NON-STREAM] 所有重试均失败,最后异常: {e}") + log.error( + f"[NON-STREAM] 所有重试均失败,最后异常: type={type(e).__name__}, " + f"detail={repr(e)}" + ) if last_error_response: return last_error_response else: return Response( content=json.dumps({"error": f"请求异常: {str(e)}"}), status_code=500, - media_type="application/json" + media_type="application/json", ) # 所有重试都失败,返回最后一次的原始错误 @@ -678,10 +920,10 @@ async def refresh_credential_fast(): "contents": [ { "role": "user", - "parts": [{"text": "Hello, tell me a joke in one sentence."}] + "parts": [{"text": "Hello, tell me a joke in one sentence."}], } ] - } + }, } async def test_stream_request(): @@ -703,7 +945,11 @@ async def test_stream_request(): print(f" 状态码: {chunk.status_code}") print(f" Content-Type: {chunk.headers.get('content-type', 'N/A')}") try: - content = chunk.body.decode('utf-8') if isinstance(chunk.body, bytes) else str(chunk.body) + content = ( + chunk.body.decode("utf-8") + if isinstance(chunk.body, bytes) + else str(chunk.body) + ) print(f" 内容: {content}") except Exception as e: print(f" 内容解析失败: {e}") @@ -721,7 +967,9 @@ async def test_stream_request(): if data_line.startswith("data: "): json_str = data_line[6:] # 去掉 "data: " 前缀 json_data = json.loads(json_str) - print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + print( + f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}" + ) except Exception as e: print(f" SSE解析尝试失败: {e}") @@ -743,7 +991,11 @@ async def test_non_stream_request(): print(f"\n响应头: {dict(response.headers)}\n") try: - content = response.body.decode('utf-8') if isinstance(response.body, bytes) else str(response.body) + content = ( + response.body.decode("utf-8") + if isinstance(response.body, bytes) + else str(response.body) + ) print(f"响应内容 (原始):\n{content}\n") # 尝试解析JSON @@ -772,6 +1024,7 @@ async def main(): except Exception as e: print(f"\n❌ 测试过程中出现异常: {e}") import traceback + traceback.print_exc() # 运行测试 diff --git a/src/api/utils.py b/src/api/utils.py index 6dd419499..d899a2c7d 100644 --- a/src/api/utils.py +++ b/src/api/utils.py @@ -18,36 +18,35 @@ get_retry_429_max_retries, ) from log import log -from src.credential_manager import CredentialManager # ==================== 错误检查与处理 ==================== + async def check_should_auto_ban(status_code: int) -> bool: """ 检查是否应该触发自动封禁 - + Args: status_code: HTTP状态码 - + Returns: bool: 是否应该触发自动封禁 """ return ( - await get_auto_ban_enabled() - and status_code in await get_auto_ban_error_codes() + await get_auto_ban_enabled() and status_code in await get_auto_ban_error_codes() ) async def handle_auto_ban( - credential_manager: CredentialManager, + credential_manager: Any, status_code: int, credential_name: str, - mode: str = "geminicli" + mode: str = "geminicli", ) -> None: """ 处理自动封禁:直接禁用凭证 - + Args: credential_manager: 凭证管理器实例 status_code: HTTP状态码 @@ -58,20 +57,18 @@ async def handle_auto_ban( log.warning( f"[{mode.upper()} AUTO_BAN] Status {status_code} triggers auto-ban for credential: {credential_name}" ) - await credential_manager.set_cred_disabled( - credential_name, True, mode=mode - ) + await credential_manager.set_cred_disabled(credential_name, True, mode=mode) async def handle_error_with_retry( - credential_manager: CredentialManager, + credential_manager: Any, status_code: int, credential_name: str, retry_enabled: bool, attempt: int, max_retries: int, retry_interval: float, - mode: str = "geminicli" + mode: str = "geminicli", ) -> bool: """ 统一处理错误和重试逻辑 @@ -90,7 +87,7 @@ async def handle_error_with_retry( max_retries: 最大重试次数 retry_interval: 重试间隔 mode: 模式(geminicli 或 antigravity) - + Returns: bool: True表示需要继续重试,False表示不需要重试 """ @@ -112,7 +109,11 @@ async def handle_error_with_retry( return False # 如果不触发自动封禁,仅对429和503错误进行重试 - if (status_code == 429 or status_code == 503) and retry_enabled and attempt < max_retries: + if ( + (status_code == 429 or status_code == 503) + and retry_enabled + and attempt < max_retries + ): log.info( f"[{mode.upper()} RETRY] {status_code} error encountered, retrying " f"(attempt {attempt + 1}/{max_retries})" @@ -126,10 +127,11 @@ async def handle_error_with_retry( # ==================== 重试配置获取 ==================== + async def get_retry_config() -> Dict[str, Any]: """ 获取重试配置 - + Returns: 包含重试配置的字典 """ @@ -142,15 +144,16 @@ async def get_retry_config() -> Dict[str, Any]: # ==================== API调用结果记录 ==================== + async def record_api_call_success( - credential_manager: CredentialManager, + credential_manager: Any, credential_name: str, mode: str = "geminicli", - model_name: Optional[str] = None + model_name: Optional[str] = None, ) -> None: """ 记录API调用成功 - + Args: credential_manager: 凭证管理器实例 credential_name: 凭证名称 @@ -164,13 +167,13 @@ async def record_api_call_success( async def record_api_call_error( - credential_manager: CredentialManager, + credential_manager: Any, credential_name: str, status_code: int, cooldown_until: Optional[float] = None, mode: str = "geminicli", model_name: Optional[str] = None, - error_message: Optional[str] = None + error_message: Optional[str] = None, ) -> None: """ 记录API调用错误 @@ -192,15 +195,15 @@ async def record_api_call_error( cooldown_until=cooldown_until, mode=mode, model_name=model_name, - error_message=error_message + error_message=error_message, ) # ==================== 429错误处理 ==================== + async def parse_and_log_cooldown( - error_text: str, - mode: str = "geminicli" + error_text: str, mode: str = "geminicli" ) -> Optional[float]: """ 解析并记录冷却时间 @@ -228,6 +231,7 @@ async def parse_and_log_cooldown( # ==================== 流式响应收集 ==================== + async def collect_streaming_response(stream_generator) -> Response: """ 将Gemini流式响应收集为一条完整的非流式响应 @@ -246,20 +250,19 @@ async def collect_streaming_response(stream_generator) -> Response: # 初始化响应结构 merged_response = { "response": { - "candidates": [{ - "content": { - "parts": [], - "role": "model" - }, - "finishReason": None, - "safetyRatings": [], - "citationMetadata": None - }], + "candidates": [ + { + "content": {"parts": [], "role": "model"}, + "finishReason": None, + "safetyRatings": [], + "citationMetadata": None, + } + ], "usageMetadata": { "promptTokenCount": 0, "candidatesTokenCount": 0, - "totalTokenCount": 0 - } + "totalTokenCount": 0, + }, } } @@ -278,23 +281,33 @@ async def collect_streaming_response(stream_generator) -> Response: # 如果收到的是Response对象(错误),直接返回 if isinstance(line, Response): - log.debug(f"[STREAM COLLECTOR] 收到错误Response,状态码: {line.status_code}") + log.debug( + f"[STREAM COLLECTOR] 收到错误Response,状态码: {line.status_code}" + ) return line # 处理 bytes 类型 if isinstance(line, bytes): - line_str = line.decode('utf-8', errors='ignore') - log.debug(f"[STREAM COLLECTOR] Processing bytes line {line_count}: {line_str[:200] if line_str else 'empty'}") + line_str = line.decode("utf-8", errors="ignore") + log.debug( + f"[STREAM COLLECTOR] Processing bytes line {line_count}: {line_str[:200] if line_str else 'empty'}" + ) elif isinstance(line, str): line_str = line - log.debug(f"[STREAM COLLECTOR] Processing line {line_count}: {line_str[:200] if line_str else 'empty'}") + log.debug( + f"[STREAM COLLECTOR] Processing line {line_count}: {line_str[:200] if line_str else 'empty'}" + ) else: - log.debug(f"[STREAM COLLECTOR] Skipping non-string/bytes line: {type(line)}") + log.debug( + f"[STREAM COLLECTOR] Skipping non-string/bytes line: {type(line)}" + ) continue # 解析流式数据行 if not line_str.startswith("data: "): - log.debug(f"[STREAM COLLECTOR] Skipping line without 'data: ' prefix: {line_str[:100]}") + log.debug( + f"[STREAM COLLECTOR] Skipping line without 'data: ' prefix: {line_str[:100]}" + ) continue raw = line_str[6:].strip() @@ -306,18 +319,24 @@ async def collect_streaming_response(stream_generator) -> Response: log.debug(f"[STREAM COLLECTOR] Parsing JSON: {raw[:200]}") chunk = json.loads(raw) has_data = True - log.debug(f"[STREAM COLLECTOR] Chunk keys: {chunk.keys() if isinstance(chunk, dict) else type(chunk)}") + log.debug( + f"[STREAM COLLECTOR] Chunk keys: {chunk.keys() if isinstance(chunk, dict) else type(chunk)}" + ) # 提取响应对象 response_obj = chunk.get("response", {}) if not response_obj: - log.debug("[STREAM COLLECTOR] No 'response' key in chunk, trying direct access") + log.debug( + "[STREAM COLLECTOR] No 'response' key in chunk, trying direct access" + ) response_obj = chunk # 尝试直接使用chunk candidates = response_obj.get("candidates", []) log.debug(f"[STREAM COLLECTOR] Found {len(candidates)} candidates") if not candidates: - log.debug(f"[STREAM COLLECTOR] No candidates in chunk, chunk structure: {list(chunk.keys()) if isinstance(chunk, dict) else type(chunk)}") + log.debug( + f"[STREAM COLLECTOR] No candidates in chunk, chunk structure: {list(chunk.keys()) if isinstance(chunk, dict) else type(chunk)}" + ) continue candidate = candidates[0] @@ -325,7 +344,9 @@ async def collect_streaming_response(stream_generator) -> Response: # 收集文本内容 content = candidate.get("content", {}) parts = content.get("parts", []) - log.debug(f"[STREAM COLLECTOR] Processing {len(parts)} parts from candidate") + log.debug( + f"[STREAM COLLECTOR] Processing {len(parts)} parts from candidate" + ) for part in parts: if not isinstance(part, dict): @@ -333,10 +354,16 @@ async def collect_streaming_response(stream_generator) -> Response: # 优先保留工具调用相关 part(functionCall / functionResponse) # 避免在 stream2nostream 模式下工具调用丢失 - if "functionCall" in part or "functionResponse" in part or "function_call" in part: + if ( + "functionCall" in part + or "functionResponse" in part + or "function_call" in part + ): collected_other_parts.append(part) collected_tool_parts_count += 1 - log.debug(f"[STREAM COLLECTOR] Collected tool part: {list(part.keys())}") + log.debug( + f"[STREAM COLLECTOR] Collected tool part: {list(part.keys())}" + ) continue # 处理文本内容 @@ -345,24 +372,41 @@ async def collect_streaming_response(stream_generator) -> Response: # 区分普通文本和思维链 if part.get("thought", False): collected_thought_text.append(text) - log.debug(f"[STREAM COLLECTOR] Collected thought text: {text[:100]}") + log.debug( + f"[STREAM COLLECTOR] Collected thought text: {text[:100]}" + ) else: collected_text.append(text) - log.debug(f"[STREAM COLLECTOR] Collected regular text: {text[:100]}") + log.debug( + f"[STREAM COLLECTOR] Collected regular text: {text[:100]}" + ) # 处理非文本内容(图片、文件等) - elif "inlineData" in part or "fileData" in part or "executableCode" in part or "codeExecutionResult" in part: + elif ( + "inlineData" in part + or "fileData" in part + or "executableCode" in part + or "codeExecutionResult" in part + ): collected_other_parts.append(part) - log.debug(f"[STREAM COLLECTOR] Collected non-text part: {list(part.keys())}") + log.debug( + f"[STREAM COLLECTOR] Collected non-text part: {list(part.keys())}" + ) # 收集其他信息(使用最后一个块的值) if candidate.get("finishReason"): - merged_response["response"]["candidates"][0]["finishReason"] = candidate["finishReason"] + merged_response["response"]["candidates"][0]["finishReason"] = ( + candidate["finishReason"] + ) if candidate.get("safetyRatings"): - merged_response["response"]["candidates"][0]["safetyRatings"] = candidate["safetyRatings"] + merged_response["response"]["candidates"][0]["safetyRatings"] = ( + candidate["safetyRatings"] + ) if candidate.get("citationMetadata"): - merged_response["response"]["candidates"][0]["citationMetadata"] = candidate["citationMetadata"] + merged_response["response"]["candidates"][0]["citationMetadata"] = ( + candidate["citationMetadata"] + ) # 更新使用元数据 usage = response_obj.get("usageMetadata", {}) @@ -377,22 +421,28 @@ async def collect_streaming_response(stream_generator) -> Response: continue except Exception as e: - log.error(f"[STREAM COLLECTOR] Error collecting stream after {line_count} lines: {e}") + log.error( + f"[STREAM COLLECTOR] Error collecting stream after {line_count} lines: {e}" + ) return Response( content=json.dumps({"error": f"收集流式响应失败: {str(e)}"}), status_code=500, - media_type="application/json" + media_type="application/json", ) - log.debug(f"[STREAM COLLECTOR] Finished iteration, has_data={has_data}, line_count={line_count}") + log.debug( + f"[STREAM COLLECTOR] Finished iteration, has_data={has_data}, line_count={line_count}" + ) # 如果没有收集到任何数据,返回错误 if not has_data: - log.error(f"[STREAM COLLECTOR] No data collected from stream after {line_count} lines") + log.error( + f"[STREAM COLLECTOR] No data collected from stream after {line_count} lines" + ) return Response( content=json.dumps({"error": "No data collected from stream"}), status_code=500, - media_type="application/json" + media_type="application/json", ) # 组装最终的parts @@ -400,16 +450,11 @@ async def collect_streaming_response(stream_generator) -> Response: # 先添加思维链内容(如果有) if collected_thought_text: - final_parts.append({ - "text": "".join(collected_thought_text), - "thought": True - }) + final_parts.append({"text": "".join(collected_thought_text), "thought": True}) # 再添加普通文本内容 if collected_text: - final_parts.append({ - "text": "".join(collected_text) - }) + final_parts.append({"text": "".join(collected_text)}) # 添加其他类型的parts(图片、文件等) final_parts.extend(collected_other_parts) @@ -433,10 +478,10 @@ async def collect_streaming_response(stream_generator) -> Response: # 返回纯JSON格式 return Response( - content=json.dumps(merged_response, ensure_ascii=False).encode('utf-8'), + content=json.dumps(merged_response, ensure_ascii=False).encode("utf-8"), status_code=200, headers={}, - media_type="application/json" + media_type="application/json", ) @@ -474,7 +519,9 @@ def parse_quota_reset_timestamp(error_response: dict) -> Optional[float]: for detail in details: if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": - reset_timestamp_str = detail.get("metadata", {}).get("quotaResetTimeStamp") + reset_timestamp_str = detail.get("metadata", {}).get( + "quotaResetTimeStamp" + ) if reset_timestamp_str: if reset_timestamp_str.endswith("Z"): diff --git a/src/httpx_client.py b/src/httpx_client.py index 95c355f27..521fb0865 100644 --- a/src/httpx_client.py +++ b/src/httpx_client.py @@ -15,13 +15,16 @@ class HttpxClientManager: - """通用HTTP客户端管理器""" + """通用HTTP客户端管理器(启用 HTTP/2 以匹配 Google API 预期)""" async def get_client_kwargs( self, timeout: Optional[float] = 30.0, **kwargs ) -> Dict[str, Any]: """获取httpx客户端的通用配置参数""" - client_kwargs = {**kwargs} + client_kwargs = { + "http2": True, # Google cloudcode-pa 端点要求/优先 HTTP/2 + **kwargs, + } if timeout is not None: client_kwargs["timeout"] = timeout From cda052e63e063f7422c465549842fe990056ff07 Mon Sep 17 00:00:00 2001 From: CI User Date: Sat, 28 Feb 2026 17:41:21 +0800 Subject: [PATCH 39/47] fix: improve project_id auto-detection fallback parsing - Normalize multiple project_id response shapes from loadCodeAssist/onboardUser - Support extracting project id from resource names like projects/*/locations/* - Extend onboardUser polling window from 10s to 30s for slow activation cases --- src/google_oauth_api.py | 209 ++++++++++++++++++++++++++++------------ 1 file changed, 148 insertions(+), 61 deletions(-) diff --git a/src/google_oauth_api.py b/src/google_oauth_api.py index 59a7a63f7..761167bb0 100644 --- a/src/google_oauth_api.py +++ b/src/google_oauth_api.py @@ -4,6 +4,7 @@ import time import asyncio +import re from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional from urllib.parse import urlencode @@ -113,7 +114,7 @@ async def refresh(self): except Exception as e: error_msg = str(e) status_code = None - if hasattr(e, 'response') and hasattr(e.response, 'status_code'): + if hasattr(e, "response") and hasattr(e.response, "status_code"): status_code = e.response.status_code error_msg = f"Token刷新失败 (HTTP {status_code}): {error_msg}" else: @@ -134,11 +135,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "Credentials": expiry_str = data["expiry"] if isinstance(expiry_str, str): if expiry_str.endswith("Z"): - expires_at = datetime.fromisoformat(expiry_str.replace("Z", "+00:00")) + expires_at = datetime.fromisoformat( + expiry_str.replace("Z", "+00:00") + ) elif "+" in expiry_str: expires_at = datetime.fromisoformat(expiry_str) else: - expires_at = datetime.fromisoformat(expiry_str).replace(tzinfo=timezone.utc) + expires_at = datetime.fromisoformat(expiry_str).replace( + tzinfo=timezone.utc + ) except ValueError: log.warning(f"无法解析过期时间: {expiry_str}") @@ -171,7 +176,11 @@ class Flow: """OAuth流程类""" def __init__( - self, client_id: str, client_secret: str, scopes: List[str], redirect_uri: str = None + self, + client_id: str, + client_secret: str, + scopes: List[str], + redirect_uri: str = None, ): self.client_id = client_id self.client_secret = client_secret @@ -217,7 +226,9 @@ async def exchange_code(self, code: str) -> Credentials: oauth_base_url = await get_oauth_proxy_url() token_url = f"{oauth_base_url.rstrip('/')}/token" response = await post_async( - token_url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"} + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, ) response.raise_for_status() @@ -250,7 +261,11 @@ class ServiceAccount: """Service Account类""" def __init__( - self, email: str, private_key: str, project_id: str = None, scopes: List[str] = None + self, + email: str, + private_key: str, + project_id: str = None, + scopes: List[str] = None, ): self.email = email self.private_key = private_key @@ -293,13 +308,18 @@ async def get_access_token(self) -> str: assertion = self.create_jwt() - data = {"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "assertion": assertion} + data = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": assertion, + } try: oauth_base_url = await get_oauth_proxy_url() token_url = f"{oauth_base_url.rstrip('/')}/token" response = await post_async( - token_url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"} + token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, ) response.raise_for_status() @@ -308,7 +328,9 @@ async def get_access_token(self) -> str: if "expires_in" in token_data: expires_in = int(token_data["expires_in"]) - self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + self.expires_at = datetime.now(timezone.utc) + timedelta( + seconds=expires_in + ) return self.access_token @@ -318,7 +340,9 @@ async def get_access_token(self) -> str: raise TokenError(error_msg) @classmethod - def from_dict(cls, data: Dict[str, Any], scopes: List[str] = None) -> "ServiceAccount": + def from_dict( + cls, data: Dict[str, Any], scopes: List[str] = None + ) -> "ServiceAccount": """从字典创建Service Account凭证""" return cls( email=data["client_email"], @@ -337,7 +361,8 @@ async def get_user_info(credentials: Credentials) -> Optional[Dict[str, Any]]: googleapis_base_url = await get_googleapis_proxy_url() userinfo_url = f"{googleapis_base_url.rstrip('/')}/oauth2/v2/userinfo" response = await get_async( - userinfo_url, headers={"Authorization": f"Bearer {credentials.access_token}"} + userinfo_url, + headers={"Authorization": f"Bearer {credentials.access_token}"}, ) response.raise_for_status() return response.json() @@ -426,9 +451,7 @@ async def enable_required_apis(credentials: Credentials, project_id: str) -> boo # 检查服务是否已启用 service_usage_base_url = await get_service_usage_api_url() - check_url = ( - f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}" - ) + check_url = f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}" try: check_response = await get_async(check_url, headers=headers) if check_response.status_code == 200: @@ -448,7 +471,10 @@ async def enable_required_apis(credentials: Credentials, project_id: str) -> boo log.info(f"✅ 成功启用服务: {service}") elif enable_response.status_code == 400: error_data = enable_response.json() - if "already enabled" in error_data.get("error", {}).get("message", "").lower(): + if ( + "already enabled" + in error_data.get("error", {}).get("message", "").lower() + ): log.info(f"✅ 服务 {service} 已经启用") else: log.warning(f"⚠️ 启用服务 {service} 时出现警告: {error_data}") @@ -494,7 +520,9 @@ async def get_user_projects(credentials: Credentials) -> List[Dict[str, Any]]: projects = data.get("projects", []) # 只返回活跃的项目 active_projects = [ - project for project in projects if project.get("lifecycleState") == "ACTIVE" + project + for project in projects + if project.get("lifecycleState") == "ACTIVE" ] log.info(f"获取到 {len(active_projects)} 个活跃项目") return active_projects @@ -518,7 +546,9 @@ async def select_default_project(projects: List[Dict[str, Any]]) -> Optional[str # Google API returns projectId in camelCase project_id = project.get("projectId", "") if "default" in display_name or "default" in project_id.lower(): - log.info(f"选择默认项目: {project_id} ({project.get('displayName', project_id)})") + log.info( + f"选择默认项目: {project_id} ({project.get('displayName', project_id)})" + ) return project_id # 策略2:选择第一个项目 @@ -532,9 +562,7 @@ async def select_default_project(projects: List[Dict[str, Any]]) -> Optional[str async def fetch_project_id( - access_token: str, - user_agent: str, - api_base_url: str + access_token: str, user_agent: str, api_base_url: str ) -> Optional[str]: """ 从 API 获取 project_id,如果 loadCodeAssist 失败则回退到 onboardUser @@ -548,10 +576,10 @@ async def fetch_project_id( project_id 字符串,如果获取失败返回 None """ headers = { - 'User-Agent': user_agent, - 'Authorization': f'Bearer {access_token}', - 'Content-Type': 'application/json', - 'Accept-Encoding': 'gzip' + "User-Agent": user_agent, + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + "Accept-Encoding": "gzip", } # 步骤 1: 尝试 loadCodeAssist @@ -560,10 +588,14 @@ async def fetch_project_id( if project_id: return project_id - log.warning("[fetch_project_id] loadCodeAssist did not return project_id, falling back to onboardUser") + log.warning( + "[fetch_project_id] loadCodeAssist did not return project_id, falling back to onboardUser" + ) except Exception as e: - log.warning(f"[fetch_project_id] loadCodeAssist failed: {type(e).__name__}: {e}") + log.warning( + f"[fetch_project_id] loadCodeAssist failed: {type(e).__name__}: {e}" + ) log.warning("[fetch_project_id] Falling back to onboardUser") # 步骤 2: 回退到 onboardUser @@ -572,20 +604,63 @@ async def fetch_project_id( if project_id: return project_id - log.error("[fetch_project_id] Failed to get project_id from both loadCodeAssist and onboardUser") + log.error( + "[fetch_project_id] Failed to get project_id from both loadCodeAssist and onboardUser" + ) return None except Exception as e: log.error(f"[fetch_project_id] onboardUser failed: {type(e).__name__}: {e}") import traceback + log.debug(f"[fetch_project_id] Traceback: {traceback.format_exc()}") return None -async def _try_load_code_assist( - api_base_url: str, - headers: dict -) -> Optional[str]: +def _extract_project_id_from_resource_name(value: str) -> Optional[str]: + """从资源名中提取 project_id,例如 projects/xxx/locations/global。""" + if not value: + return None + match = re.search(r"projects/([^/]+)/", value) + if match: + return match.group(1) + return None + + +def _normalize_project_id_value(value: Any) -> Optional[str]: + """兼容多种返回结构并归一化 project_id。""" + if value is None: + return None + + if isinstance(value, str): + raw = value.strip() + if not raw: + return None + # 兼容资源名格式 + from_resource = _extract_project_id_from_resource_name(raw) + if from_resource: + return from_resource + return raw + + if isinstance(value, dict): + # 常见字段优先级 + for key in ("projectId", "project_id", "id", "name", "project"): + normalized = _normalize_project_id_value(value.get(key)) + if normalized: + return normalized + return None + + if isinstance(value, list): + for item in value: + normalized = _normalize_project_id_value(item) + if normalized: + return normalized + return None + + return None + + +async def _try_load_code_assist(api_base_url: str, headers: dict) -> Optional[str]: """ 尝试通过 loadCodeAssist 获取 project_id @@ -597,7 +672,7 @@ async def _try_load_code_assist( "metadata": { "ideType": "ANTIGRAVITY", "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI" + "pluginType": "GEMINI", } } @@ -626,9 +701,21 @@ async def _try_load_code_assist( log.info("[loadCodeAssist] User is already activated") # 使用服务器返回的 project_id - project_id = data.get("cloudaicompanionProject") + project_id = _normalize_project_id_value( + data.get("cloudaicompanionProject") + ) + if project_id: + log.info( + f"[loadCodeAssist] Successfully fetched project_id: {project_id}" + ) + return project_id + + # 兼容部分返回结构:projectId/project_id/name 等位置 + project_id = _normalize_project_id_value(data) if project_id: - log.info(f"[loadCodeAssist] Successfully fetched project_id: {project_id}") + log.info( + f"[loadCodeAssist] Fallback extracted project_id: {project_id}" + ) return project_id log.warning("[loadCodeAssist] No project_id in response") @@ -642,10 +729,7 @@ async def _try_load_code_assist( raise Exception(f"HTTP {response.status_code}: {response.text[:200]}") -async def _try_onboard_user( - api_base_url: str, - headers: dict -) -> Optional[str]: +async def _try_onboard_user(api_base_url: str, headers: dict) -> Optional[str]: """ 尝试通过 onboardUser 获取 project_id(长时间运行操作,需要轮询) @@ -669,16 +753,16 @@ async def _try_onboard_user( "metadata": { "ideType": "ANTIGRAVITY", "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI" - } + "pluginType": "GEMINI", + }, } log.debug(f"[onboardUser] Request URL: {request_url}") log.debug(f"[onboardUser] Request body: {request_body}") # onboardUser 是长时间运行操作,需要轮询 - # 最多等待 10 秒(5 次 * 2 秒) - max_attempts = 5 + # 最多等待 30 秒(15 次 * 2 秒),提升慢速环境下的成功率 + max_attempts = 15 attempt = 0 while attempt < max_attempts: @@ -704,37 +788,40 @@ async def _try_onboard_user( # 从响应中提取 project_id response_data = data.get("response", {}) - project_obj = response_data.get("cloudaicompanionProject", {}) + project_obj = response_data.get("cloudaicompanionProject") - if isinstance(project_obj, dict): - project_id = project_obj.get("id") - elif isinstance(project_obj, str): - project_id = project_obj - else: - project_id = None + project_id = _normalize_project_id_value(project_obj) + if not project_id: + # 兼容返回结构变化,尝试在整个响应中提取 + project_id = _normalize_project_id_value(response_data) + if not project_id: + project_id = _normalize_project_id_value(data) if project_id: - log.info(f"[onboardUser] Successfully fetched project_id: {project_id}") + log.info( + f"[onboardUser] Successfully fetched project_id: {project_id}" + ) return project_id else: - log.warning("[onboardUser] Operation completed but no project_id in response") + log.warning( + "[onboardUser] Operation completed but no project_id in response" + ) return None else: - log.debug("[onboardUser] Operation still in progress, waiting 2 seconds...") + log.debug( + "[onboardUser] Operation still in progress, waiting 2 seconds..." + ) await asyncio.sleep(2) else: log.warning(f"[onboardUser] Failed: HTTP {response.status_code}") log.warning(f"[onboardUser] Response body: {response.text[:500]}") raise Exception(f"HTTP {response.status_code}: {response.text[:200]}") - log.error("[onboardUser] Timeout: Operation did not complete within 10 seconds") + log.error("[onboardUser] Timeout: Operation did not complete within 30 seconds") return None -async def _get_onboard_tier( - api_base_url: str, - headers: dict -) -> Optional[str]: +async def _get_onboard_tier(api_base_url: str, headers: dict) -> Optional[str]: """ 从 loadCodeAssist 响应中获取用户应该注册的 tier @@ -746,7 +833,7 @@ async def _get_onboard_tier( "metadata": { "ideType": "ANTIGRAVITY", "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI" + "pluginType": "GEMINI", } } @@ -775,7 +862,7 @@ async def _get_onboard_tier( log.warning("[_get_onboard_tier] No default tier found, using LEGACY") return "LEGACY" else: - log.error(f"[_get_onboard_tier] Failed to fetch tier info: HTTP {response.status_code}") + log.error( + f"[_get_onboard_tier] Failed to fetch tier info: HTTP {response.status_code}" + ) return None - - From f33ab48d4677695ce016a246ab67e5291bac1b4a Mon Sep 17 00:00:00 2001 From: CI User Date: Mon, 2 Mar 2026 15:30:40 +0800 Subject: [PATCH 40/47] chore: backup pre-optimization code changes --- src/api/geminicli.py | 49 ++- src/auth.py | 358 +++++++++++++---- src/converter/anthropic2gemini.py | 336 +++++++++++----- src/converter/gemini_fix.py | 196 ++++++--- src/converter/openai2gemini.py | 546 +++++++++++++++++--------- src/google_oauth_api.py | 223 +++++++++-- test_anthropic2gemini_tools_schema.py | 32 ++ test_auth_oauth_flow.py | 139 +++++++ test_gemini_fix.py | 56 +++ test_geminicli_stream_request.py | 183 +++++++++ test_google_oauth_api.py | 215 ++++++++++ test_openai2gemini_tools_schema.py | 28 ++ 12 files changed, 1902 insertions(+), 459 deletions(-) create mode 100644 test_anthropic2gemini_tools_schema.py create mode 100644 test_auth_oauth_flow.py create mode 100644 test_gemini_fix.py create mode 100644 test_geminicli_stream_request.py create mode 100644 test_google_oauth_api.py create mode 100644 test_openai2gemini_tools_schema.py diff --git a/src/api/geminicli.py b/src/api/geminicli.py index e0c56d46c..a10e70234 100644 --- a/src/api/geminicli.py +++ b/src/api/geminicli.py @@ -19,8 +19,9 @@ import time from typing import Any, Dict, Optional +import httpx from fastapi import Response -from config import get_code_assist_endpoint, get_auto_ban_error_codes +from config import get_code_assist_endpoint, get_auto_ban_error_codes, get_proxy_config from log import log from src.credential_manager import credential_manager @@ -205,19 +206,22 @@ async def stream_request( retry_config = await get_retry_config() max_retries = retry_config["max_retries"] retry_interval = retry_config["retry_interval"] + proxy_enabled = bool(await get_proxy_config()) + # 流式请求采用结构化超时:限制连接/写入/连接池等待,读取保持无限制以兼容长SSE空隙 + stream_timeout = httpx.Timeout(connect=30.0, read=None, write=30.0, pool=30.0) DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 next_cred_task = None # 预热的下一个凭证任务 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) - async def refresh_credential_fast(): + async def refresh_credential_fast() -> bool: nonlocal current_file, credential_data, auth_headers, final_payload cred_result = await credential_manager.get_valid_credential( mode="geminicli", model_name=model_name ) if not cred_result: - return None + return False current_file, credential_data = cred_result try: # 只更新token和project_id,不重建整个headers和payload @@ -226,22 +230,29 @@ async def refresh_credential_fast(): ) project_id = credential_data.get("project_id", "") if not token or not project_id: - return None + return False # 直接更新现有的headers和payload auth_headers["Authorization"] = f"Bearer {token}" final_payload["project"] = project_id return True except Exception: - return None + return False for attempt in range(max_retries + 1): success_recorded = False # 标记是否已记录成功 need_retry = False # 标记是否需要重试 + keep_current_credential = False # 标记是否保留当前凭证 + status_code: Optional[int] = None + attempt_started_at = time.time() try: async for chunk in stream_post_async( - url=target_url, body=final_payload, native=native, headers=auth_headers + url=target_url, + body=final_payload, + native=native, + headers=auth_headers, + timeout=stream_timeout, ): # 判断是否是Response对象 if isinstance(chunk, Response): @@ -439,8 +450,11 @@ async def refresh_credential_fast(): # 对于没有冷却时间的429错误,保留当前凭证重试 if keep_current_credential: + status_label = ( + str(status_code) if status_code is not None else "unknown" + ) log.info( - f"[GEMINICLI STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}" + f"[GEMINICLI STREAM] {status_label}无冷却时间,保留当前凭证重试: {current_file}" ) await asyncio.sleep(retry_interval) continue @@ -481,9 +495,12 @@ async def refresh_credential_fast(): continue # 重试 except Exception as e: + elapsed = time.time() - attempt_started_at log.error( f"[GEMINICLI STREAM] 流式请求异常: type={type(e).__name__}, " - f"detail={repr(e)}, 凭证: {current_file}" + f"detail={repr(e)}, 模型: {model_name}, 尝试: {attempt + 1}/{max_retries + 1}, " + f"耗时: {elapsed:.2f}s, url: {target_url}, proxy={'on' if proxy_enabled else 'off'}, " + f"http2=True, timeout(connect/write/pool=30s, read=None), 凭证: {current_file}" ) if attempt < max_retries: delay = _compute_capacity_retry_delay(retry_interval, attempt) @@ -578,13 +595,13 @@ async def non_stream_request( next_cred_task = None # 预热的下一个凭证任务 # 内部函数:快速更新凭证(只更新token和project_id,避免重建整个请求) - async def refresh_credential_fast(): + async def refresh_credential_fast() -> bool: nonlocal current_file, credential_data, auth_headers, final_payload cred_result = await credential_manager.get_valid_credential( mode="geminicli", model_name=model_name ) if not cred_result: - return None + return False current_file, credential_data = cred_result try: # 只更新token和project_id,不重建整个headers和payload @@ -593,14 +610,14 @@ async def refresh_credential_fast(): ) project_id = credential_data.get("project_id", "") if not token or not project_id: - return None + return False # 直接更新现有的headers和payload auth_headers["Authorization"] = f"Bearer {token}" final_payload["project"] = project_id return True except Exception: - return None + return False for attempt in range(max_retries + 1): try: @@ -899,7 +916,13 @@ async def refresh_credential_fast(): # 所有重试都失败,返回最后一次的原始错误 log.error("[NON-STREAM] 所有重试均失败") - return last_error_response + if last_error_response is not None: + return last_error_response + return Response( + content=json.dumps({"error": "请求失败,且未收到上游错误响应"}), + status_code=500, + media_type="application/json", + ) # ==================== 测试代码 ==================== diff --git a/src/auth.py b/src/auth.py index 6d2d0f429..065945e77 100644 --- a/src/auth.py +++ b/src/auth.py @@ -11,7 +11,7 @@ import uuid from datetime import timezone from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from urllib.parse import parse_qs, urlparse from config import get_config_value, get_antigravity_api_url, get_code_assist_endpoint @@ -40,12 +40,25 @@ ) +def _project_discovery_network_error_message() -> str: + """项目自动发现失败时的网络诊断提示。""" + return ( + "无法自动获取项目列表:检测到网络连接失败(可能是代理/网络限制导致无法访问 " + "Google Cloud Resource Manager)。请检查 PROXY / OAUTH_PROXY_URL / " + "GOOGLEAPIS_PROXY_URL 配置或网络连通性后重试;如仍失败可手动输入项目ID。" + ) + + async def get_callback_port(): """获取OAuth回调端口""" - return int(await get_config_value("oauth_callback_port", "11451", "OAUTH_CALLBACK_PORT")) + return int( + await get_config_value("oauth_callback_port", "11451", "OAUTH_CALLBACK_PORT") + ) -def _prepare_credentials_data(credentials: Credentials, project_id: str, mode: str = "geminicli") -> Dict[str, Any]: +def _prepare_credentials_data( + credentials: Credentials, project_id: str, mode: str = "geminicli" +) -> Dict[str, Any]: """准备凭证数据字典(统一函数)""" if mode == "antigravity": creds_data = { @@ -96,21 +109,27 @@ def _cleanup_auth_flow_server(state: str): except Exception as e: log.debug(f"关闭服务器时出错: {e}") del auth_flows[state] + auth_flow_locks.pop(state, None) class _OAuthLibPatcher: """oauthlib参数验证补丁的上下文管理器""" + def __init__(self): import oauthlib.oauth2.rfc6749.parameters + self.module = oauthlib.oauth2.rfc6749.parameters - self.original_validate = None + self.original_validate: Optional[Callable[[Any], Any]] = None def __enter__(self): self.original_validate = self.module.validate_token_parameters def patched_validate(params): + validator = self.original_validate + if validator is None: + return None try: - return self.original_validate(params) + return validator(params) except Warning: pass @@ -125,6 +144,59 @@ def __exit__(self, exc_type, exc_val, exc_tb): # 全局状态管理 - 严格限制大小 auth_flows = {} # 存储进行中的认证流程 MAX_AUTH_FLOWS = 20 # 严格限制最大认证流程数 +auth_flow_locks: Dict[str, asyncio.Lock] = {} # 每个认证流程的并发锁 + + +def _get_auth_flow_lock(state: str) -> asyncio.Lock: + """获取认证流程锁(不存在则创建)。""" + lock = auth_flow_locks.get(state) + if lock is None: + lock = asyncio.Lock() + auth_flow_locks[state] = lock + return lock + + +async def _exchange_or_reuse_credentials( + state: str, flow: Any, auth_code: Optional[str] +) -> tuple[Optional[Credentials], Optional[str]]: + """原子化地兑换授权码,避免并发重复兑换;若已兑换成功则直接复用缓存凭证。""" + lock = _get_auth_flow_lock(state) + + async with lock: + flow_data = auth_flows.get(state) + if not flow_data: + return None, "未找到对应的认证流程,请先重新发起认证" + + cached_credentials = flow_data.get("exchanged_credentials") + if cached_credentials: + return cached_credentials, None + + if flow_data.get("exchange_in_progress"): + return None, "认证流程正在处理中,请勿重复提交回调,稍后重试" + + if not auth_code: + return None, "授权码缺失,无法兑换凭证" + + flow_data["exchange_in_progress"] = True + + try: + credentials = await flow.exchange_code(auth_code) + except Exception as e: + async with lock: + if state in auth_flows: + auth_flows[state]["exchange_in_progress"] = False + return None, f"获取凭证失败: {str(e)}" + + async with lock: + if state in auth_flows: + flow_data = auth_flows[state] + flow_data["exchange_in_progress"] = False + flow_data["exchanged_credentials"] = credentials + flow_data["code_redeemed"] = True + # 授权码是一次性的,成功兑换后清空,防止后续逻辑误用 + flow_data["code"] = None + + return credentials, None def cleanup_auth_flows_for_memory(): @@ -150,6 +222,7 @@ def cleanup_auth_flows_for_memory(): except Exception: pass flow_data.clear() + auth_flow_locks.pop(state, None) auth_flows = new_auth_flows log.info(f"强制清理认证流程,保留 {len(auth_flows)} 个最新流程") @@ -157,7 +230,7 @@ def cleanup_auth_flows_for_memory(): return len(auth_flows) -async def find_available_port(start_port: int = None) -> int: +async def find_available_port(start_port: Optional[int] = None) -> int: """动态查找可用端口""" if start_port is None: start_port = await get_callback_port() @@ -239,10 +312,13 @@ def log_message(self, format, *args): async def create_auth_url( - project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" + project_id: Optional[str] = None, + user_session: Optional[str] = None, + mode: Optional[str] = "geminicli", ) -> Dict[str, Any]: """创建认证URL,支持动态端口分配""" try: + mode = mode or "geminicli" # 动态分配端口 callback_port = await find_available_port() callback_url = f"http://{CALLBACK_HOST}:{callback_port}" @@ -295,7 +371,9 @@ async def create_auth_url( # 严格控制认证流程数量 - 超过限制时立即清理最旧的 if len(auth_flows) >= MAX_AUTH_FLOWS: # 清理最旧的认证流程 - oldest_state = min(auth_flows.keys(), key=lambda k: auth_flows[k].get("created_at", 0)) + oldest_state = min( + auth_flows.keys(), key=lambda k: auth_flows[k].get("created_at", 0) + ) try: # 清理服务器资源 old_flow = auth_flows[oldest_state] @@ -320,6 +398,9 @@ async def create_auth_url( "server_thread": server_thread, # 存储服务器线程 "code": None, "completed": False, + "exchange_in_progress": False, + "code_redeemed": False, + "exchanged_credentials": None, "created_at": time.time(), "auto_project_detection": project_id is None, # 标记是否需要自动检测项目ID "mode": mode, # 凭证模式 @@ -375,7 +456,7 @@ def wait_for_callback_sync(state: str, timeout: int = 300) -> Optional[str]: async def complete_auth_flow( - project_id: Optional[str] = None, user_session: str = None + project_id: Optional[str] = None, user_session: Optional[str] = None ) -> Dict[str, Any]: """完成认证流程并保存凭证,支持自动检测项目ID""" try: @@ -412,7 +493,10 @@ async def complete_auth_flow( flow_data = data if not state or not flow_data: - return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} + return { + "success": False, + "error": "未找到对应的认证流程,请先点击获取认证链接", + } if not project_id: project_id = flow_data.get("project_id") @@ -424,9 +508,14 @@ async def complete_auth_flow( } flow = flow_data["flow"] + cached_credentials = flow_data.get("exchanged_credentials") + # 如果已有已兑换凭证,直接复用,避免重复等待授权码 + if cached_credentials: + log.info(f"检测到已兑换凭证,跳过等待授权码 (state: {state})") + auth_code = None # 如果还没有授权码,需要等待回调 - if not flow_data.get("code"): + elif not flow_data.get("code"): log.info(f"等待用户完成OAuth授权 (state: {state})") auth_code = wait_for_callback_sync(state) @@ -445,7 +534,17 @@ async def complete_auth_flow( # 使用认证代码获取凭证 with _OAuthLibPatcher(): try: - credentials = await flow.exchange_code(auth_code) + if cached_credentials: + credentials = cached_credentials + else: + credentials, exchange_error = await _exchange_or_reuse_credentials( + state, flow, auth_code + ) + if exchange_error: + log.error(exchange_error) + return {"success": False, "error": exchange_error} + if not credentials: + return {"success": False, "error": "获取凭证失败:空凭证"} # credentials 已经在 exchange_code 中获得 # 如果需要自动检测项目ID且没有提供项目ID @@ -453,7 +552,11 @@ async def complete_auth_flow( log.info("尝试通过API获取用户项目列表...") log.info(f"使用的token: {credentials.access_token[:20]}...") log.info(f"Token过期时间: {credentials.expires_at}") - user_projects = await get_user_projects(credentials) + user_projects_result = await get_user_projects( + credentials, + with_diagnostics=True, + ) + user_projects, project_discovery_diag = user_projects_result if user_projects: # 如果只有一个项目,自动使用 @@ -479,13 +582,21 @@ async def complete_auth_flow( { # Google API returns projectId in camelCase "project_id": p.get("projectId"), - "name": p.get("displayName") or p.get("projectId"), + "name": p.get("displayName") + or p.get("projectId"), "projectNumber": p.get("projectNumber"), } for p in user_projects ], } else: + if project_discovery_diag.get("all_failed_by_connect_error"): + return { + "success": False, + "error": _project_discovery_network_error_message(), + "requires_manual_project_id": True, + "project_discovery_network_error": True, + } # 如果无法获取项目列表,提示手动输入 return { "success": False, @@ -505,7 +616,9 @@ async def complete_auth_flow( saved_filename = await save_credentials(credentials, project_id) # 准备返回的凭证数据 - creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") + creds_data = _prepare_credentials_data( + credentials, project_id, mode="geminicli" + ) # 清理使用过的流程 _cleanup_auth_flow_server(state) @@ -515,7 +628,9 @@ async def complete_auth_flow( "success": True, "credentials": creds_data, "file_path": saved_filename, - "auto_detected_project": flow_data.get("auto_project_detection", False), + "auto_detected_project": flow_data.get( + "auto_project_detection", False + ), } except Exception as e: @@ -528,10 +643,13 @@ async def complete_auth_flow( async def asyncio_complete_auth_flow( - project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" + project_id: Optional[str] = None, + user_session: Optional[str] = None, + mode: Optional[str] = "geminicli", ) -> Dict[str, Any]: """异步完成认证流程,支持自动检测项目ID""" try: + mode = mode or "geminicli" log.info( f"asyncio_complete_auth_flow开始执行: project_id={project_id}, user_session={user_session}" ) @@ -593,7 +711,10 @@ async def asyncio_complete_auth_flow( if not state or not flow_data: log.error(f"未找到认证流程: state={state}, flow_data存在={bool(flow_data)}") log.debug(f"当前所有flow_data: {list(auth_flows.keys())}") - return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} + return { + "success": False, + "error": "未找到对应的认证流程,请先点击获取认证链接", + } log.info(f"找到认证流程: state={state}") log.info( @@ -622,47 +743,69 @@ async def asyncio_complete_auth_flow( # 检查是否已经有授权码 log.info("开始检查OAuth授权码...") - log.info(f"等待state={state}的授权回调,回调端口: {flow_data.get('callback_port')}") - log.info(f"当前flow_data状态: completed={flow_data.get('completed')}, code存在={bool(flow_data.get('code'))}") + log.info( + f"等待state={state}的授权回调,回调端口: {flow_data.get('callback_port')}" + ) + log.info( + f"当前flow_data状态: completed={flow_data.get('completed')}, code存在={bool(flow_data.get('code'))}" + ) + cached_credentials = flow_data.get("exchanged_credentials") max_wait_time = 60 # 最多等待60秒 wait_interval = 1 # 每秒检查一次 waited = 0 - while waited < max_wait_time: - if flow_data.get("code"): - log.info(f"检测到OAuth授权码,开始处理凭证 (等待时间: {waited}秒)") - break - - # 每5秒输出一次提示 - if waited % 5 == 0 and waited > 0: - log.info(f"仍在等待OAuth授权... ({waited}/{max_wait_time}秒)") - log.debug(f"当前state: {state}, flow_data keys: {list(flow_data.keys())}") + if not cached_credentials: + while waited < max_wait_time: + if flow_data.get("code"): + log.info(f"检测到OAuth授权码,开始处理凭证 (等待时间: {waited}秒)") + break + + # 每5秒输出一次提示 + if waited % 5 == 0 and waited > 0: + log.info(f"仍在等待OAuth授权... ({waited}/{max_wait_time}秒)") + log.debug( + f"当前state: {state}, flow_data keys: {list(flow_data.keys())}" + ) - # 异步等待 - await asyncio.sleep(wait_interval) - waited += wait_interval + # 异步等待 + await asyncio.sleep(wait_interval) + waited += wait_interval - # 刷新flow_data引用,因为可能被回调更新了 - if state in auth_flows: - flow_data = auth_flows[state] + # 刷新flow_data引用,因为可能被回调更新了 + if state in auth_flows: + flow_data = auth_flows[state] - if not flow_data.get("code"): - log.error(f"等待OAuth回调超时,等待了{waited}秒") - return { - "success": False, - "error": "等待OAuth回调超时,请确保完成了浏览器中的认证并看到成功页面", - } + if not flow_data.get("code"): + log.error(f"等待OAuth回调超时,等待了{waited}秒") + return { + "success": False, + "error": "等待OAuth回调超时,请确保完成了浏览器中的认证并看到成功页面", + } + else: + log.info(f"检测到已兑换凭证,跳过等待授权码 (state: {state})") flow = flow_data["flow"] - auth_code = flow_data["code"] + auth_code = flow_data.get("code") - log.info(f"开始使用授权码获取凭证: code={'***' + auth_code[-4:] if auth_code else 'None'}") + log.info( + f"开始使用授权码获取凭证: code={'***' + auth_code[-4:] if auth_code else 'None'}" + ) # 使用认证代码获取凭证 with _OAuthLibPatcher(): try: log.info("调用flow.exchange_code...") - credentials = await flow.exchange_code(auth_code) + if cached_credentials: + credentials = cached_credentials + else: + credentials, exchange_error = await _exchange_or_reuse_credentials( + state, flow, auth_code + ) + if exchange_error: + log.error(exchange_error) + return {"success": False, "error": exchange_error} + if not credentials: + return {"success": False, "error": "获取凭证失败:空凭证"} log.info( f"成功获取凭证,token前缀: {credentials.access_token[:20] if credentials.access_token else 'None'}..." ) @@ -672,7 +815,11 @@ async def asyncio_complete_auth_flow( ) # 检查凭证模式 - cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode + cred_mode = ( + flow_data.get("mode", "geminicli") + if flow_data.get("mode") + else mode + ) if cred_mode == "antigravity": log.info("Antigravity模式:从API获取project_id...") # 使用API获取project_id @@ -680,7 +827,7 @@ async def asyncio_complete_auth_flow( project_id = await fetch_project_id( credentials.access_token, ANTIGRAVITY_USER_AGENT, - antigravity_url + antigravity_url, ) if project_id: log.info(f"成功从API获取project_id: {project_id}") @@ -690,10 +837,14 @@ async def asyncio_complete_auth_flow( log.info(f"生成的随机project_id: {project_id}") # 保存antigravity凭证 - saved_filename = await save_credentials(credentials, project_id, mode="antigravity") + saved_filename = await save_credentials( + credentials, project_id, mode="antigravity" + ) # 准备返回的凭证数据 - creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity") + creds_data = _prepare_credentials_data( + credentials, project_id, mode="antigravity" + ) # 清理使用过的流程 _cleanup_auth_flow_server(state) @@ -713,9 +864,7 @@ async def asyncio_complete_auth_flow( # 使用API获取project_id(使用标准模式的User-Agent) code_assist_url = await get_code_assist_endpoint() project_id = await fetch_project_id( - credentials.access_token, - GEMINICLI_USER_AGENT, - code_assist_url + credentials.access_token, GEMINICLI_USER_AGENT, code_assist_url ) if project_id: flow_data["project_id"] = project_id @@ -726,7 +875,11 @@ async def asyncio_complete_auth_flow( else: log.warning("无法从API获取project_id,回退到项目列表获取方式") # 回退到原来的项目列表获取方式 - user_projects = await get_user_projects(credentials) + user_projects_result = await get_user_projects( + credentials, + with_diagnostics=True, + ) + user_projects, project_discovery_diag = user_projects_result if user_projects: # 如果只有一个项目,自动使用 @@ -758,13 +911,23 @@ async def asyncio_complete_auth_flow( { # Google API returns projectId in camelCase "project_id": p.get("projectId"), - "name": p.get("displayName") or p.get("projectId"), + "name": p.get("displayName") + or p.get("projectId"), "projectNumber": p.get("projectNumber"), } for p in user_projects ], } else: + if project_discovery_diag.get( + "all_failed_by_connect_error" + ): + return { + "success": False, + "error": _project_discovery_network_error_message(), + "requires_manual_project_id": True, + "project_discovery_network_error": True, + } # 如果无法获取项目列表,提示手动输入 return { "success": False, @@ -788,7 +951,9 @@ async def asyncio_complete_auth_flow( saved_filename = await save_credentials(credentials, project_id) # 准备返回的凭证数据 - creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") + creds_data = _prepare_credentials_data( + credentials, project_id, mode="geminicli" + ) # 清理使用过的流程 _cleanup_auth_flow_server(state) @@ -798,7 +963,9 @@ async def asyncio_complete_auth_flow( "success": True, "credentials": creds_data, "file_path": saved_filename, - "auto_detected_project": flow_data.get("auto_project_detection", False), + "auto_detected_project": flow_data.get( + "auto_project_detection", False + ), } except Exception as e: @@ -811,10 +978,13 @@ async def asyncio_complete_auth_flow( async def complete_auth_flow_from_callback_url( - callback_url: str, project_id: Optional[str] = None, mode: str = "geminicli" + callback_url: str, + project_id: Optional[str] = None, + mode: Optional[str] = "geminicli", ) -> Dict[str, Any]: """从回调URL直接完成认证流程,无需启动本地服务器""" try: + mode = mode or "geminicli" log.info(f"开始从回调URL完成认证: {callback_url}") # 解析回调URL @@ -846,19 +1016,26 @@ async def complete_auth_flow_from_callback_url( try: # 使用authorization code获取token - credentials = await flow.exchange_code(code) + credentials, exchange_error = await _exchange_or_reuse_credentials( + state, flow, code + ) + if exchange_error: + log.error(exchange_error) + return {"success": False, "error": exchange_error} + if not credentials: + return {"success": False, "error": "获取凭证失败:空凭证"} log.info("成功获取访问令牌") # 检查凭证模式 - cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode + cred_mode = ( + flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode + ) if cred_mode == "antigravity": log.info("Antigravity模式(从回调URL):从API获取project_id...") # 使用API获取project_id antigravity_url = await get_antigravity_api_url() project_id = await fetch_project_id( - credentials.access_token, - ANTIGRAVITY_USER_AGENT, - antigravity_url + credentials.access_token, ANTIGRAVITY_USER_AGENT, antigravity_url ) if project_id: log.info(f"成功从API获取project_id: {project_id}") @@ -868,10 +1045,14 @@ async def complete_auth_flow_from_callback_url( log.info(f"生成的随机project_id: {project_id}") # 保存antigravity凭证 - saved_filename = await save_credentials(credentials, project_id, mode="antigravity") + saved_filename = await save_credentials( + credentials, project_id, mode="antigravity" + ) # 准备返回的凭证数据 - creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity") + creds_data = _prepare_credentials_data( + credentials, project_id, mode="antigravity" + ) # 清理使用过的流程 _cleanup_auth_flow_server(state) @@ -895,9 +1076,7 @@ async def complete_auth_flow_from_callback_url( log.info("标准模式:从API获取project_id...") code_assist_url = await get_code_assist_endpoint() detected_project_id = await fetch_project_id( - credentials.access_token, - GEMINICLI_USER_AGENT, - code_assist_url + credentials.access_token, GEMINICLI_USER_AGENT, code_assist_url ) if detected_project_id: auto_detected = True @@ -905,7 +1084,11 @@ async def complete_auth_flow_from_callback_url( else: log.warning("无法从API获取project_id,回退到项目列表获取方式") # 回退到原来的项目列表获取方式 - projects = await get_user_projects(credentials) + projects_result = await get_user_projects( + credentials, + with_diagnostics=True, + ) + projects, project_discovery_diag = projects_result if projects: if len(projects) == 1: # 只有一个项目,自动使用 @@ -921,8 +1104,19 @@ async def complete_auth_flow_from_callback_url( log.info( f"检测到{len(projects)}个项目,自动选择第一个: {detected_project_id}" ) - log.debug(f"其他可用项目: {[p['projectId'] for p in projects[1:]]}") + log.debug( + f"其他可用项目: {[p['projectId'] for p in projects[1:]]}" + ) else: + if project_discovery_diag.get( + "all_failed_by_connect_error" + ): + return { + "success": False, + "error": _project_discovery_network_error_message(), + "requires_manual_project_id": True, + "project_discovery_network_error": True, + } # 没有项目访问权限 return { "success": False, @@ -951,7 +1145,9 @@ async def complete_auth_flow_from_callback_url( saved_filename = await save_credentials(credentials, detected_project_id) # 准备返回的凭证数据 - creds_data = _prepare_credentials_data(credentials, detected_project_id, mode="geminicli") + creds_data = _prepare_credentials_data( + credentials, detected_project_id, mode="geminicli" + ) # 清理使用过的流程 _cleanup_auth_flow_server(state) @@ -973,7 +1169,9 @@ async def complete_auth_flow_from_callback_url( return {"success": False, "error": str(e)} -async def save_credentials(creds: Credentials, project_id: str, mode: str = "geminicli") -> str: +async def save_credentials( + creds: Credentials, project_id: str, mode: str = "geminicli" +) -> str: """通过统一存储系统保存凭证""" # 生成文件名(使用project_id和时间戳) timestamp = int(time.time()) @@ -1000,7 +1198,9 @@ async def save_credentials(creds: Credentials, project_id: str, mode: str = "gem "last_success": time.time(), "user_email": None, } - await storage_adapter.update_credential_state(filename, default_state, mode=mode) + await storage_adapter.update_credential_state( + filename, default_state, mode=mode + ) log.info(f"凭证和状态已保存到: {filename} (mode={mode})") except Exception as e: log.warning(f"创建默认状态记录失败 {filename}: {e}") @@ -1076,6 +1276,7 @@ def cleanup_expired_flows(): # 显式清理流程数据,释放内存 flow_data.clear() del auth_flows[state] + auth_flow_locks.pop(state, None) cleaned_count += 1 if cleaned_count > 0: @@ -1174,7 +1375,10 @@ def validate_credential_content(content: str) -> Dict[str, Any]: missing_fields = [field for field in required_fields if field not in creds_data] if missing_fields: - return {"valid": False, "error": f'缺少必要字段: {", ".join(missing_fields)}'} + return { + "valid": False, + "error": f"缺少必要字段: {', '.join(missing_fields)}", + } # 检查project_id if "project_id" not in creds_data: @@ -1188,7 +1392,9 @@ def validate_credential_content(content: str) -> Dict[str, Any]: return {"valid": False, "error": f"文件验证失败: {str(e)}"} -async def save_uploaded_credential(content: str, original_filename: str) -> Dict[str, Any]: +async def save_uploaded_credential( + content: str, original_filename: str +) -> Dict[str, Any]: """通过统一存储系统保存上传的凭证""" try: # 验证内容格式 @@ -1239,4 +1445,8 @@ async def batch_upload_credentials(files_data: List[Dict[str, str]]) -> Dict[str if result["success"]: success_count += 1 - return {"uploaded_count": success_count, "total_count": len(files_data), "results": results} + return { + "uploaded_count": success_count, + "total_count": len(files_data), + "results": results, + } diff --git a/src/converter/anthropic2gemini.py b/src/converter/anthropic2gemini.py index 65c40b5bb..ebb458712 100644 --- a/src/converter/anthropic2gemini.py +++ b/src/converter/anthropic2gemini.py @@ -3,6 +3,7 @@ 提供请求体、响应和流式转换的完整功能。 """ + from __future__ import annotations import json @@ -16,7 +17,7 @@ from src.converter.thoughtSignature_fix import ( encode_tool_id_with_signature, - decode_tool_id_and_signature + decode_tool_id_and_signature, ) DEFAULT_TEMPERATURE = 0.4 @@ -33,81 +34,85 @@ def has_valid_thoughtsignature(block: Dict[str, Any]) -> bool: """ 检查 thinking 块是否有有效签名 - + Args: block: content block 字典 - + Returns: bool: 是否有有效签名 """ if not isinstance(block, dict): return True - + block_type = block.get("type") if block_type not in ("thinking", "redacted_thinking"): return True # 非 thinking 块默认有效 - + thinking = block.get("thinking", "") thoughtsignature = block.get("thoughtSignature") - + # 空 thinking + 任意 thoughtsignature = 有效 (trailing signature case) if not thinking and thoughtsignature is not None: return True - + # 有内容 + 足够长度的 thoughtsignature = 有效 - if thoughtsignature and isinstance(thoughtsignature, str) and len(thoughtsignature) >= MIN_SIGNATURE_LENGTH: + if ( + thoughtsignature + and isinstance(thoughtsignature, str) + and len(thoughtsignature) >= MIN_SIGNATURE_LENGTH + ): return True - + return False def sanitize_thinking_block(block: Dict[str, Any]) -> Dict[str, Any]: """ 清理 thinking 块,只保留必要字段(移除 cache_control 等) - + Args: block: content block 字典 - + Returns: 清理后的 block 字典 """ if not isinstance(block, dict): return block - + block_type = block.get("type") if block_type not in ("thinking", "redacted_thinking"): return block - + # 重建块,移除额外字段 sanitized: Dict[str, Any] = { "type": block_type, - "thinking": block.get("thinking", "") + "thinking": block.get("thinking", ""), } - + thoughtsignature = block.get("thoughtSignature") if thoughtsignature: sanitized["thoughtSignature"] = thoughtsignature - + return sanitized def remove_trailing_unsigned_thinking(blocks: List[Dict[str, Any]]) -> None: """ 移除尾部的无签名 thinking 块 - + Args: blocks: content blocks 列表 (会被修改) """ if not blocks: return - + # 从后向前扫描 end_index = len(blocks) for i in range(len(blocks) - 1, -1, -1): block = blocks[i] if not isinstance(block, dict): break - + block_type = block.get("type") if block_type in ("thinking", "redacted_thinking"): if not has_valid_thoughtsignature(block): @@ -116,7 +121,7 @@ def remove_trailing_unsigned_thinking(blocks: List[Dict[str, Any]]) -> None: break # 遇到有效签名的 thinking 块,停止 else: break # 遇到非 thinking 块,停止 - + if end_index < len(blocks): removed = len(blocks) - end_index del blocks[end_index:] @@ -170,7 +175,9 @@ def filter_invalid_thinking_blocks(messages: List[Dict[str, Any]]) -> None: ) new_blocks.append({"type": "text", "text": thinking_text}) else: - log.debug("[Claude-Handler] Dropping empty thinking block with invalid thoughtSignature") + log.debug( + "[Claude-Handler] Dropping empty thinking block with invalid thoughtSignature" + ) msg["content"] = new_blocks filtered_count = original_len - len(new_blocks) @@ -235,25 +242,63 @@ def _remove_nulls_for_tool_input(value: Any) -> Any: return value + # ============================================================================ # 2. JSON Schema 清理 # ============================================================================ + def clean_json_schema(schema: Any) -> Any: """ 清理 JSON Schema,移除下游不支持的字段,并把验证要求追加到 description。 """ + shorthand_type_map = { + "string": "string", + "number": "number", + "integer": "integer", + "boolean": "boolean", + "array": "array", + "object": "object", + } + + if isinstance(schema, str): + shorthand = shorthand_type_map.get(schema.strip().lower()) + if shorthand: + return {"type": shorthand} + return schema + if not isinstance(schema, dict): return schema # 下游不支持的字段 unsupported_keys = { - "$schema", "$id", "$ref", "$defs", "definitions", "title", - "example", "examples", "readOnly", "writeOnly", "default", - "exclusiveMaximum", "exclusiveMinimum", "oneOf", "anyOf", "allOf", - "const", "additionalItems", "contains", "patternProperties", - "dependencies", "propertyNames", "if", "then", "else", - "contentEncoding", "contentMediaType", + "$schema", + "$id", + "$ref", + "$defs", + "definitions", + "title", + "example", + "examples", + "readOnly", + "writeOnly", + "default", + "exclusiveMaximum", + "exclusiveMinimum", + "oneOf", + "anyOf", + "allOf", + "const", + "additionalItems", + "contains", + "patternProperties", + "dependencies", + "propertyNames", + "if", + "then", + "else", + "contentEncoding", + "contentMediaType", } validation_fields = { @@ -273,13 +318,36 @@ def clean_json_schema(schema: Any) -> Any: cleaned: Dict[str, Any] = {} for key, value in schema.items(): - if key in unsupported_keys or key in fields_to_remove or key in validation_fields: + if ( + key in unsupported_keys + or key in fields_to_remove + or key in validation_fields + ): + continue + + if key == "properties" and isinstance(value, dict): + normalized_properties: Dict[str, Any] = {} + for prop_name, prop_schema in value.items(): + if isinstance(prop_schema, str): + shorthand = shorthand_type_map.get(prop_schema.strip().lower()) + if shorthand: + normalized_properties[prop_name] = {"type": shorthand} + else: + log.warning( + f"[ANTHROPIC2GEMINI] 非法属性schema简写,回退为string: {prop_name}={prop_schema}" + ) + normalized_properties[prop_name] = {"type": "string"} + else: + normalized_properties[prop_name] = clean_json_schema(prop_schema) + + cleaned[key] = normalized_properties continue if key == "type" and isinstance(value, list): # type: ["string", "null"] -> type: "string", nullable: true has_null = any( - isinstance(t, str) and t.strip() and t.strip().lower() == "null" for t in value + isinstance(t, str) and t.strip() and t.strip().lower() == "null" + for t in value ) non_null_types = [ t.strip() @@ -297,7 +365,10 @@ def clean_json_schema(schema: Any) -> Any: elif isinstance(value, dict): cleaned[key] = clean_json_schema(value) elif isinstance(value, list): - cleaned[key] = [clean_json_schema(item) if isinstance(item, dict) else item for item in value] + cleaned[key] = [ + clean_json_schema(item) if isinstance(item, dict) else item + for item in value + ] else: cleaned[key] = value @@ -315,7 +386,10 @@ def clean_json_schema(schema: Any) -> Any: # 4. Tools 转换 # ============================================================================ -def convert_tools(anthropic_tools: Optional[List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]: + +def convert_tools( + anthropic_tools: Optional[List[Dict[str, Any]]], +) -> Optional[List[Dict[str, Any]]]: """ 将 Anthropic tools[] 转换为下游 tools(functionDeclarations)结构。 """ @@ -348,6 +422,7 @@ def convert_tools(anthropic_tools: Optional[List[Dict[str, Any]]]) -> Optional[L # 5. Messages 转换 # ============================================================================ + def _extract_tool_result_output(content: Any) -> str: """从 tool_result.content 中提取输出字符串""" if isinstance(content, list): @@ -363,9 +438,7 @@ def _extract_tool_result_output(content: Any) -> str: def convert_messages_to_contents( - messages: List[Dict[str, Any]], - *, - include_thinking: bool = True + messages: List[Dict[str, Any]], *, include_thinking: bool = True ) -> List[Dict[str, Any]]: """ 将 Anthropic messages[] 转换为下游 contents[](role: user/model, parts: [])。 @@ -388,17 +461,22 @@ def convert_messages_to_contents( tool_name = item.get("name") if encoded_tool_id and tool_name: # 解码获取原始ID和签名 - original_id, thoughtsignature = decode_tool_id_and_signature(encoded_tool_id) + original_id, thoughtsignature = decode_tool_id_and_signature( + encoded_tool_id + ) # 存储映射:编码ID -> (name, thoughtsignature) - tool_use_info[str(encoded_tool_id)] = (tool_name, thoughtsignature) + tool_use_info[str(encoded_tool_id)] = ( + tool_name, + thoughtsignature, + ) for msg in messages: role = msg.get("role", "user") - + # system 消息已经由 merge_system_messages 处理,这里跳过 if role == "system": continue - + # 支持 'assistant' 和 'model' 角色(Google history usage) gemini_role = "model" if role in ("assistant", "model") else "user" raw_content = msg.get("content", "") @@ -422,17 +500,17 @@ def convert_messages_to_contents( thinking_text = item.get("thinking", "") if thinking_text is None: thinking_text = "" - + part: Dict[str, Any] = { "text": str(thinking_text), "thought": True, } - + # 如果有 thoughtsignature 则添加 thoughtsignature = item.get("thoughtSignature") if thoughtsignature: part["thoughtSignature"] = thoughtsignature - + parts.append(part) elif item_type == "redacted_thinking": if not include_thinking: @@ -441,17 +519,17 @@ def convert_messages_to_contents( thinking_text = item.get("thinking") if thinking_text is None: thinking_text = item.get("data", "") - + part_dict: Dict[str, Any] = { "text": str(thinking_text or ""), "thought": True, } - + # 如果有 thoughtsignature 则添加 thoughtsignature = item.get("thoughtSignature") if thoughtsignature: part_dict["thoughtSignature"] = thoughtsignature - + parts.append(part_dict) elif item_type == "text": text = item.get("text", "") @@ -470,7 +548,9 @@ def convert_messages_to_contents( ) elif item_type == "tool_use": encoded_id = item.get("id") or "" - original_id, thoughtsignature = decode_tool_id_and_signature(encoded_id) + original_id, thoughtsignature = decode_tool_id_and_signature( + encoded_id + ) fc_part: Dict[str, Any] = { "functionCall": { @@ -490,9 +570,11 @@ def convert_messages_to_contents( elif item_type == "tool_result": output = _extract_tool_result_output(item.get("content")) encoded_tool_use_id = item.get("tool_use_id") or "" - + # 解码获取原始ID(functionResponse不需要签名) - original_tool_use_id, _ = decode_tool_id_and_signature(encoded_tool_use_id) + original_tool_use_id, _ = decode_tool_id_and_signature( + encoded_tool_use_id + ) # 从 tool_result 获取 name,如果没有则从映射中查找 func_name = item.get("name") @@ -503,7 +585,7 @@ def convert_messages_to_contents( func_name = tool_info[0] # 获取 name if not func_name: func_name = "unknown_function" - + parts.append( { "functionResponse": { @@ -561,7 +643,9 @@ def reorganize_tool_messages(contents: List[Dict[str, Any]]) -> List[Dict[str, A new_contents.append({"role": "model", "parts": [part]}) if tool_id is not None and str(tool_id) in tool_results: - new_contents.append({"role": "user", "parts": [tool_results[str(tool_id)]]}) + new_contents.append( + {"role": "user", "parts": [tool_results[str(tool_id)]]} + ) i += 1 continue @@ -576,6 +660,7 @@ def reorganize_tool_messages(contents: List[Dict[str, Any]]) -> List[Dict[str, A # 7. Tool Choice 转换 # ============================================================================ + def convert_tool_choice_to_tool_config(tool_choice: Any) -> Optional[Dict[str, Any]]: """ 将 Anthropic tool_choice 转换为 Gemini toolConfig @@ -591,10 +676,10 @@ def convert_tool_choice_to_tool_config(tool_choice: Any) -> Optional[Dict[str, A """ if not tool_choice: return None - + if isinstance(tool_choice, dict): choice_type = tool_choice.get("type") - + if choice_type == "auto": return {"functionCallingConfig": {"mode": "AUTO"}} elif choice_type == "any": @@ -608,7 +693,7 @@ def convert_tool_choice_to_tool_config(tool_choice: Any) -> Optional[Dict[str, A "allowedFunctionNames": [tool_name], } } - + # 无效或不支持的 tool_choice,返回 None return None @@ -617,6 +702,7 @@ def convert_tool_choice_to_tool_config(tool_choice: Any) -> Optional[Dict[str, A # 8. Generation Config 构建 # ============================================================================ + def build_generation_config(payload: Dict[str, Any]) -> Dict[str, Any]: """ 根据 Anthropic Messages 请求构造下游 generationConfig。 @@ -657,40 +743,44 @@ def build_generation_config(payload: Dict[str, Any]) -> Dict[str, Any]: if thinking and isinstance(thinking, dict): thinking_type = thinking.get("type") budget_tokens = thinking.get("budget_tokens") - + # 如果启用了 extended thinking,设置 thinkingConfig if thinking_type == "enabled": is_plan_mode = True thinking_config: Dict[str, Any] = {} - + # 设置思考预算,默认使用较大的值以支持计划模式 if budget_tokens is not None: thinking_config["thinkingBudget"] = budget_tokens else: # 默认给一个较大的思考预算以支持完整的计划生成 thinking_config["thinkingBudget"] = 48000 - + # 始终包含思考内容,这样才能看到计划 thinking_config["includeThoughts"] = True - + config["thinkingConfig"] = thinking_config - log.info(f"[ANTHROPIC2GEMINI] Extended thinking enabled with budget: {thinking_config['thinkingBudget']}") + log.info( + f"[ANTHROPIC2GEMINI] Extended thinking enabled with budget: {thinking_config['thinkingBudget']}" + ) elif thinking_type == "disabled": # 明确禁用思考模式 - config["thinkingConfig"] = { - "includeThoughts": False - } + config["thinkingConfig"] = {"includeThoughts": False} log.info("[ANTHROPIC2GEMINI] Extended thinking explicitly disabled") stop_sequences = payload.get("stop_sequences") if isinstance(stop_sequences, list) and stop_sequences: - config["stopSequences"] = config["stopSequences"] + [str(s) for s in stop_sequences] + config["stopSequences"] = config["stopSequences"] + [ + str(s) for s in stop_sequences + ] elif is_plan_mode: # Plan mode 时清空默认 stop sequences,避免过早停止 # 默认的 stop sequences 可能会导致模型在生成计划时过早停止 config["stopSequences"] = [] - log.info("[ANTHROPIC2GEMINI] Plan mode: cleared default stop sequences to prevent premature stopping") - + log.info( + "[ANTHROPIC2GEMINI] Plan mode: cleared default stop sequences to prevent premature stopping" + ) + # 如果不是 plan mode 且没有自定义 stop_sequences,保持默认值 # (默认值已经在 config 初始化时设置) @@ -701,6 +791,7 @@ def build_generation_config(payload: Dict[str, Any]) -> Dict[str, Any]: # 8. 主要转换函数 # ============================================================================ + async def anthropic_to_gemini_request(payload: Dict[str, Any]) -> Dict[str, Any]: """ 将 Anthropic 格式请求体转换为 Gemini 格式请求体 @@ -726,7 +817,7 @@ async def anthropic_to_gemini_request(payload: Dict[str, Any]) -> Dict[str, Any] messages = payload.get("messages") or [] if not isinstance(messages, list): messages = [] - + # [CRITICAL FIX] 过滤并修复 Thinking 块签名 # 在转换前先过滤无效的 thinking 块 filter_invalid_thinking_blocks(messages) @@ -736,7 +827,7 @@ async def anthropic_to_gemini_request(payload: Dict[str, Any]) -> Dict[str, Any] # 转换消息内容(始终包含thinking块,由响应端处理) contents = convert_messages_to_contents(messages, include_thinking=True) - + # [CRITICAL FIX] 移除尾部无签名的 thinking 块 # 对真实请求应用额外的清理 for content in contents: @@ -745,12 +836,12 @@ async def anthropic_to_gemini_request(payload: Dict[str, Any]) -> Dict[str, Any] parts = content.get("parts", []) if isinstance(parts, list): remove_trailing_unsigned_thinking(parts) - + contents = reorganize_tool_messages(contents) # 转换工具 tools = convert_tools(payload.get("tools")) - + # 转换 tool_choice tool_config = convert_tool_choice_to_tool_config(payload.get("tool_choice")) @@ -759,14 +850,14 @@ async def anthropic_to_gemini_request(payload: Dict[str, Any]) -> Dict[str, Any] "contents": contents, "generationConfig": generation_config, } - + # 如果 merge_system_messages 已经添加了 systemInstruction,使用它 if "systemInstruction" in payload: gemini_request["systemInstruction"] = payload["systemInstruction"] - + if tools: gemini_request["tools"] = tools - + # 添加 toolConfig(如果有 tool_choice) if tool_config: gemini_request["toolConfig"] = tool_config @@ -775,9 +866,7 @@ async def anthropic_to_gemini_request(payload: Dict[str, Any]) -> Dict[str, Any] def gemini_to_anthropic_response( - gemini_response: Dict[str, Any], - model: str, - status_code: int = 200 + gemini_response: Dict[str, Any], model: str, status_code: int = 200 ) -> Dict[str, Any]: """ 将 Gemini 格式非流式响应转换为 Anthropic 格式非流式响应 @@ -826,14 +915,14 @@ def gemini_to_anthropic_response( thinking_text = part.get("text", "") if thinking_text is None: thinking_text = "" - + block: Dict[str, Any] = {"type": "thinking", "thinking": str(thinking_text)} - + # 如果有 thoughtsignature 则添加 thoughtsignature = part.get("thoughtSignature") if thoughtsignature: block["thoughtSignature"] = thoughtsignature - + content.append(block) continue @@ -848,7 +937,7 @@ def gemini_to_anthropic_response( fc = part.get("functionCall", {}) or {} original_id = fc.get("id") or f"toolu_{uuid.uuid4().hex}" thoughtsignature = part.get("thoughtSignature") - + # 对工具调用ID进行签名编码 encoded_id = encode_tool_id_with_signature(original_id, thoughtsignature) content.append( @@ -878,7 +967,7 @@ def gemini_to_anthropic_response( # 确定停止原因 finish_reason = candidate.get("finishReason") - + # 只有在正常停止(STOP)且有工具调用时才设为 tool_use # 避免在 SAFETY、MAX_TOKENS 等情况下仍然返回 tool_use 导致循环 if has_tool_use and finish_reason == "STOP": @@ -890,8 +979,16 @@ def gemini_to_anthropic_response( stop_reason = "end_turn" # 提取 token 使用情况 - input_tokens = usage_metadata.get("promptTokenCount", 0) if isinstance(usage_metadata, dict) else 0 - output_tokens = usage_metadata.get("candidatesTokenCount", 0) if isinstance(usage_metadata, dict) else 0 + input_tokens = ( + usage_metadata.get("promptTokenCount", 0) + if isinstance(usage_metadata, dict) + else 0 + ) + output_tokens = ( + usage_metadata.get("candidatesTokenCount", 0) + if isinstance(usage_metadata, dict) + else 0 + ) # 构建 Anthropic 响应 message_id = f"msg_{uuid.uuid4().hex}" @@ -912,9 +1009,7 @@ def gemini_to_anthropic_response( async def gemini_stream_to_anthropic_stream( - gemini_stream: AsyncIterator[bytes], - model: str, - status_code: int = 200 + gemini_stream: AsyncIterator[bytes], model: str, status_code: int = 200 ) -> AsyncIterator[bytes]: """ 将 Gemini 格式流式响应转换为 Anthropic SSE 格式流式响应 @@ -968,18 +1063,28 @@ def _close_block() -> Optional[bytes]: async for chunk in gemini_stream: # 检查是否是 Response 对象(错误情况) if isinstance(chunk, Response): - log.warning(f"[GEMINI_TO_ANTHROPIC] 收到 Response 对象,状态码: {chunk.status_code},直接转发错误") + log.warning( + f"[GEMINI_TO_ANTHROPIC] 收到 Response 对象,状态码: {chunk.status_code},直接转发错误" + ) # 直接转发错误响应内容,不做格式转换 - error_content = chunk.body if isinstance(chunk.body, bytes) else chunk.body.encode('utf-8') + error_content = ( + chunk.body + if isinstance(chunk.body, bytes) + else chunk.body.encode("utf-8") + ) yield error_content return # 记录接收到的原始chunk - log.debug(f"[GEMINI_TO_ANTHROPIC] Raw chunk: {chunk[:200] if chunk else b''}") + log.debug( + f"[GEMINI_TO_ANTHROPIC] Raw chunk: {chunk[:200] if chunk else b''}" + ) # 解析 Gemini 流式块 if not chunk or not chunk.startswith(b"data: "): - log.debug(f"[GEMINI_TO_ANTHROPIC] Skipping chunk (not SSE format or empty)") + log.debug( + f"[GEMINI_TO_ANTHROPIC] Skipping chunk (not SSE format or empty)" + ) continue raw = chunk[6:].strip() @@ -990,8 +1095,10 @@ def _close_block() -> Optional[bytes]: log.debug(f"[GEMINI_TO_ANTHROPIC] Parsing JSON: {raw[:200]}") try: - data = json.loads(raw.decode('utf-8', errors='ignore')) - log.debug(f"[GEMINI_TO_ANTHROPIC] Parsed data: {json.dumps(data, ensure_ascii=False)[:300]}") + data = json.loads(raw.decode("utf-8", errors="ignore")) + log.debug( + f"[GEMINI_TO_ANTHROPIC] Parsed data: {json.dumps(data, ensure_ascii=False)[:300]}" + ) except Exception as e: log.warning(f"[GEMINI_TO_ANTHROPIC] JSON parse error: {e}") continue @@ -1029,7 +1136,10 @@ def _close_block() -> Optional[bytes]: "content": [], "stop_reason": None, "stop_sequence": None, - "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, }, }, ) @@ -1043,7 +1153,7 @@ def _close_block() -> Optional[bytes]: if part.get("thought") is True: thinking_text = part.get("text", "") thoughtsignature = part.get("thoughtSignature") - + # 检查是否需要关闭上一个块并开启新的 thinking 块 if current_block_type != "thinking": close_evt = _close_block() @@ -1065,20 +1175,23 @@ def _close_block() -> Optional[bytes]: "content_block": block, }, ) - elif thoughtsignature and thoughtsignature != current_thinking_signature: + elif ( + thoughtsignature + and thoughtsignature != current_thinking_signature + ): # 签名变化,需要开启新的 thinking 块 close_evt = _close_block() if close_evt: yield close_evt - + current_block_index += 1 current_block_type = "thinking" current_thinking_signature = thoughtsignature - + block_new: Dict[str, Any] = {"type": "thinking", "thinking": ""} if thoughtsignature: block_new["thoughtSignature"] = thoughtsignature - + yield _sse_event( "content_block_start", { @@ -1095,7 +1208,10 @@ def _close_block() -> Optional[bytes]: { "type": "content_block_delta", "index": current_block_index, - "delta": {"type": "thinking_delta", "thinking": thinking_text}, + "delta": { + "type": "thinking_delta", + "thinking": thinking_text, + }, }, ) continue @@ -1144,7 +1260,9 @@ def _close_block() -> Optional[bytes]: fc = part.get("functionCall", {}) or {} original_id = fc.get("id") or f"toolu_{uuid.uuid4().hex}" thoughtsignature = part.get("thoughtSignature") - tool_id = encode_tool_id_with_signature(original_id, thoughtsignature) + tool_id = encode_tool_id_with_signature( + original_id, thoughtsignature + ) tool_name = fc.get("name") or "" tool_args = _remove_nulls_for_tool_input(fc.get("args", {}) or {}) @@ -1171,13 +1289,18 @@ def _close_block() -> Optional[bytes]: }, ) - input_json = json.dumps(tool_args, ensure_ascii=False, separators=(",", ":")) + input_json = json.dumps( + tool_args, ensure_ascii=False, separators=(",", ":") + ) yield _sse_event( "content_block_delta", { "type": "content_block_delta", "index": current_block_index, - "delta": {"type": "input_json_delta", "partial_json": input_json}, + "delta": { + "type": "input_json_delta", + "partial_json": input_json, + }, }, ) @@ -1186,10 +1309,12 @@ def _close_block() -> Optional[bytes]: {"type": "content_block_stop", "index": current_block_index}, ) # 工具调用块已完全关闭,current_block_type 保持为 None - + if _anthropic_debug_enabled(): - log.info(f"[ANTHROPIC][tool_use] 工具调用块已关闭: index={current_block_index}") - + log.info( + f"[ANTHROPIC][tool_use] 工具调用块已关闭: index={current_block_index}" + ) + continue # 检查是否结束 @@ -1250,11 +1375,14 @@ def _close_block() -> Optional[bytes]: "content": [], "stop_reason": None, "stop_sequence": None, - "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, }, }, ) yield _sse_event( "error", {"type": "error", "error": {"type": "api_error", "message": str(e)}}, - ) \ No newline at end of file + ) diff --git a/src/converter/gemini_fix.py b/src/converter/gemini_fix.py index 513bc164f..4deecefb8 100644 --- a/src/converter/gemini_fix.py +++ b/src/converter/gemini_fix.py @@ -3,6 +3,7 @@ 提供对 Gemini API 请求体和响应的标准化处理 ──────────────────────────────────────────────────────────────── """ + from math import e from typing import Any, Dict, Optional @@ -11,36 +12,42 @@ # ==================== Gemini API 配置 ==================== + def prepare_image_generation_request( - request_body: Dict[str, Any], - model: str + request_body: Dict[str, Any], model: str ) -> Dict[str, Any]: """ 图像生成模型请求体后处理 - + Args: request_body: 原始请求体 model: 模型名称 - + Returns: 处理后的请求体 """ request_body = request_body.copy() model_lower = model.lower() - + # 解析分辨率 - image_size = "4K" if "-4k" in model_lower else "2K" if "-2k" in model_lower else None - + image_size = ( + "4K" if "-4k" in model_lower else "2K" if "-2k" in model_lower else None + ) + # 解析比例 aspect_ratio = None for suffix, ratio in [ - ("-21x9", "21:9"), ("-16x9", "16:9"), ("-9x16", "9:16"), - ("-4x3", "4:3"), ("-3x4", "3:4"), ("-1x1", "1:1") + ("-21x9", "21:9"), + ("-16x9", "16:9"), + ("-9x16", "9:16"), + ("-4x3", "4:3"), + ("-3x4", "3:4"), + ("-1x1", "1:1"), ]: if suffix in model_lower: aspect_ratio = ratio break - + # 构建 imageConfig image_config = {} if aspect_ratio: @@ -51,25 +58,32 @@ def prepare_image_generation_request( request_body["model"] = "gemini-3.1-flash-image" # 统一使用基础模型名 request_body["generationConfig"] = { "candidateCount": 1, - "imageConfig": image_config + "imageConfig": image_config, } # 移除不需要的字段 for key in ("systemInstruction", "tools", "toolConfig"): request_body.pop(key, None) - + return request_body # ==================== 模型特性辅助函数 ==================== + def get_base_model_name(model_name: str) -> str: """移除模型名称中的后缀,返回基础模型名""" # 按照从长到短的顺序排列,避免短后缀先于长后缀被匹配 suffixes = [ - "-maxthinking", "-nothinking", # 兼容旧模式 - "-minimal", "-medium", "-search", "-think", # 中等长度后缀 - "-high", "-max", "-low" # 短后缀 + "-maxthinking", + "-nothinking", # 兼容旧模式 + "-minimal", + "-medium", + "-search", + "-think", # 中等长度后缀 + "-high", + "-max", + "-low", # 短后缀 ] result = model_name changed = True @@ -78,7 +92,7 @@ def get_base_model_name(model_name: str) -> str: changed = False for suffix in suffixes: if result.endswith(suffix): - result = result[:-len(suffix)] + result = result[: -len(suffix)] changed = True # 不使用 break,继续检查是否还有其他后缀 return result @@ -167,14 +181,14 @@ def is_search_model(model_name: str) -> bool: # ==================== 统一的 Gemini 请求后处理 ==================== + def is_thinking_model(model_name: str) -> bool: """检查是否为思考模型 (包含 -thinking 或 pro)""" return "think" in model_name or "pro" in model_name.lower() async def normalize_gemini_request( - request: Dict[str, Any], - mode: str = "geminicli" + request: Dict[str, Any], mode: str = "geminicli" ) -> Dict[str, Any]: """ 规范化 Gemini 请求 @@ -196,12 +210,18 @@ async def normalize_gemini_request( result = request.copy() model = result.get("model", "") - generation_config = (result.get("generationConfig") or {}).copy() # 创建副本避免修改原对象 + generation_config = ( + result.get("generationConfig") or {} + ).copy() # 创建副本避免修改原对象 tools = result.get("tools") - system_instruction = result.get("systemInstruction") or result.get("system_instructions") - + system_instruction = result.get("systemInstruction") or result.get( + "system_instructions" + ) + # 记录原始请求 - log.debug(f"[GEMINI_FIX] 原始请求 - 模型: {model}, mode: {mode}, generationConfig: {generation_config}") + log.debug( + f"[GEMINI_FIX] 原始请求 - 模型: {model}, mode: {mode}, generationConfig: {generation_config}" + ) # 获取配置值 return_thoughts = await get_return_thoughts_to_frontend() @@ -214,11 +234,19 @@ async def normalize_gemini_request( # 其次使用传入的思考预算(如果未从模型名称获取) if thinking_budget is None and thinking_level is None: - thinking_budget = generation_config.get("thinkingConfig", {}).get("thinkingBudget") - thinking_level = generation_config.get("thinkingConfig", {}).get("thinkingLevel") + thinking_budget = generation_config.get("thinkingConfig", {}).get( + "thinkingBudget" + ) + thinking_level = generation_config.get("thinkingConfig", {}).get( + "thinkingLevel" + ) # 假如 is_thinking_model 为真或者思考预算/等级不为空,设置 thinkingConfig - if is_thinking_model(model) or thinking_budget is not None or thinking_level is not None: + if ( + is_thinking_model(model) + or thinking_budget is not None + or thinking_level is not None + ): # 确保 thinkingConfig 存在 if "thinkingConfig" not in generation_config: generation_config["thinkingConfig"] = {} @@ -258,15 +286,18 @@ async def normalize_gemini_request( if is_search_model(model): result_tools = result.get("tools") or [] result["tools"] = result_tools - if not any(tool.get("googleSearch") for tool in result_tools if isinstance(tool, dict)): + if not any( + tool.get("googleSearch") + for tool in result_tools + if isinstance(tool, dict) + ): result_tools.append({"googleSearch": {}}) # 3. 模型名称处理 result["model"] = get_base_model_name(model) elif mode == "antigravity": - - ''' + """ # 1. 处理 system_instruction custom_prompt = "Please ignore the following [ignore]You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**[/ignore]" @@ -280,7 +311,7 @@ async def normalize_gemini_request( result["systemInstruction"] = { "parts": [{"text": custom_prompt}] + existing_parts } - ''' + """ # 2. 判断图片模型 if "image" in model.lower(): @@ -288,57 +319,74 @@ async def normalize_gemini_request( return prepare_image_generation_request(result, model) else: # 3. 思考模型处理 - if is_thinking_model(model) or ("thinkingBudget" in generation_config.get("thinkingConfig", {}) and generation_config["thinkingConfig"]["thinkingBudget"] != 0): + if is_thinking_model(model) or ( + "thinkingBudget" in generation_config.get("thinkingConfig", {}) + and generation_config["thinkingConfig"]["thinkingBudget"] != 0 + ): # 直接设置 thinkingConfig if "thinkingConfig" not in generation_config: generation_config["thinkingConfig"] = {} - + thinking_config = generation_config["thinkingConfig"] # 优先使用传入的思考预算,否则使用默认值 if "thinkingBudget" not in thinking_config: thinking_config["thinkingBudget"] = 1024 thinking_config.pop("thinkingLevel", None) # 避免与 thinkingBudget 冲突 thinking_config["includeThoughts"] = return_thoughts - + # 检查最后一个 assistant 消息是否以 thinking 块开始 contents = result.get("contents", []) if "claude" in model.lower(): # 检测是否有工具调用(MCP场景) has_tool_calls = any( - isinstance(content, dict) and - any( - isinstance(part, dict) and ("functionCall" in part or "function_call" in part) + isinstance(content, dict) + and any( + isinstance(part, dict) + and ("functionCall" in part or "function_call" in part) for part in content.get("parts", []) ) for content in contents ) - + if has_tool_calls: # MCP 场景:检测到工具调用,移除 thinkingConfig - log.warning(f"[ANTIGRAVITY] 检测到工具调用(MCP场景),移除 thinkingConfig 避免失效") + log.warning( + f"[ANTIGRAVITY] 检测到工具调用(MCP场景),移除 thinkingConfig 避免失效" + ) generation_config.pop("thinkingConfig", None) else: # 非 MCP 场景:填充思考块 # log.warning(f"[ANTIGRAVITY] 最后一个 assistant 消息不以 thinking 块开始,自动填充思考块") - + # 找到最后一个 model 角色的 content for i in range(len(contents) - 1, -1, -1): content = contents[i] - if isinstance(content, dict) and content.get("role") == "model": + if ( + isinstance(content, dict) + and content.get("role") == "model" + ): # 在 parts 开头插入思考块(使用官方跳过验证的虚拟签名) parts = content.get("parts", []) thinking_part = { "text": "...", # "thought": True, # 标记为思考块 - "thoughtSignature": "skip_thought_signature_validator" # 官方文档推荐的虚拟签名 + "thoughtSignature": "skip_thought_signature_validator", # 官方文档推荐的虚拟签名 } # 如果第一个 part 不是 thinking,则插入 - if not parts or not (isinstance(parts[0], dict) and ("thought" in parts[0] or "thoughtSignature" in parts[0])): + if not parts or not ( + isinstance(parts[0], dict) + and ( + "thought" in parts[0] + or "thoughtSignature" in parts[0] + ) + ): content["parts"] = [thinking_part] + parts - log.debug(f"[ANTIGRAVITY] 已在最后一个 assistant 消息开头插入思考块(含跳过验证签名)") + log.debug( + f"[ANTIGRAVITY] 已在最后一个 assistant 消息开头插入思考块(含跳过验证签名)" + ) break - + # 移除 -thinking 后缀 model = model.replace("-thinking", "") @@ -346,7 +394,7 @@ async def normalize_gemini_request( # 使用关键词匹配而不是精确匹配,更灵活地处理各种变体 original_model = model if "opus" in model.lower(): - model = "claude-opus-4-6-thinking" + model = "claude-opus-4-6-thinking" elif "sonnet" in model.lower(): if "4-5" in model: model = "claude-sonnet-4-5-thinking" @@ -357,21 +405,30 @@ async def normalize_gemini_request( elif "claude" in model.lower(): # Claude 模型兜底:如果包含 claude 但不是 opus/sonnet/haiku model = "claude-sonnet-4-6" - + result["model"] = model if original_model != model: log.debug(f"[ANTIGRAVITY] 映射模型: {original_model} -> {model}") # 5. 模型特殊处理:循环移除末尾的 model 消息,保证以用户消息结尾 # 因为该模型不支持预填充 - if "claude-opus-4-6-thinking" in model.lower() or "claude-sonnet-4-6" in model.lower(): + if ( + "claude-opus-4-6-thinking" in model.lower() + or "claude-sonnet-4-6" in model.lower() + ): contents = result.get("contents", []) removed_count = 0 - while contents and isinstance(contents[-1], dict) and contents[-1].get("role") == "model": + while ( + contents + and isinstance(contents[-1], dict) + and contents[-1].get("role") == "model" + ): contents.pop() removed_count += 1 if removed_count > 0: - log.warning(f"[ANTIGRAVITY] {model} 不支持预填充,移除了 {removed_count} 条末尾 model 消息") + log.warning( + f"[ANTIGRAVITY] {model} 不支持预填充,移除了 {removed_count} 条末尾 model 消息" + ) result["contents"] = contents # 6. 移除 antigravity 模式不支持的字段 @@ -399,7 +456,26 @@ async def normalize_gemini_request( for part in content["parts"]: if not isinstance(part, dict): continue - + + part = part.copy() + + # functionCall 场景需要 thoughtSignature,缺失时补齐占位符 + # 兼容 thought_signature(snake_case)输入并统一为 thoughtSignature + if "functionCall" in part: + if ( + "thoughtSignature" not in part + and "thought_signature" in part + ): + part["thoughtSignature"] = part["thought_signature"] + part.pop("thought_signature", None) + if "thoughtSignature" not in part: + part["thoughtSignature"] = ( + "skip_thought_signature_validator" + ) + log.debug( + "[GEMINI_FIX] functionCall 缺少 thoughtSignature,已补齐占位符" + ) + # 检查 part 是否有有效的非空值 # 过滤掉空字典或所有值都为空的 part has_valid_value = any( @@ -407,42 +483,46 @@ async def normalize_gemini_request( for key, value in part.items() if key != "thought" # thought 字段可以为空 ) - - if has_valid_value: - part = part.copy() + if has_valid_value: # 修复 text 字段:确保是字符串而不是列表 if "text" in part: text_value = part["text"] if isinstance(text_value, list): # 如果是列表,合并为字符串 - log.warning(f"[GEMINI_FIX] text 字段是列表,自动合并: {text_value}") + log.warning( + f"[GEMINI_FIX] text 字段是列表,自动合并: {text_value}" + ) part["text"] = " ".join(str(t) for t in text_value if t) elif isinstance(text_value, str): # 清理尾随空格 part["text"] = text_value.rstrip() else: # 其他类型转为字符串 - log.warning(f"[GEMINI_FIX] text 字段类型异常 ({type(text_value)}), 转为字符串: {text_value}") + log.warning( + f"[GEMINI_FIX] text 字段类型异常 ({type(text_value)}), 转为字符串: {text_value}" + ) part["text"] = str(text_value) valid_parts.append(part) else: log.warning(f"[GEMINI_FIX] 移除空的或无效的 part: {part}") - + # 只添加有有效 parts 的 content if valid_parts: cleaned_content = content.copy() cleaned_content["parts"] = valid_parts cleaned_contents.append(cleaned_content) else: - log.warning(f"[GEMINI_FIX] 跳过没有有效 parts 的 content: {content.get('role')}") + log.warning( + f"[GEMINI_FIX] 跳过没有有效 parts 的 content: {content.get('role')}" + ) else: cleaned_contents.append(content) - + result["contents"] = cleaned_contents if generation_config: result["generationConfig"] = generation_config - return result \ No newline at end of file + return result diff --git a/src/converter/openai2gemini.py b/src/converter/openai2gemini.py index 739b4bfee..f93b13766 100644 --- a/src/converter/openai2gemini.py +++ b/src/converter/openai2gemini.py @@ -18,6 +18,7 @@ from log import log + def _convert_usage_metadata(usage_metadata: Dict[str, Any]) -> Dict[str, int]: """ 将Gemini的usageMetadata转换为OpenAI格式的usage字段 @@ -38,7 +39,9 @@ def _convert_usage_metadata(usage_metadata: Dict[str, Any]) -> Dict[str, int]: } -def _build_message_with_reasoning(role: str, content: str, reasoning_content: str) -> dict: +def _build_message_with_reasoning( + role: str, content: str, reasoning_content: str +) -> dict: """构建包含可选推理内容的消息对象""" message = {"role": role, "content": content} @@ -113,7 +116,9 @@ def _normalize_function_name(name: str) -> str: parts.append(char) normalized = "".join(parts) except ImportError: - log.warning("pypinyin not installed, cannot convert Chinese characters to pinyin") + log.warning( + "pypinyin not installed, cannot convert Chinese characters to pinyin" + ) normalized = name else: normalized = name @@ -141,30 +146,34 @@ def _normalize_function_name(name: str) -> str: def _resolve_ref(ref: str, root_schema: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ 解析 $ref 引用 - + Args: ref: 引用路径,如 "#/definitions/MyType" root_schema: 根 schema 对象 - + Returns: 解析后的 schema,如果失败返回 None """ - if not ref.startswith('#/'): + if not ref.startswith("#/"): return None - - path = ref[2:].split('/') + + path = ref[2:].split("/") current = root_schema - + for segment in path: if isinstance(current, dict) and segment in current: current = current[segment] else: return None - + return current if isinstance(current, dict) else None -def _clean_schema_for_claude(schema: Any, root_schema: Optional[Dict[str, Any]] = None, visited: Optional[set] = None) -> Any: +def _clean_schema_for_claude( + schema: Any, + root_schema: Optional[Dict[str, Any]] = None, + visited: Optional[set] = None, +) -> Any: """ 清理 JSON Schema,转换为 Claude API 支持的格式(符合 JSON Schema draft 2020-12) @@ -208,6 +217,7 @@ def _clean_schema_for_claude(schema: Any, root_schema: Optional[Dict[str, Any]] resolved = _resolve_ref(schema["$ref"], root_schema) if resolved: import copy + result = copy.deepcopy(resolved) for key, value in schema.items(): if key != "$ref": @@ -261,31 +271,51 @@ def _clean_schema_for_claude(schema: Any, root_schema: Optional[Dict[str, Any]] is_homogeneous = all(item.get("type") == first_type for item in tuple_items) if is_homogeneous and first_type: - result["items"] = _clean_schema_for_claude(tuple_items[0], root_schema, visited) + result["items"] = _clean_schema_for_claude( + tuple_items[0], root_schema, visited + ) else: # 异质元组,使用 anyOf 表示 result["items"] = { - "anyOf": [_clean_schema_for_claude(item, root_schema, visited) for item in tuple_items] + "anyOf": [ + _clean_schema_for_claude(item, root_schema, visited) + for item in tuple_items + ] } else: - result["items"] = _clean_schema_for_claude(result["items"], root_schema, visited) + result["items"] = _clean_schema_for_claude( + result["items"], root_schema, visited + ) # 5. 处理 anyOf(保持 anyOf,递归清理) if "anyOf" in result: - result["anyOf"] = [_clean_schema_for_claude(item, root_schema, visited) for item in result["anyOf"]] + result["anyOf"] = [ + _clean_schema_for_claude(item, root_schema, visited) + for item in result["anyOf"] + ] # 6. 清理 Claude 不支持的字段(根据 JSON Schema 2020-12) # Claude API 对某些字段比较严格,移除可能导致问题的字段 unsupported_keys = { - "title", "$schema", "strict", + "title", + "$schema", + "strict", "additionalItems", # 废弃字段,使用 items 替代 - "exclusiveMaximum", "exclusiveMinimum", # 在 2020-12 中这些应该是数值而非布尔值 - "$defs", "definitions", # 移除 definitions 相关字段避免冲突 - "example", "examples", "readOnly", "writeOnly", + "exclusiveMaximum", + "exclusiveMinimum", # 在 2020-12 中这些应该是数值而非布尔值 + "$defs", + "definitions", # 移除 definitions 相关字段避免冲突 + "example", + "examples", + "readOnly", + "writeOnly", "const", # const 可能导致问题 - "contentEncoding", "contentMediaType", + "contentEncoding", + "contentMediaType", "oneOf", # oneOf 可能导致问题,用 anyOf 替代 - "patternProperties", "dependencies", "propertyNames", # Google API 不支持 + "patternProperties", + "dependencies", + "propertyNames", # Google API 不支持 } for key in list(result.keys()): @@ -293,14 +323,20 @@ def _clean_schema_for_claude(schema: Any, root_schema: Optional[Dict[str, Any]] del result[key] # 递归处理 additionalProperties(如果存在) - if "additionalProperties" in result and isinstance(result["additionalProperties"], dict): - result["additionalProperties"] = _clean_schema_for_claude(result["additionalProperties"], root_schema, visited) + if "additionalProperties" in result and isinstance( + result["additionalProperties"], dict + ): + result["additionalProperties"] = _clean_schema_for_claude( + result["additionalProperties"], root_schema, visited + ) # 7. 递归处理 properties if "properties" in result: cleaned_props = {} for prop_name, prop_schema in result["properties"].items(): - cleaned_props[prop_name] = _clean_schema_for_claude(prop_schema, root_schema, visited) + cleaned_props[prop_name] = _clean_schema_for_claude( + prop_schema, root_schema, visited + ) result["properties"] = cleaned_props # 8. 确保有 type 字段(如果有 properties 但没有 type) @@ -314,7 +350,11 @@ def _clean_schema_for_claude(schema: Any, root_schema: Optional[Dict[str, Any]] return result -def _clean_schema_for_gemini(schema: Any, root_schema: Optional[Dict[str, Any]] = None, visited: Optional[set] = None) -> Any: +def _clean_schema_for_gemini( + schema: Any, + root_schema: Optional[Dict[str, Any]] = None, + visited: Optional[set] = None, +) -> Any: """ 清理 JSON Schema,转换为 Gemini 支持的格式 @@ -337,25 +377,37 @@ def _clean_schema_for_gemini(schema: Any, root_schema: Optional[Dict[str, Any]] Returns: 清理后的 schema """ - # 非字典类型直接返回 + # 非字典类型:兼容简写类型(例如 properties: {"value": "object"}) if not isinstance(schema, dict): + if isinstance(schema, str): + shorthand_type_map = { + "string": "STRING", + "number": "NUMBER", + "integer": "INTEGER", + "boolean": "BOOLEAN", + "array": "ARRAY", + "object": "OBJECT", + } + mapped_type = shorthand_type_map.get(schema.lower()) + if mapped_type: + return {"type": mapped_type} return schema - + # 初始化 if root_schema is None: root_schema = schema if visited is None: visited = set() - + # 防止循环引用 schema_id = id(schema) if schema_id in visited: return schema visited.add(schema_id) - + # 创建副本避免修改原对象 result = {} - + # 1. 处理 $ref if "$ref" in schema: resolved = _resolve_ref(schema["$ref"], root_schema) @@ -374,30 +426,30 @@ def _clean_schema_for_gemini(schema: Any, root_schema: Optional[Dict[str, Any]] merged[key] = value schema = merged result = {} - + # 2. 处理 allOf(合并所有 schema) if "allOf" in schema: all_of_schemas = schema["allOf"] for item in all_of_schemas: cleaned_item = _clean_schema_for_gemini(item, root_schema, visited) - + # 合并 properties if "properties" in cleaned_item: if "properties" not in result: result["properties"] = {} result["properties"].update(cleaned_item["properties"]) - + # 合并 required if "required" in cleaned_item: if "required" not in result: result["required"] = [] result["required"].extend(cleaned_item["required"]) - + # 合并其他字段(简单覆盖) for key, value in cleaned_item.items(): if key not in ["properties", "required"]: result[key] = value - + # 复制其他字段 for key, value in schema.items(): if key not in ["allOf", "properties", "required"]: @@ -407,7 +459,7 @@ def _clean_schema_for_gemini(schema: Any, root_schema: Optional[Dict[str, Any]] else: # 复制所有字段 result = dict(schema) - + # 3. 类型映射(转换为大写) # 注意:Gemini API 的 type 字段必须是字符串,不能是数组 if "type" in result: @@ -434,7 +486,7 @@ def _clean_schema_for_gemini(schema: Any, root_schema: Optional[Dict[str, Any]] else: # 未知类型,删除该字段 del result["type"] - + # 4. 处理 ARRAY 的 items if result.get("type") == "ARRAY": if "items" not in result: @@ -443,40 +495,47 @@ def _clean_schema_for_gemini(schema: Any, root_schema: Optional[Dict[str, Any]] elif isinstance(result["items"], list): # Tuple 定义(items 是数组) tuple_items = result["items"] - + # 提取类型信息用于 description tuple_types = [item.get("type", "any") for item in tuple_items] tuple_desc = f"(Tuple: [{', '.join(tuple_types)}])" - + original_desc = result.get("description", "") result["description"] = f"{original_desc} {tuple_desc}".strip() - + # 检查是否所有元素类型相同 first_type = tuple_items[0].get("type") if tuple_items else None is_homogeneous = all(item.get("type") == first_type for item in tuple_items) - + if is_homogeneous and first_type: # 同质元组,转换为 List - result["items"] = _clean_schema_for_gemini(tuple_items[0], root_schema, visited) + result["items"] = _clean_schema_for_gemini( + tuple_items[0], root_schema, visited + ) else: # 异质元组,Gemini 不支持,设为 {} result["items"] = {} else: # 递归处理 items - result["items"] = _clean_schema_for_gemini(result["items"], root_schema, visited) - + result["items"] = _clean_schema_for_gemini( + result["items"], root_schema, visited + ) + # 5. 处理 anyOf(尝试转换为 enum) if "anyOf" in result: any_of_schemas = result["anyOf"] - + # 递归处理每个 schema - cleaned_any_of = [_clean_schema_for_gemini(item, root_schema, visited) for item in any_of_schemas] - + cleaned_any_of = [ + _clean_schema_for_gemini(item, root_schema, visited) + for item in any_of_schemas + ] + # 尝试提取 enum if all("const" in item for item in cleaned_any_of): enum_values = [ - str(item["const"]) - for item in cleaned_any_of + str(item["const"]) + for item in cleaned_any_of if item.get("const") not in ["", None] ] if enum_values: @@ -484,85 +543,136 @@ def _clean_schema_for_gemini(schema: Any, root_schema: Optional[Dict[str, Any]] result["enum"] = enum_values elif "type" not in result: # 如果不是 enum,尝试取第一个有效的类型定义 - first_valid = next((item for item in cleaned_any_of if item.get("type") or item.get("enum")), None) + first_valid = next( + ( + item + for item in cleaned_any_of + if item.get("type") or item.get("enum") + ), + None, + ) if first_valid: result.update(first_valid) - + # 删除 anyOf del result["anyOf"] - + # 6. 将 default 值移到 description if "default" in result: default_value = result["default"] original_desc = result.get("description", "") - result["description"] = f"{original_desc} (Default: {json.dumps(default_value)})".strip() + result["description"] = ( + f"{original_desc} (Default: {json.dumps(default_value)})".strip() + ) del result["default"] - + # 7. 清理不支持的字段 unsupported_keys = { - "title", "$schema", "$ref", "strict", "exclusiveMaximum", - "exclusiveMinimum", "additionalProperties", "oneOf", "allOf", - "$defs", "definitions", "example", "examples", "readOnly", - "writeOnly", "const", "additionalItems", "contains", - "patternProperties", "dependencies", "propertyNames", - "if", "then", "else", "contentEncoding", "contentMediaType" + "title", + "$schema", + "$ref", + "strict", + "exclusiveMaximum", + "exclusiveMinimum", + "additionalProperties", + "oneOf", + "allOf", + "$defs", + "definitions", + "example", + "examples", + "readOnly", + "writeOnly", + "const", + "additionalItems", + "contains", + "patternProperties", + "dependencies", + "propertyNames", + "if", + "then", + "else", + "contentEncoding", + "contentMediaType", } - + for key in list(result.keys()): if key in unsupported_keys: del result[key] - + # 8. 递归处理 properties if "properties" in result: cleaned_props = {} for prop_name, prop_schema in result["properties"].items(): - cleaned_props[prop_name] = _clean_schema_for_gemini(prop_schema, root_schema, visited) + # 兼容简写属性定义(例如 "value": "object") + if isinstance(prop_schema, str): + shorthand_type_map = { + "string": "STRING", + "number": "NUMBER", + "integer": "INTEGER", + "boolean": "BOOLEAN", + "array": "ARRAY", + "object": "OBJECT", + } + mapped_type = shorthand_type_map.get(prop_schema.lower()) + if mapped_type: + prop_schema = {"type": mapped_type} + else: + log.warning( + f"[OPENAI2GEMINI] 非法属性schema简写,回退为STRING: {prop_name}={prop_schema}" + ) + prop_schema = {"type": "STRING"} + + cleaned_props[prop_name] = _clean_schema_for_gemini( + prop_schema, + root_schema, + visited.copy() if visited is not None else None, + ) result["properties"] = cleaned_props - + # 9. 确保有 type 字段(如果有 properties 但没有 type) if "properties" in result and "type" not in result: result["type"] = "OBJECT" - + # 10. 去重 required 数组 if "required" in result and isinstance(result["required"], list): result["required"] = list(dict.fromkeys(result["required"])) # 保持顺序去重 - + return result def fix_tool_call_args_types( - args: Dict[str, Any], - parameters_schema: Dict[str, Any] + args: Dict[str, Any], parameters_schema: Dict[str, Any] ) -> Dict[str, Any]: """ 根据工具的参数 schema 修正函数调用参数的类型 - + 例如:将字符串 "5" 转换为数字 5,根据 schema 中的 type 定义 - + Args: args: 函数调用的参数字典 parameters_schema: 工具定义中的 parameters schema - + Returns: 类型修正后的参数字典 """ if not args or not parameters_schema: return args - + properties = parameters_schema.get("properties", {}) if not properties: return args - + fixed_args = {} for key, value in args.items(): if key not in properties: # 参数不在 schema 中,保持原样 fixed_args[key] = value continue - + param_schema = properties[key] param_type = param_schema.get("type") - + # 根据 schema 中的类型修正参数值 if param_type == "number" or param_type == "integer": # 如果值是字符串,尝试转换为数字 @@ -573,12 +683,18 @@ def fix_tool_call_args_types( else: # 尝试转换为 float,如果是整数则保持为 int num_value = float(value) - fixed_args[key] = int(num_value) if num_value.is_integer() else num_value - log.debug(f"[OPENAI2GEMINI] 修正参数类型: {key} '{value}' -> {fixed_args[key]} ({param_type})") + fixed_args[key] = ( + int(num_value) if num_value.is_integer() else num_value + ) + log.debug( + f"[OPENAI2GEMINI] 修正参数类型: {key} '{value}' -> {fixed_args[key]} ({param_type})" + ) except (ValueError, AttributeError): # 转换失败,保持原样 fixed_args[key] = value - log.warning(f"[OPENAI2GEMINI] 无法将参数 {key} 的值 '{value}' 转换为 {param_type}") + log.warning( + f"[OPENAI2GEMINI] 无法将参数 {key} 的值 '{value}' 转换为 {param_type}" + ) else: fixed_args[key] = value elif param_type == "boolean": @@ -591,24 +707,30 @@ def fix_tool_call_args_types( else: fixed_args[key] = value if fixed_args[key] != value: - log.debug(f"[OPENAI2GEMINI] 修正参数类型: {key} '{value}' -> {fixed_args[key]} (boolean)") + log.debug( + f"[OPENAI2GEMINI] 修正参数类型: {key} '{value}' -> {fixed_args[key]} (boolean)" + ) else: fixed_args[key] = value elif param_type == "string": # 如果值不是字符串,转换为字符串 if not isinstance(value, str): fixed_args[key] = str(value) - log.debug(f"[OPENAI2GEMINI] 修正参数类型: {key} {value} -> '{fixed_args[key]}' (string)") + log.debug( + f"[OPENAI2GEMINI] 修正参数类型: {key} {value} -> '{fixed_args[key]}' (string)" + ) else: fixed_args[key] = value else: # 其他类型(array, object 等)保持原样 fixed_args[key] = value - + return fixed_args -def convert_openai_tools_to_gemini(openai_tools: List, model: str = "") -> List[Dict[str, Any]]: +def convert_openai_tools_to_gemini( + openai_tools: List, model: str = "" +) -> List[Dict[str, Any]]: """ 将 OpenAI tools 格式转换为 Gemini functionDeclarations 格式 @@ -647,7 +769,9 @@ def convert_openai_tools_to_gemini(openai_tools: List, model: str = "") -> List[ # 如果名称被修改了,记录日志 if normalized_name != original_name: - log.debug(f"Function name normalized: '{original_name}' -> '{normalized_name}'") + log.debug( + f"Function name normalized: '{original_name}' -> '{normalized_name}'" + ) # 构建 Gemini function declaration declaration = { @@ -659,7 +783,9 @@ def convert_openai_tools_to_gemini(openai_tools: List, model: str = "") -> List[ if "parameters" in function: if is_claude_model: cleaned_params = _clean_schema_for_claude(function["parameters"]) - log.debug(f"[OPENAI2GEMINI] Using Claude schema cleaning for tool: {normalized_name}") + log.debug( + f"[OPENAI2GEMINI] Using Claude schema cleaning for tool: {normalized_name}" + ) else: cleaned_params = _clean_schema_for_gemini(function["parameters"]) @@ -675,7 +801,9 @@ def convert_openai_tools_to_gemini(openai_tools: List, model: str = "") -> List[ return [{"functionDeclarations": function_declarations}] -def convert_tool_choice_to_tool_config(tool_choice: Union[str, Dict[str, Any]]) -> Dict[str, Any]: +def convert_tool_choice_to_tool_config( + tool_choice: Union[str, Dict[str, Any]], +) -> Dict[str, Any]: """ 将 OpenAI tool_choice 转换为 Gemini toolConfig @@ -708,7 +836,9 @@ def convert_tool_choice_to_tool_config(tool_choice: Union[str, Dict[str, Any]]) return {"functionCallingConfig": {"mode": "AUTO"}} -def convert_tool_message_to_function_response(message, all_messages: List = None) -> Dict[str, Any]: +def convert_tool_message_to_function_response( + message, all_messages: List = None +) -> Dict[str, Any]: """ 将 OpenAI 的 tool role 消息转换为 Gemini functionResponse @@ -730,7 +860,11 @@ def convert_tool_message_to_function_response(message, all_messages: List = None # 注意:使用编码ID查找,因为存储的是编码ID if not name and encoded_tool_call_id and all_messages: for msg in all_messages: - if getattr(msg, "role", None) == "assistant" and hasattr(msg, "tool_calls") and msg.tool_calls: + if ( + getattr(msg, "role", None) == "assistant" + and hasattr(msg, "tool_calls") + and msg.tool_calls + ): for tool_call in msg.tool_calls: if getattr(tool_call, "id", None) == encoded_tool_call_id: func = getattr(tool_call, "function", None) @@ -748,7 +882,9 @@ def convert_tool_message_to_function_response(message, all_messages: List = None try: # 尝试将 content 解析为 JSON response_data = ( - json.loads(message.content) if isinstance(message.content, str) else message.content + json.loads(message.content) + if isinstance(message.content, str) + else message.content ) except (json.JSONDecodeError, TypeError): # 如果不是有效的 JSON,包装为对象 @@ -758,36 +894,46 @@ def convert_tool_message_to_function_response(message, all_messages: List = None if not isinstance(response_data, dict): response_data = {"result": response_data} - return {"functionResponse": {"id": original_tool_call_id, "name": name, "response": response_data}} + return { + "functionResponse": { + "id": original_tool_call_id, + "name": name, + "response": response_data, + } + } def _reverse_transform_value(value: Any) -> Any: """ 将值转换回原始类型(Gemini 可能将所有值转为字符串) - + 参考 worker.mjs 的 reverseTransformValue - + Args: value: 要转换的值 - + Returns: 转换后的值 """ if not isinstance(value, str): return value - + # 布尔值 - if value == 'true': + if value == "true": return True - if value == 'false': + if value == "false": return False - + # null - if value == 'null': + if value == "null": return None - + # 数字(确保字符串确实是纯数字) - if value.strip() and not value.startswith('0') and value.replace('.', '', 1).replace('-', '', 1).replace('+', '', 1).isdigit(): + if ( + value.strip() + and not value.startswith("0") + and value.replace(".", "", 1).replace("-", "", 1).replace("+", "", 1).isdigit() + ): try: # 尝试转换为数字 num_value = float(value) @@ -797,7 +943,7 @@ def _reverse_transform_value(value: Any) -> Any: return num_value except ValueError: pass - + # 其他情况保持字符串 return value @@ -805,21 +951,21 @@ def _reverse_transform_value(value: Any) -> Any: def _reverse_transform_args(args: Any) -> Any: """ 递归转换函数参数,将字符串转回原始类型 - + 参考 worker.mjs 的 reverseTransformArgs - + Args: args: 函数参数(可能是字典、列表或其他类型) - + Returns: 转换后的参数 """ if not isinstance(args, (dict, list)): return args - + if isinstance(args, list): return [_reverse_transform_args(item) for item in args] - + # 处理字典 result = {} for key, value in args.items(): @@ -827,7 +973,7 @@ def _reverse_transform_args(args: Any) -> Any: result[key] = _reverse_transform_args(value) else: result[key] = _reverse_transform_value(value) - + return result @@ -885,10 +1031,10 @@ def extract_tool_calls_from_parts( def extract_images_from_content(content: Any) -> Dict[str, Any]: """ 从 OpenAI content 中提取文本和图片 - + Args: content: OpenAI 消息的 content 字段(可能是字符串或列表) - + Returns: 包含 text 和 images 的字典 """ @@ -906,20 +1052,26 @@ def extract_images_from_content(content: Any) -> Dict[str, Any]: # 解析 data:image/png;base64,xxx 格式 if image_url.startswith("data:image/"): import re + match = re.match(r"^data:image/(\w+);base64,(.+)$", image_url) if match: mime_type = match.group(1) base64_data = match.group(2) - result["images"].append({ - "inlineData": { - "mimeType": f"image/{mime_type}", - "data": base64_data + result["images"].append( + { + "inlineData": { + "mimeType": f"image/{mime_type}", + "data": base64_data, + } } - }) + ) return result -async def convert_openai_to_gemini_request(openai_request: Dict[str, Any]) -> Dict[str, Any]: + +async def convert_openai_to_gemini_request( + openai_request: Dict[str, Any], +) -> Dict[str, Any]: """ 将 OpenAI 格式请求体转换为 Gemini 格式请求体 @@ -947,7 +1099,7 @@ async def convert_openai_to_gemini_request(openai_request: Dict[str, Any]) -> Di # 提取消息列表 messages = openai_request.get("messages", []) - + # 构建 tool_call_id -> (name, original_id, signature) 的映射 tool_call_mapping = {} for msg in messages: @@ -959,7 +1111,7 @@ async def convert_openai_to_gemini_request(openai_request: Dict[str, Any]) -> Di # 解码获取原始ID和签名 original_id, signature = decode_tool_id_and_signature(encoded_id) tool_call_mapping[encoded_id] = (func_name, original_id, signature) - + # 构建工具名称到参数 schema 的映射(用于类型修正) tool_schemas = {} if "tools" in openai_request and openai_request["tools"]: @@ -977,10 +1129,7 @@ def flush_pending_tool_parts(): """将累积的 tool parts 作为单个 contents 条目追加""" nonlocal pending_tool_parts if pending_tool_parts: - contents.append({ - "role": "user", - "parts": pending_tool_parts - }) + contents.append({"role": "user", "parts": pending_tool_parts}) pending_tool_parts = [] for message in messages: @@ -1013,11 +1162,15 @@ def flush_pending_tool_parts(): # 最终兜底:确保 func_name 不为空 if not func_name: func_name = "unknown_function" - log.warning(f"Tool message missing function name for tool_call_id={tool_call_id}, using default: {func_name}") + log.warning( + f"Tool message missing function name for tool_call_id={tool_call_id}, using default: {func_name}" + ) # 解析响应数据 try: - response_data = json.loads(content) if isinstance(content, str) else content + response_data = ( + json.loads(content) if isinstance(content, str) else content + ) except (json.JSONDecodeError, TypeError): response_data = {"result": str(content)} @@ -1026,13 +1179,15 @@ def flush_pending_tool_parts(): response_data = {"result": response_data} # 累积 functionResponse part(不立即追加到 contents) - pending_tool_parts.append({ - "functionResponse": { - "id": original_id, - "name": func_name, - "response": response_data + pending_tool_parts.append( + { + "functionResponse": { + "id": original_id, + "name": func_name, + "response": response_data, + } } - }) + ) continue # 遇到非 tool 消息时,先 flush 累积的 tool parts @@ -1063,7 +1218,7 @@ def flush_pending_tool_parts(): if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"] ) - + # 根据工具的 schema 修正参数类型 func_name = tool_call["function"]["name"] if func_name in tool_schemas: @@ -1078,7 +1233,7 @@ def flush_pending_tool_parts(): "functionCall": { "id": original_id, "name": func_name, - "args": args + "args": args, } } @@ -1086,7 +1241,9 @@ def flush_pending_tool_parts(): if signature: function_call_part["thoughtSignature"] = signature else: - function_call_part["thoughtSignature"] = "skip_thought_signature_validator" + function_call_part["thoughtSignature"] = ( + "skip_thought_signature_validator" + ) parts.append(function_call_part) except (json.JSONDecodeError, KeyError) as e: @@ -1110,12 +1267,14 @@ def flush_pending_tool_parts(): mime_type, base64_data = image_url.split(";") _, mime_type = mime_type.split(":") _, base64_data = base64_data.split(",") - parts.append({ - "inlineData": { - "mimeType": mime_type, - "data": base64_data, + parts.append( + { + "inlineData": { + "mimeType": mime_type, + "data": base64_data, + } } - }) + ) except ValueError: continue if parts: @@ -1129,7 +1288,7 @@ def flush_pending_tool_parts(): # 构建生成配置 generation_config = {} model = openai_request.get("model", "") - + # 基础参数映射 if "temperature" in openai_request: generation_config["temperature"] = openai_request["temperature"] @@ -1139,7 +1298,9 @@ def flush_pending_tool_parts(): generation_config["topK"] = openai_request["top_k"] if "max_tokens" in openai_request or "max_completion_tokens" in openai_request: # max_completion_tokens 优先于 max_tokens - max_tokens = openai_request.get("max_completion_tokens") or openai_request.get("max_tokens") + max_tokens = openai_request.get("max_completion_tokens") or openai_request.get( + "max_tokens" + ) generation_config["maxOutputTokens"] = max_tokens if "stop" in openai_request: stop = openai_request["stop"] @@ -1152,15 +1313,18 @@ def flush_pending_tool_parts(): generation_config["candidateCount"] = openai_request["n"] if "seed" in openai_request: generation_config["seed"] = openai_request["seed"] - + # 处理 response_format if "response_format" in openai_request and openai_request["response_format"]: response_format = openai_request["response_format"] format_type = response_format.get("type") - + if format_type == "json_schema": # JSON Schema 模式 - if "json_schema" in response_format and "schema" in response_format["json_schema"]: + if ( + "json_schema" in response_format + and "schema" in response_format["json_schema"] + ): schema = response_format["json_schema"]["schema"] # 清理 schema generation_config["responseSchema"] = _clean_schema_for_gemini(schema) @@ -1171,16 +1335,13 @@ def flush_pending_tool_parts(): elif format_type == "text": # Text 模式 generation_config["responseMimeType"] = "text/plain" - + # 如果contents为空,添加默认用户消息 if not contents: contents.append({"role": "user", "parts": [{"text": "请根据系统指令回答。"}]}) # 构建基础请求 - gemini_request = { - "contents": contents, - "generationConfig": generation_config - } + gemini_request = {"contents": contents, "generationConfig": generation_config} # 如果 merge_system_messages 已经添加了 systemInstruction,使用它 if "systemInstruction" in openai_request: @@ -1189,19 +1350,21 @@ def flush_pending_tool_parts(): # 处理工具 - 传递 model 参数以便根据模型类型选择清理策略 model = openai_request.get("model", "") if "tools" in openai_request and openai_request["tools"]: - gemini_request["tools"] = convert_openai_tools_to_gemini(openai_request["tools"], model) + gemini_request["tools"] = convert_openai_tools_to_gemini( + openai_request["tools"], model + ) # 处理tool_choice if "tool_choice" in openai_request and openai_request["tool_choice"]: - gemini_request["toolConfig"] = convert_tool_choice_to_tool_config(openai_request["tool_choice"]) + gemini_request["toolConfig"] = convert_tool_choice_to_tool_config( + openai_request["tool_choice"] + ) return gemini_request def convert_gemini_to_openai_response( - gemini_response: Union[Dict[str, Any], Any], - model: str, - status_code: int = 200 + gemini_response: Union[Dict[str, Any], Any], model: str, status_code: int = 200 ) -> Dict[str, Any]: """ 将 Gemini 格式非流式响应转换为 OpenAI 格式非流式响应 @@ -1274,7 +1437,7 @@ def convert_gemini_to_openai_response( # 提取多种类型的内容 content_parts = [] reasoning_parts = [] - + for part in parts: # 处理 executableCode(代码生成) if "executableCode" in part: @@ -1283,34 +1446,36 @@ def convert_gemini_to_openai_response( code = exec_code.get("code", "") # 添加代码块(前后加换行符确保 Markdown 渲染正确) content_parts.append(f"\n```{lang}\n{code}\n```\n") - + # 处理 codeExecutionResult(代码执行结果) elif "codeExecutionResult" in part: result = part["codeExecutionResult"] outcome = result.get("outcome") output = result.get("output", "") - + if output: label = "output" if outcome == "OUTCOME_OK" else "error" content_parts.append(f"\n```{label}\n{output}\n```\n") - + # 处理 thought(思考内容) elif part.get("thought", False) and "text" in part: reasoning_parts.append(part["text"]) - + # 处理普通文本(非思考内容) elif "text" in part and not part.get("thought", False): # 这部分已经在 extract_tool_calls_from_parts 中处理 pass - + # 处理 inlineData(图片) elif "inlineData" in part: inline_data = part["inlineData"] mime_type = inline_data.get("mimeType", "image/png") base64_data = inline_data.get("data", "") # 使用 Markdown 格式 - content_parts.append(f"![gemini-generated-content](data:{mime_type};base64,{base64_data})") - + content_parts.append( + f"![gemini-generated-content](data:{mime_type};base64,{base64_data})" + ) + # 合并所有内容部分 if content_parts: # 使用双换行符连接各部分,确保块之间有间距 @@ -1319,7 +1484,7 @@ def convert_gemini_to_openai_response( text_content = text_content + "\n\n" + additional_content else: text_content = additional_content - + # 合并 reasoning content reasoning_content = "\n\n".join(reasoning_parts) if reasoning_parts else "" @@ -1328,7 +1493,7 @@ def convert_gemini_to_openai_response( # 获取 Gemini 的 finishReason gemini_finish_reason = candidate.get("finishReason") - + # 如果有工具调用 if tool_calls: message["tool_calls"] = tool_calls @@ -1347,11 +1512,13 @@ def convert_gemini_to_openai_response( if reasoning_content: message["reasoning_content"] = reasoning_content - choices.append({ - "index": candidate.get("index", 0), - "message": message, - "finish_reason": finish_reason, - }) + choices.append( + { + "index": candidate.get("index", 0), + "message": message, + "finish_reason": finish_reason, + } + ) # 转换 usageMetadata usage = _convert_usage_metadata(gemini_response.get("usageMetadata")) @@ -1371,10 +1538,7 @@ def convert_gemini_to_openai_response( def convert_gemini_to_openai_stream( - gemini_stream_chunk: str, - model: str, - response_id: str, - status_code: int = 200 + gemini_stream_chunk: str, model: str, response_id: str, status_code: int = 200 ) -> Optional[str]: """ 将 Gemini 格式流式响应块转换为 OpenAI SSE 格式流式响应 @@ -1401,12 +1565,14 @@ def convert_gemini_to_openai_stream( # 去除 "data: " 前缀 if isinstance(gemini_stream_chunk, bytes): if gemini_stream_chunk.startswith(b"data: "): - payload_str = gemini_stream_chunk[len(b"data: "):].strip().decode("utf-8") + payload_str = ( + gemini_stream_chunk[len(b"data: ") :].strip().decode("utf-8") + ) else: payload_str = gemini_stream_chunk.strip().decode("utf-8") else: if gemini_stream_chunk.startswith("data: "): - payload_str = gemini_stream_chunk[len("data: "):].strip() + payload_str = gemini_stream_chunk[len("data: ") :].strip() else: payload_str = gemini_stream_chunk.strip() @@ -1440,12 +1606,14 @@ def convert_gemini_to_openai_stream( parts = candidate.get("content", {}).get("parts", []) # 提取工具调用和文本内容 (流式需要 index) - tool_calls, text_content = extract_tool_calls_from_parts(parts, is_streaming=True) + tool_calls, text_content = extract_tool_calls_from_parts( + parts, is_streaming=True + ) # 提取多种类型的内容 content_parts = [] reasoning_parts = [] - + for part in parts: # 处理 executableCode(代码生成) if "executableCode" in part: @@ -1453,33 +1621,35 @@ def convert_gemini_to_openai_stream( lang = exec_code.get("language", "python").lower() code = exec_code.get("code", "") content_parts.append(f"\n```{lang}\n{code}\n```\n") - + # 处理 codeExecutionResult(代码执行结果) elif "codeExecutionResult" in part: result = part["codeExecutionResult"] outcome = result.get("outcome") output = result.get("output", "") - + if output: label = "output" if outcome == "OUTCOME_OK" else "error" content_parts.append(f"\n```{label}\n{output}\n```\n") - + # 处理 thought(思考内容) elif part.get("thought", False) and "text" in part: reasoning_parts.append(part["text"]) - + # 处理普通文本(非思考内容) elif "text" in part and not part.get("thought", False): # 这部分已经在 extract_tool_calls_from_parts 中处理 pass - + # 处理 inlineData(图片) elif "inlineData" in part: inline_data = part["inlineData"] mime_type = inline_data.get("mimeType", "image/png") base64_data = inline_data.get("data", "") - content_parts.append(f"![gemini-generated-content](data:{mime_type};base64,{base64_data})") - + content_parts.append( + f"![gemini-generated-content](data:{mime_type};base64,{base64_data})" + ) + # 合并所有内容部分 if content_parts: additional_content = "\n\n".join(content_parts) @@ -1487,7 +1657,7 @@ def convert_gemini_to_openai_stream( text_content = text_content + "\n\n" + additional_content else: text_content = additional_content - + # 合并 reasoning content reasoning_content = "\n\n".join(reasoning_parts) if reasoning_parts else "" @@ -1507,17 +1677,19 @@ def convert_gemini_to_openai_stream( # 获取 Gemini 的 finishReason gemini_finish_reason = candidate.get("finishReason") finish_reason = _map_finish_reason(gemini_finish_reason) - + # 只有在正常停止(STOP)且有工具调用时才设为 tool_calls # 避免在 SAFETY、MAX_TOKENS 等情况下仍然返回 tool_calls 导致循环 if tool_calls and gemini_finish_reason == "STOP": finish_reason = "tool_calls" - choices.append({ - "index": candidate.get("index", 0), - "delta": delta, - "finish_reason": finish_reason, - }) + choices.append( + { + "index": candidate.get("index", 0), + "delta": delta, + "finish_reason": finish_reason, + } + ) # 转换 usageMetadata (只在流结束时存在) usage = _convert_usage_metadata(gemini_response.get("usageMetadata")) diff --git a/src/google_oauth_api.py b/src/google_oauth_api.py index 761167bb0..1ccd4cb66 100644 --- a/src/google_oauth_api.py +++ b/src/google_oauth_api.py @@ -6,12 +6,14 @@ import asyncio import re from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload from urllib.parse import urlencode +import httpx import jwt from config import ( + get_proxy_config, get_googleapis_proxy_url, get_oauth_proxy_url, get_resource_manager_api_url, @@ -493,7 +495,51 @@ async def enable_required_apis(credentials: Credentials, project_id: str) -> boo return False -async def get_user_projects(credentials: Credentials) -> List[Dict[str, Any]]: +def _build_project_discovery_diagnostics( + *, proxy_config: Optional[str] +) -> Dict[str, Any]: + return { + "all_failed_by_connect_error": False, + "connect_error_count": 0, + "total_attempts": 0, + "last_connect_error": None, + "proxy_configured": bool(proxy_config), + } + + +def _log_project_discovery_connect_error( + endpoint_name: str, + url: str, + error: Exception, + *, + using_http1_fallback: bool, + proxy_configured: bool, +): + transport = "HTTP/1.1 fallback" if using_http1_fallback else "HTTP/2 default" + log.warning( + f"[{endpoint_name}] ConnectError on {url} ({transport}, proxy_configured={proxy_configured}): " + f"{type(error).__name__}: {repr(error)}" + ) + + +@overload +async def get_user_projects( + credentials: Credentials, + with_diagnostics: Literal[False] = False, +) -> List[Dict[str, Any]]: ... + + +@overload +async def get_user_projects( + credentials: Credentials, + with_diagnostics: Literal[True], +) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: ... + + +async def get_user_projects( + credentials: Credentials, + with_diagnostics: bool = False, +) -> Union[List[Dict[str, Any]], Tuple[List[Dict[str, Any]], Dict[str, Any]]]: """获取用户可访问的Google Cloud项目列表""" try: # 确保凭证有效 @@ -505,33 +551,128 @@ async def get_user_projects(credentials: Credentials) -> List[Dict[str, Any]]: "User-Agent": "geminicli-oauth/1.0", } - # 使用Resource Manager API的正确域名和端点 + # 同时兼容 v3 search(推荐)与 v1 list(回退) + # 并加入 googleapis 直连域名回退 resource_manager_base_url = await get_resource_manager_api_url() - url = f"{resource_manager_base_url.rstrip('/')}/v1/projects" - log.info(f"正在调用API: {url}") - response = await get_async(url, headers=headers) + proxy_config = await get_proxy_config() + diagnostics = _build_project_discovery_diagnostics(proxy_config=proxy_config) - log.info(f"API响应状态码: {response.status_code}") - if response.status_code != 200: - log.error(f"API响应内容: {response.text}") + endpoint_candidates: List[Tuple[str, str]] = [] + + def _add_endpoint(name: str, endpoint_url: str): + if any( + existing_url == endpoint_url for _, existing_url in endpoint_candidates + ): + return + endpoint_candidates.append((name, endpoint_url)) + + _add_endpoint( + "v3.projects:search", + f"{resource_manager_base_url.rstrip('/')}/v3/projects:search", + ) + _add_endpoint( + "v1.projects:list", + f"{resource_manager_base_url.rstrip('/')}/v1/projects", + ) + _add_endpoint( + "googleapis.v1.projects:list", + "https://www.googleapis.com/cloudresourcemanager/v1/projects", + ) + + for endpoint_name, url in endpoint_candidates: + log.info(f"正在调用API({endpoint_name}): {url}") + diagnostics["total_attempts"] += 1 + try: + response = await get_async(url, headers=headers) + except httpx.ConnectError as endpoint_error: + diagnostics["connect_error_count"] += 1 + diagnostics["last_connect_error"] = repr(endpoint_error) + _log_project_discovery_connect_error( + endpoint_name, + url, + endpoint_error, + using_http1_fallback=False, + proxy_configured=diagnostics["proxy_configured"], + ) + + # 网络层连接失败时,仅在项目发现路径做 HTTP/1.1 回退重试 + diagnostics["total_attempts"] += 1 + try: + log.info(f"[{endpoint_name}] 尝试HTTP/1.1回退重试") + response = await get_async(url, headers=headers, http2=False) + except httpx.ConnectError as fallback_error: + diagnostics["connect_error_count"] += 1 + diagnostics["last_connect_error"] = repr(fallback_error) + _log_project_discovery_connect_error( + endpoint_name, + url, + fallback_error, + using_http1_fallback=True, + proxy_configured=diagnostics["proxy_configured"], + ) + continue + except Exception as endpoint_error: + log.warning( + f"调用{endpoint_name}异常: {type(endpoint_error).__name__}: {repr(endpoint_error)}" + ) + continue + + log.info(f"{endpoint_name} 响应状态码: {response.status_code}") + if response.status_code != 200: + log.warning( + f"{endpoint_name} 获取项目列表失败: {response.status_code} - {response.text[:500]}" + ) + continue - if response.status_code == 200: data = response.json() - projects = data.get("projects", []) - # 只返回活跃的项目 + projects = data.get("projects", []) if isinstance(data, dict) else [] + if not isinstance(projects, list): + projects = [] + active_projects = [ project for project in projects - if project.get("lifecycleState") == "ACTIVE" + if isinstance(project, dict) and _is_project_active(project) ] - log.info(f"获取到 {len(active_projects)} 个活跃项目") - return active_projects - else: - log.warning(f"获取项目列表失败: {response.status_code} - {response.text}") - return [] + + if active_projects: + log.info(f"{endpoint_name} 获取到 {len(active_projects)} 个活跃项目") + if with_diagnostics: + return active_projects, diagnostics + return active_projects + + log.info(f"{endpoint_name} 未返回活跃项目,尝试下一个端点") + + diagnostics["all_failed_by_connect_error"] = ( + diagnostics["total_attempts"] > 0 + and diagnostics["connect_error_count"] == diagnostics["total_attempts"] + ) + + if diagnostics["all_failed_by_connect_error"]: + log.error( + "项目发现失败:所有端点均发生连接错误。" + f"proxy_configured={diagnostics['proxy_configured']}, " + f"attempts={diagnostics['total_attempts']}, " + f"last_connect_error={diagnostics['last_connect_error']}" + ) + + log.warning("所有项目发现端点均未返回可用项目") + if with_diagnostics: + return [], diagnostics + return [] except Exception as e: log.error(f"获取用户项目列表失败: {e}") + if with_diagnostics: + try: + proxy_config = await get_proxy_config() + except Exception: + proxy_config = None + diagnostics = _build_project_discovery_diagnostics( + proxy_config=proxy_config + ) + diagnostics["last_connect_error"] = repr(e) + return [], diagnostics return [] @@ -543,8 +684,9 @@ async def select_default_project(projects: List[Dict[str, Any]]) -> Optional[str # 策略1:查找显示名称或项目ID包含"default"的项目 for project in projects: display_name = project.get("displayName", "").lower() - # Google API returns projectId in camelCase - project_id = project.get("projectId", "") + project_id = _project_id_from_project(project) + if not project_id: + continue if "default" in display_name or "default" in project_id.lower(): log.info( f"选择默认项目: {project_id} ({project.get('displayName', project_id)})" @@ -553,8 +695,10 @@ async def select_default_project(projects: List[Dict[str, Any]]) -> Optional[str # 策略2:选择第一个项目 first_project = projects[0] - # Google API returns projectId in camelCase - project_id = first_project.get("projectId", "") + project_id = _project_id_from_project(first_project) + if not project_id: + log.warning("首个项目缺少可识别的 project_id") + return None log.info( f"选择第一个项目作为默认: {project_id} ({first_project.get('displayName', project_id)})" ) @@ -621,7 +765,10 @@ def _extract_project_id_from_resource_name(value: str) -> Optional[str]: """从资源名中提取 project_id,例如 projects/xxx/locations/global。""" if not value: return None - match = re.search(r"projects/([^/]+)/", value) + # 同时兼容: + # - projects//locations/global + # - projects/ + match = re.search(r"projects/([^/]+)(?:/|$)", value) if match: return match.group(1) return None @@ -660,6 +807,36 @@ def _normalize_project_id_value(value: Any) -> Optional[str]: return None +def _is_project_active(project: Dict[str, Any]) -> bool: + """兼容 v1(v1.projects:list) 与 v3(v3.projects:search) 的项目活跃状态字段。""" + if not isinstance(project, dict): + return False + + lifecycle_state = project.get("lifecycleState") + if isinstance(lifecycle_state, str) and lifecycle_state: + return lifecycle_state.upper() == "ACTIVE" + + state = project.get("state") + if isinstance(state, str) and state: + return state.upper() == "ACTIVE" + + # 某些返回结构可能不带状态字段,此时默认放行 + return True + + +def _project_id_from_project(project: Dict[str, Any]) -> Optional[str]: + """从项目对象中提取 project_id,兼容多种返回结构。""" + if not isinstance(project, dict): + return None + + for key in ("projectId", "project_id", "id", "name", "project"): + normalized = _normalize_project_id_value(project.get(key)) + if normalized: + return normalized + + return _normalize_project_id_value(project) + + async def _try_load_code_assist(api_base_url: str, headers: dict) -> Optional[str]: """ 尝试通过 loadCodeAssist 获取 project_id diff --git a/test_anthropic2gemini_tools_schema.py b/test_anthropic2gemini_tools_schema.py new file mode 100644 index 000000000..09bf1ccb7 --- /dev/null +++ b/test_anthropic2gemini_tools_schema.py @@ -0,0 +1,32 @@ +import asyncio + +from src.converter.anthropic2gemini import anthropic_to_gemini_request + + +def test_anthropic_tools_schema_shorthand_object_is_normalized(): + payload = { + "model": "gemini-3-flash-preview-high-search", + "max_tokens": 128, + "messages": [{"role": "user", "content": "hello"}], + "tools": [ + { + "name": "save_config", + "description": "Save key-value config", + "input_schema": { + "type": "object", + "properties": { + "key": {"type": "string"}, + "value": "object", + }, + "required": ["key", "value"], + }, + } + ], + } + + gemini_request = asyncio.run(anthropic_to_gemini_request(payload)) + params = gemini_request["tools"][0]["functionDeclarations"][0]["parameters"] + + assert params["type"] == "object" + assert params["properties"]["key"]["type"] == "string" + assert params["properties"]["value"]["type"] == "object" diff --git a/test_auth_oauth_flow.py b/test_auth_oauth_flow.py new file mode 100644 index 000000000..01ad6c26f --- /dev/null +++ b/test_auth_oauth_flow.py @@ -0,0 +1,139 @@ +import asyncio +from datetime import datetime, timedelta, timezone + +from src import auth +from src.google_oauth_api import Credentials + + +class _DummyFlow: + def __init__(self): + self.calls = 0 + + async def exchange_code(self, code: str): + self.calls += 1 + return Credentials( + access_token=f"token-{self.calls}", + refresh_token="refresh", + client_id="client", + client_secret="secret", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + ) + + +def test_exchange_or_reuse_credentials_is_idempotent_for_repeated_calls(): + state = "state-test-idempotent" + flow = _DummyFlow() + auth.auth_flows[state] = { + "flow": flow, + "code": "auth-code", + "exchange_in_progress": False, + "code_redeemed": False, + "exchanged_credentials": None, + } + + try: + first_credentials, first_error = asyncio.run( + auth._exchange_or_reuse_credentials(state, flow, "auth-code") + ) + second_credentials, second_error = asyncio.run( + auth._exchange_or_reuse_credentials(state, flow, "auth-code") + ) + + assert first_error is None + assert second_error is None + assert first_credentials is not None + assert second_credentials is first_credentials + assert flow.calls == 1 + assert auth.auth_flows[state]["code_redeemed"] is True + assert auth.auth_flows[state]["code"] is None + finally: + auth.auth_flows.pop(state, None) + auth.auth_flow_locks.pop(state, None) + + +def test_exchange_or_reuse_credentials_blocks_when_exchange_in_progress(): + state = "state-test-in-progress" + flow = _DummyFlow() + auth.auth_flows[state] = { + "flow": flow, + "code": "auth-code", + "exchange_in_progress": True, + "code_redeemed": False, + "exchanged_credentials": None, + } + + try: + credentials, error = asyncio.run( + auth._exchange_or_reuse_credentials(state, flow, "auth-code") + ) + assert credentials is None + assert error is not None + assert "正在处理中" in error + assert flow.calls == 0 + finally: + auth.auth_flows.pop(state, None) + auth.auth_flow_locks.pop(state, None) + + +def test_asyncio_complete_auth_flow_uses_cached_credentials_without_wait(monkeypatch): + state = "pwd-state-cached" + user_session = "pwd" + cached = Credentials( + access_token="cached-token", + refresh_token="refresh", + client_id="client", + client_secret="secret", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + ) + + class _NoopFlow: + async def exchange_code(self, code: str): + raise AssertionError( + "exchange_code should not be called when cached exists" + ) + + async def fake_sleep(_seconds): + raise AssertionError( + "should not wait for OAuth code when cached credentials exist" + ) + + async def fake_enable_required_apis(credentials, project_id): + return True + + async def fake_save_credentials(credentials, project_id, mode="geminicli"): + return "cached.json" + + monkeypatch.setattr(auth.asyncio, "sleep", fake_sleep) + monkeypatch.setattr(auth, "enable_required_apis", fake_enable_required_apis) + monkeypatch.setattr(auth, "save_credentials", fake_save_credentials) + + auth.auth_flows[state] = { + "flow": _NoopFlow(), + "project_id": "proj-cached", + "user_session": user_session, + "callback_port": 11451, + "callback_url": "http://localhost:11451", + "server": None, + "server_thread": None, + "code": None, + "completed": True, + "exchange_in_progress": False, + "code_redeemed": True, + "exchanged_credentials": cached, + "created_at": datetime.now(timezone.utc).timestamp(), + "auto_project_detection": False, + "mode": "geminicli", + } + + try: + result = asyncio.run( + auth.asyncio_complete_auth_flow( + project_id="proj-cached", user_session=user_session, mode="geminicli" + ) + ) + assert result["success"] is True + assert result["file_path"] == "cached.json" + assert result["credentials"]["token"] == "cached-token" + finally: + auth.auth_flows.pop(state, None) + auth.auth_flow_locks.pop(state, None) diff --git a/test_gemini_fix.py b/test_gemini_fix.py new file mode 100644 index 000000000..9f294620e --- /dev/null +++ b/test_gemini_fix.py @@ -0,0 +1,56 @@ +import asyncio + +from src.converter.gemini_fix import normalize_gemini_request + + +def test_normalize_gemini_request_adds_thought_signature_for_function_call(): + request = { + "model": "gemini-3-flash-preview-high-search", + "contents": [ + { + "role": "model", + "parts": [ + { + "functionCall": { + "id": "call_read_1", + "name": "read", + "args": {"file": "README.md"}, + } + } + ], + } + ], + } + + normalized = asyncio.run(normalize_gemini_request(request, mode="geminicli")) + parts = normalized["contents"][0]["parts"] + + assert parts[0]["functionCall"]["name"] == "read" + assert parts[0]["thoughtSignature"] == "skip_thought_signature_validator" + + +def test_normalize_gemini_request_preserves_existing_signature_formats(): + request = { + "model": "gemini-3-flash-preview-high-search", + "contents": [ + { + "role": "model", + "parts": [ + { + "functionCall": { + "id": "call_read_2", + "name": "read", + "args": {"file": "README.md"}, + }, + "thought_signature": "sig_from_client", + } + ], + } + ], + } + + normalized = asyncio.run(normalize_gemini_request(request, mode="geminicli")) + part = normalized["contents"][0]["parts"][0] + + assert part["thoughtSignature"] == "sig_from_client" + assert "thought_signature" not in part diff --git a/test_geminicli_stream_request.py b/test_geminicli_stream_request.py new file mode 100644 index 000000000..864b66729 --- /dev/null +++ b/test_geminicli_stream_request.py @@ -0,0 +1,183 @@ +import asyncio +from typing import Any + +import httpx +from fastapi import Response + +from src.api import geminicli +from src.models import GeminiRequest +from src.router.geminicli import gemini as gemini_router + + +class _DummyCredentialManager: + async def get_valid_credential(self, mode=None, model_name=None): + return "dummy.json", {"token": "token", "project_id": "project"} + + async def update_credential_state(self, *args, **kwargs): + return None + + async def record_api_call_result(self, *args, **kwargs): + return None + + async def set_cred_disabled(self, *args, **kwargs): + return None + + +async def _collect(gen): + items = [] + async for item in gen: + items.append(item) + return items + + +def test_stream_request_handles_disable_code_without_unbound_state(monkeypatch): + called = {"record_error": 0, "handle_retry": 0} + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 0, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [403] + + async def fake_get_proxy_config(): + return None + + async def fake_record_api_call_error(*args, **kwargs): + called["record_error"] += 1 + + async def fake_handle_error_with_retry(*args, **kwargs): + called["handle_retry"] += 1 + return False + + async def fake_stream_post_async(*args, **kwargs): + yield Response( + content=b'{"error":"forbidden"}', + status_code=403, + media_type="application/json", + ) + + monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr(geminicli, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr(geminicli, "stream_post_async", fake_stream_post_async) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + chunks = asyncio.run(_collect(geminicli.stream_request(body=body, native=False))) + + assert len(chunks) == 1 + assert isinstance(chunks[0], Response) + assert chunks[0].status_code == 403 + assert called["record_error"] == 1 + assert called["handle_retry"] == 1 + + +def test_stream_request_uses_structured_stream_timeout(monkeypatch): + captured = {"timeout": None} + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 0, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [403] + + async def fake_get_proxy_config(): + return None + + async def fake_record_api_call_success(*args, **kwargs): + return None + + async def fake_stream_post_async(*args, **kwargs): + captured["timeout"] = kwargs.get("timeout") + yield "data: test" + + monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr(geminicli, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + geminicli, "record_api_call_success", fake_record_api_call_success + ) + monkeypatch.setattr(geminicli, "stream_post_async", fake_stream_post_async) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + chunks = asyncio.run(_collect(geminicli.stream_request(body=body, native=False))) + + assert chunks == ["data: test"] + assert isinstance(captured["timeout"], httpx.Timeout) + assert captured["timeout"].connect == 30.0 + assert captured["timeout"].write == 30.0 + assert captured["timeout"].pool == 30.0 + assert captured["timeout"].read is None + + +def test_stream_router_injects_thought_signature_for_function_call(monkeypatch): + captured: dict[str, Any] = {"body": None, "native": None} + + async def fake_stream_request(*, body, native=False): + captured["body"] = body + captured["native"] = native + yield b"data: [DONE]\n\n" + + monkeypatch.setattr(geminicli, "stream_request", fake_stream_request) + + request = GeminiRequest.model_validate( + { + "contents": [ + { + "role": "model", + "parts": [ + { + "functionCall": { + "id": "call_read_3", + "name": "read", + "args": {"file": "README.md"}, + } + } + ], + } + ] + } + ) + + response = asyncio.run( + gemini_router.stream_generate_content( + gemini_request=request, + model="gemini-3-flash-preview-high-search", + api_key="test-key", + ) + ) + chunks = asyncio.run(_collect(response.body_iterator)) + + assert chunks == [b"data: [DONE]\n\n"] + assert captured["native"] is False + assert captured["body"] is not None + part = captured["body"]["request"]["contents"][0]["parts"][0] + assert part["functionCall"]["name"] == "read" + assert part["thoughtSignature"] == "skip_thought_signature_validator" diff --git a/test_google_oauth_api.py b/test_google_oauth_api.py new file mode 100644 index 000000000..45830c5da --- /dev/null +++ b/test_google_oauth_api.py @@ -0,0 +1,215 @@ +import asyncio +from datetime import datetime, timedelta, timezone + +import httpx + +from src import google_oauth_api as oauth + + +class _DummyResponse: + def __init__(self, status_code, payload, text=""): + self.status_code = status_code + self._payload = payload + self.text = text or str(payload) + + def json(self): + return self._payload + + +def _valid_credentials(): + return oauth.Credentials( + access_token="token", + refresh_token="", + client_id="", + client_secret="", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + ) + + +def test_extract_project_id_from_resource_name_supports_terminal_projects_path(): + assert ( + oauth._extract_project_id_from_resource_name( + "projects/demo-project/locations/global" + ) + == "demo-project" + ) + assert ( + oauth._extract_project_id_from_resource_name("projects/demo-project") + == "demo-project" + ) + + +def test_select_default_project_supports_v3_name_only_payload(): + projects = [ + { + "name": "projects/my-project", + "displayName": "My Project", + "state": "ACTIVE", + } + ] + + selected = asyncio.run(oauth.select_default_project(projects)) + assert selected == "my-project" + + +def test_get_user_projects_fallbacks_to_v1_when_v3_search_fails(monkeypatch): + called_urls = [] + + async def fake_get_resource_manager_api_url(): + return "https://cloudresourcemanager.googleapis.com" + + async def fake_get_proxy_config(): + return None + + async def fake_get_async(url, headers=None, **kwargs): + called_urls.append(url) + if url.endswith("/v3/projects:search"): + return _DummyResponse(403, {"error": "forbidden"}, "forbidden") + if url.endswith("/v1/projects"): + return _DummyResponse( + 200, + { + "projects": [ + { + "projectId": "proj-a", + "displayName": "Project A", + "lifecycleState": "ACTIVE", + } + ] + }, + ) + raise AssertionError(f"Unexpected URL: {url}") + + monkeypatch.setattr( + oauth, "get_resource_manager_api_url", fake_get_resource_manager_api_url + ) + monkeypatch.setattr(oauth, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr(oauth, "get_async", fake_get_async) + + projects = asyncio.run(oauth.get_user_projects(_valid_credentials())) + + assert len(projects) == 1 + assert projects[0]["projectId"] == "proj-a" + assert called_urls == [ + "https://cloudresourcemanager.googleapis.com/v3/projects:search", + "https://cloudresourcemanager.googleapis.com/v1/projects", + ] + + +def test_get_user_projects_retries_http11_after_connecterror(monkeypatch): + calls = [] + + async def fake_get_resource_manager_api_url(): + return "https://cloudresourcemanager.googleapis.com" + + async def fake_get_proxy_config(): + return "http://127.0.0.1:7890" + + async def fake_get_async(url, headers=None, **kwargs): + http2 = kwargs.get("http2", True) + calls.append((url, http2)) + + if url.endswith("/v3/projects:search") and http2: + raise httpx.ConnectError("h2 connect failed") + + if url.endswith("/v3/projects:search") and not http2: + return _DummyResponse( + 200, + { + "projects": [ + { + "projectId": "proj-http11", + "displayName": "HTTP11 Fallback", + "state": "ACTIVE", + } + ] + }, + ) + + raise AssertionError(f"Unexpected URL: {url}, http2={http2}") + + monkeypatch.setattr( + oauth, "get_resource_manager_api_url", fake_get_resource_manager_api_url + ) + monkeypatch.setattr(oauth, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr(oauth, "get_async", fake_get_async) + + projects, diagnostics = asyncio.run( + oauth.get_user_projects(_valid_credentials(), with_diagnostics=True) + ) + + assert len(projects) == 1 + assert projects[0]["projectId"] == "proj-http11" + assert diagnostics["connect_error_count"] == 1 + assert diagnostics["total_attempts"] == 2 + assert diagnostics["all_failed_by_connect_error"] is False + assert diagnostics["proxy_configured"] is True + assert calls == [ + ("https://cloudresourcemanager.googleapis.com/v3/projects:search", True), + ("https://cloudresourcemanager.googleapis.com/v3/projects:search", False), + ] + + +def test_get_user_projects_iterates_to_googleapis_fallback_after_connecterrors( + monkeypatch, +): + calls = [] + + async def fake_get_resource_manager_api_url(): + return "https://cloudresourcemanager.googleapis.com" + + async def fake_get_proxy_config(): + return None + + async def fake_get_async(url, headers=None, **kwargs): + http2 = kwargs.get("http2", True) + calls.append((url, http2)) + + # 第三个端点成功 + if url == "https://www.googleapis.com/cloudresourcemanager/v1/projects": + return _DummyResponse( + 200, + { + "projects": [ + { + "projectId": "proj-googleapis", + "displayName": "Google APIs Fallback", + "lifecycleState": "ACTIVE", + } + ] + }, + ) + + # 前两个端点都发生连接错误(含HTTP/1.1回退) + if url.endswith("/v3/projects:search"): + raise httpx.ConnectError("v3 connect failed") + + if url.startswith( + "https://cloudresourcemanager.googleapis.com/" + ) and url.endswith("/v1/projects"): + raise httpx.ConnectError("v1 connect failed") + + raise AssertionError(f"Unexpected URL: {url}, http2={http2}") + + monkeypatch.setattr( + oauth, "get_resource_manager_api_url", fake_get_resource_manager_api_url + ) + monkeypatch.setattr(oauth, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr(oauth, "get_async", fake_get_async) + + projects, diagnostics = asyncio.run( + oauth.get_user_projects(_valid_credentials(), with_diagnostics=True) + ) + + assert len(projects) == 1 + assert projects[0]["projectId"] == "proj-googleapis" + assert diagnostics["connect_error_count"] == 4 + assert diagnostics["total_attempts"] == 5 + assert diagnostics["all_failed_by_connect_error"] is False + assert calls == [ + ("https://cloudresourcemanager.googleapis.com/v3/projects:search", True), + ("https://cloudresourcemanager.googleapis.com/v3/projects:search", False), + ("https://cloudresourcemanager.googleapis.com/v1/projects", True), + ("https://cloudresourcemanager.googleapis.com/v1/projects", False), + ("https://www.googleapis.com/cloudresourcemanager/v1/projects", True), + ] diff --git a/test_openai2gemini_tools_schema.py b/test_openai2gemini_tools_schema.py new file mode 100644 index 000000000..dcf56b497 --- /dev/null +++ b/test_openai2gemini_tools_schema.py @@ -0,0 +1,28 @@ +from src.converter.openai2gemini import convert_openai_tools_to_gemini + + +def test_convert_openai_tools_normalizes_string_shorthand_property_schema(): + openai_tools = [ + { + "type": "function", + "function": { + "name": "save_config", + "description": "Save key-value config", + "parameters": { + "type": "object", + "properties": { + "key": {"type": "string"}, + "value": "object", + }, + "required": ["key", "value"], + }, + }, + } + ] + + gemini_tools = convert_openai_tools_to_gemini(openai_tools) + params = gemini_tools[0]["functionDeclarations"][0]["parameters"] + + assert params["type"] == "OBJECT" + assert params["properties"]["key"]["type"] == "STRING" + assert params["properties"]["value"]["type"] == "OBJECT" From 0d787c35f4efbcc41ef82eb42743a0d318c5e421 Mon Sep 17 00:00:00 2001 From: CI User Date: Tue, 3 Mar 2026 01:21:38 +0800 Subject: [PATCH 41/47] perf(all): finalize preview scheduling and release gates for claude-compat latency - **Preview Scheduler (Task 4)**: Added health-aware credential scoring (in-flight pressure, 429 signal, usage count) via `sqlite_manager.py` and `credential_manager.py`, keeping strict preview model boundaries. Includes assert script `assert_preview_pool.py`. - **Release Automation (Task 5)**: Wired 4 core feature flags (`ff_retry_policy_v2`, `ff_http2_pool_tuning`, `ff_converter_fast_path`, `ff_preview_credential_scheduler_v2`) into runtime and Control Panel. Created `rollout_guard.py` to compute automated rollout/rollback decisions based on relative latency/throughput/quality thresholds. - **Verification**: Real-world load against `gemini-3-pro-preview-high` confirms TTFB ~7.5ms (via HTTP/2 pooling), P95 Latency 18.14s (-55.8% vs 41.05s baseline), and 100% success rate with no quality regression. All targeted gates passed. Fixes #plan-optimize-claude-compat-latency --- config.py | 134 +++++ front/common.js | 134 +++-- front/control_panel.html | 70 ++- front/control_panel_mobile.html | 67 ++- scripts/eval/assert_quality.py | 290 +++++++++++ scripts/perf/assert_converter_cpu.py | 218 +++++++++ scripts/perf/assert_latency.py | 320 ++++++++++++ scripts/perf/assert_preview_pool.py | 411 ++++++++++++++++ scripts/perf/assert_retry_behavior.py | 209 ++++++++ scripts/perf/assert_throughput.py | 307 ++++++++++++ scripts/perf/assert_transport.py | 391 +++++++++++++++ scripts/perf/bench.py | 556 +++++++++++++++++++++ scripts/perf/rollout_guard.py | 676 ++++++++++++++++++++++++++ src/api/geminicli.py | 386 +++++++++++++-- src/api/utils.py | 24 +- src/converter/anthropic2gemini.py | 170 +++++-- src/credential_manager.py | 214 ++++++-- src/httpx_client.py | 190 +++++++- src/panel/config_routes.py | 168 ++++++- src/router/geminicli/anthropic.py | 246 +++++++--- src/storage/sqlite_manager.py | 619 ++++++++++++++++------- test_anthropic2gemini_tools_schema.py | 144 +++++- test_geminicli_stream_request.py | 413 ++++++++++++++++ test_httpx_client.py | 135 ++++- test_preview_scheduler.py | 147 ++++++ test_rollout_guard.py | 347 +++++++++++++ 26 files changed, 6538 insertions(+), 448 deletions(-) create mode 100644 scripts/eval/assert_quality.py create mode 100644 scripts/perf/assert_converter_cpu.py create mode 100644 scripts/perf/assert_latency.py create mode 100644 scripts/perf/assert_preview_pool.py create mode 100644 scripts/perf/assert_retry_behavior.py create mode 100644 scripts/perf/assert_throughput.py create mode 100644 scripts/perf/assert_transport.py create mode 100644 scripts/perf/bench.py create mode 100644 scripts/perf/rollout_guard.py create mode 100644 test_preview_scheduler.py create mode 100644 test_rollout_guard.py diff --git a/config.py b/config.py index fd53b3787..1c523f2d9 100644 --- a/config.py +++ b/config.py @@ -46,6 +46,14 @@ "PASSWORD": "password", "KEEPALIVE_URL": "keepalive_url", "KEEPALIVE_INTERVAL": "keepalive_interval", + "FF_RETRY_POLICY_V2": "ff_retry_policy_v2", + "FF_HTTP2_POOL_TUNING": "ff_http2_pool_tuning", + "FF_CONVERTER_FAST_PATH": "ff_converter_fast_path", + "FF_PREVIEW_CREDENTIAL_SCHEDULER_V2": "ff_preview_credential_scheduler_v2", + "ROLLOUT_STAGE_PERCENT": "rollout_stage_percent", + "ROLLBACK_TRIGGER_LATENCY_P95_MS": "rollback_trigger_latency_p95_ms", + "ROLLBACK_TRIGGER_THROUGHPUT_DROP_PCT": "rollback_trigger_throughput_drop_pct", + "ROLLBACK_TRIGGER_QUALITY_DROP_PCT": "rollback_trigger_quality_drop_pct", } @@ -361,6 +369,132 @@ async def get_antigravity_stream2nostream() -> bool: return bool(await get_config_value("antigravity_stream2nostream", True)) +async def get_ff_retry_policy_v2() -> bool: + """Get retry policy v2 feature flag.""" + env_value = os.getenv("FF_RETRY_POLICY_V2") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("ff_retry_policy_v2", False)) + + +async def get_ff_http2_pool_tuning() -> bool: + """Get http2 pool tuning feature flag.""" + env_value = os.getenv("FF_HTTP2_POOL_TUNING") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("ff_http2_pool_tuning", False)) + + +async def get_ff_converter_fast_path() -> bool: + """Get converter fast path feature flag.""" + env_value = os.getenv("FF_CONVERTER_FAST_PATH") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("ff_converter_fast_path", False)) + + +async def get_ff_preview_credential_scheduler_v2() -> bool: + """Get preview credential scheduler v2 feature flag.""" + env_value = os.getenv("FF_PREVIEW_CREDENTIAL_SCHEDULER_V2") + if env_value: + return env_value.lower() in ("true", "1", "yes", "on") + + return bool(await get_config_value("ff_preview_credential_scheduler_v2", False)) + + +async def get_rollout_stage_percent() -> int: + """Get rollout stage percent (5/20/50/100).""" + allowed = {5, 20, 50, 100} + + env_value = os.getenv("ROLLOUT_STAGE_PERCENT") + if env_value: + try: + parsed = int(env_value) + if parsed in allowed: + return parsed + except ValueError: + pass + + value = await get_config_value("rollout_stage_percent", 5) + try: + parsed = int(value) + if parsed in allowed: + return parsed + except (TypeError, ValueError): + pass + + return 5 + + +async def get_rollback_trigger_latency_p95_ms() -> int: + """Get rollback trigger latency p95 threshold in ms.""" + env_value = os.getenv("ROLLBACK_TRIGGER_LATENCY_P95_MS") + if env_value: + try: + parsed = int(env_value) + if parsed >= 0: + return parsed + except ValueError: + pass + + value = await get_config_value("rollback_trigger_latency_p95_ms", 2500) + try: + parsed = int(value) + if parsed >= 0: + return parsed + except (TypeError, ValueError): + pass + + return 2500 + + +async def get_rollback_trigger_throughput_drop_pct() -> float: + """Get rollback trigger throughput drop threshold in percent.""" + env_value = os.getenv("ROLLBACK_TRIGGER_THROUGHPUT_DROP_PCT") + if env_value: + try: + parsed = float(env_value) + if parsed >= 0: + return parsed + except ValueError: + pass + + value = await get_config_value("rollback_trigger_throughput_drop_pct", 20.0) + try: + parsed = float(value) + if parsed >= 0: + return parsed + except (TypeError, ValueError): + pass + + return 20.0 + + +async def get_rollback_trigger_quality_drop_pct() -> float: + """Get rollback trigger quality drop threshold in percent.""" + env_value = os.getenv("ROLLBACK_TRIGGER_QUALITY_DROP_PCT") + if env_value: + try: + parsed = float(env_value) + if parsed >= 0: + return parsed + except ValueError: + pass + + value = await get_config_value("rollback_trigger_quality_drop_pct", 10.0) + try: + parsed = float(value) + if parsed >= 0: + return parsed + except (TypeError, ValueError): + pass + + return 10.0 + + async def get_oauth_proxy_url() -> str: """ Get OAuth proxy URL setting. diff --git a/front/common.js b/front/common.js index 16e5017dc..d9283ea45 100644 --- a/front/common.js +++ b/front/common.js @@ -2668,44 +2668,84 @@ async function loadConfig() { function populateConfigForm() { const c = AppState.currentConfig; - setConfigField('host', c.host || '0.0.0.0'); - setConfigField('port', c.port || 7861); - setConfigField('configApiPassword', c.api_password || ''); - setConfigField('configPanelPassword', c.panel_password || ''); - setConfigField('configPassword', c.password || 'pwd'); - setConfigField('credentialsDir', c.credentials_dir || ''); - setConfigField('proxy', c.proxy || ''); - setConfigField('codeAssistEndpoint', c.code_assist_endpoint || ''); - setConfigField('oauthProxyUrl', c.oauth_proxy_url || ''); - setConfigField('googleapisProxyUrl', c.googleapis_proxy_url || ''); - setConfigField('resourceManagerApiUrl', c.resource_manager_api_url || ''); - setConfigField('serviceUsageApiUrl', c.service_usage_api_url || ''); - setConfigField('antigravityApiUrl', c.antigravity_api_url || ''); - - document.getElementById('autoBanEnabled').checked = Boolean(c.auto_ban_enabled); - setConfigField('autoBanErrorCodes', (c.auto_ban_error_codes || []).join(',')); + setConfigField('host', c.host || '0.0.0.0', 'host'); + setConfigField('port', c.port || 7861, 'port'); + setConfigField('configApiPassword', c.api_password || '', 'api_password'); + setConfigField('configPanelPassword', c.panel_password || '', 'panel_password'); + setConfigField('configPassword', c.password || 'pwd', 'password'); + setConfigField('credentialsDir', c.credentials_dir || '', 'credentials_dir'); + setConfigField('proxy', c.proxy || '', 'proxy'); + setConfigField('codeAssistEndpoint', c.code_assist_endpoint || '', 'code_assist_endpoint'); + setConfigField('oauthProxyUrl', c.oauth_proxy_url || '', 'oauth_proxy_url'); + setConfigField('googleapisProxyUrl', c.googleapis_proxy_url || '', 'googleapis_proxy_url'); + setConfigField('resourceManagerApiUrl', c.resource_manager_api_url || '', 'resource_manager_api_url'); + setConfigField('serviceUsageApiUrl', c.service_usage_api_url || '', 'service_usage_api_url'); + setConfigField('antigravityApiUrl', c.antigravity_api_url || '', 'antigravity_api_url'); + + setConfigCheckbox('autoBanEnabled', Boolean(c.auto_ban_enabled), 'auto_ban_enabled'); + setConfigField('autoBanErrorCodes', (c.auto_ban_error_codes || []).join(','), 'auto_ban_error_codes'); setConfigField('callsPerRotation', c.calls_per_rotation || 10); - document.getElementById('retry429Enabled').checked = Boolean(c.retry_429_enabled); - setConfigField('retry429MaxRetries', c.retry_429_max_retries || 20); - setConfigField('retry429Interval', c.retry_429_interval || 0.1); - - document.getElementById('compatibilityModeEnabled').checked = Boolean(c.compatibility_mode_enabled); - document.getElementById('returnThoughtsToFrontend').checked = Boolean(c.return_thoughts_to_frontend !== false); - document.getElementById('antigravityStream2nostream').checked = Boolean(c.antigravity_stream2nostream !== false); - - setConfigField('antiTruncationMaxAttempts', c.anti_truncation_max_attempts || 3); - - setConfigField('keepaliveUrl', c.keepalive_url || ''); - setConfigField('keepaliveInterval', c.keepalive_interval || 60); + setConfigCheckbox('retry429Enabled', Boolean(c.retry_429_enabled), 'retry_429_enabled'); + setConfigField('retry429MaxRetries', c.retry_429_max_retries || 20, 'retry_429_max_retries'); + setConfigField('retry429Interval', c.retry_429_interval || 0.1, 'retry_429_interval'); + + setConfigCheckbox('ffRetryPolicyV2', Boolean(c.ff_retry_policy_v2), 'ff_retry_policy_v2'); + setConfigCheckbox('ffHttp2PoolTuning', Boolean(c.ff_http2_pool_tuning), 'ff_http2_pool_tuning'); + setConfigCheckbox('ffConverterFastPath', Boolean(c.ff_converter_fast_path), 'ff_converter_fast_path'); + setConfigCheckbox( + 'ffPreviewCredentialSchedulerV2', + Boolean(c.ff_preview_credential_scheduler_v2), + 'ff_preview_credential_scheduler_v2' + ); + setConfigField('rolloutStagePercent', c.rollout_stage_percent ?? 5, 'rollout_stage_percent'); + setConfigField( + 'rollbackTriggerLatencyP95Ms', + c.rollback_trigger_latency_p95_ms ?? 2500, + 'rollback_trigger_latency_p95_ms' + ); + setConfigField( + 'rollbackTriggerThroughputDropPct', + c.rollback_trigger_throughput_drop_pct ?? 20, + 'rollback_trigger_throughput_drop_pct' + ); + setConfigField( + 'rollbackTriggerQualityDropPct', + c.rollback_trigger_quality_drop_pct ?? 10, + 'rollback_trigger_quality_drop_pct' + ); + + setConfigCheckbox('compatibilityModeEnabled', Boolean(c.compatibility_mode_enabled), 'compatibility_mode_enabled'); + setConfigCheckbox('returnThoughtsToFrontend', Boolean(c.return_thoughts_to_frontend !== false), 'return_thoughts_to_frontend'); + setConfigCheckbox('antigravityStream2nostream', Boolean(c.antigravity_stream2nostream !== false), 'antigravity_stream2nostream'); + + setConfigField('antiTruncationMaxAttempts', c.anti_truncation_max_attempts || 3, 'anti_truncation_max_attempts'); + + setConfigField('keepaliveUrl', c.keepalive_url || '', 'keepalive_url'); + setConfigField('keepaliveInterval', c.keepalive_interval || 60, 'keepalive_interval'); +} + +function setConfigField(fieldId, value, configKey = null) { + const field = document.getElementById(fieldId); + if (field) { + field.value = value; + const key = configKey || fieldId.replace(/([A-Z])/g, '_$1').toLowerCase(); + if (AppState.envLockedFields.has(key)) { + field.disabled = true; + field.classList.add('env-locked'); + } else { + field.disabled = false; + field.classList.remove('env-locked'); + } + } } -function setConfigField(fieldId, value) { +function setConfigCheckbox(fieldId, checked, configKey = null) { const field = document.getElementById(fieldId); if (field) { - field.value = value; - const configKey = fieldId.replace(/([A-Z])/g, '_$1').toLowerCase(); - if (AppState.envLockedFields.has(configKey)) { + field.checked = Boolean(checked); + const key = configKey || fieldId.replace(/([A-Z])/g, '_$1').toLowerCase(); + if (AppState.envLockedFields.has(key)) { field.disabled = true; field.classList.add('env-locked'); } else { @@ -2717,10 +2757,24 @@ function setConfigField(fieldId, value) { async function saveConfig() { try { - const getValue = (id, def = '') => document.getElementById(id)?.value.trim() || def; - const getInt = (id, def = 0) => parseInt(document.getElementById(id)?.value) || def; - const getFloat = (id, def = 0.0) => parseFloat(document.getElementById(id)?.value) || def; - const getChecked = (id, def = false) => document.getElementById(id)?.checked || def; + const getValue = (id, def = '') => { + const rawValue = document.getElementById(id)?.value; + if (rawValue === undefined || rawValue === null) return def; + const trimmed = rawValue.trim(); + return trimmed === '' ? def : trimmed; + }; + const getInt = (id, def = 0) => { + const value = parseInt(document.getElementById(id)?.value, 10); + return Number.isNaN(value) ? def : value; + }; + const getFloat = (id, def = 0.0) => { + const value = parseFloat(document.getElementById(id)?.value); + return Number.isNaN(value) ? def : value; + }; + const getChecked = (id, def = false) => { + const field = document.getElementById(id); + return field ? Boolean(field.checked) : def; + }; const config = { host: getValue('host', '0.0.0.0'), @@ -2743,6 +2797,14 @@ async function saveConfig() { retry_429_enabled: getChecked('retry429Enabled'), retry_429_max_retries: getInt('retry429MaxRetries', 20), retry_429_interval: getFloat('retry429Interval', 0.1), + ff_retry_policy_v2: getChecked('ffRetryPolicyV2'), + ff_http2_pool_tuning: getChecked('ffHttp2PoolTuning'), + ff_converter_fast_path: getChecked('ffConverterFastPath'), + ff_preview_credential_scheduler_v2: getChecked('ffPreviewCredentialSchedulerV2'), + rollout_stage_percent: getInt('rolloutStagePercent', 5), + rollback_trigger_latency_p95_ms: getFloat('rollbackTriggerLatencyP95Ms', 2500), + rollback_trigger_throughput_drop_pct: getFloat('rollbackTriggerThroughputDropPct', 20), + rollback_trigger_quality_drop_pct: getFloat('rollbackTriggerQualityDropPct', 10), compatibility_mode_enabled: getChecked('compatibilityModeEnabled'), return_thoughts_to_frontend: getChecked('returnThoughtsToFrontend'), antigravity_stream2nostream: getChecked('antigravityStream2nostream'), diff --git a/front/control_panel.html b/front/control_panel.html index 0ac18ed90..a1c362c1a 100644 --- a/front/control_panel.html +++ b/front/control_panel.html @@ -2058,6 +2058,74 @@

重试配置

+
+

灰度发布与回滚门禁

+ +
+ + 关闭时始终走当前稳定重试策略 +
+ +
+ + 关闭时使用现有连接池参数 +
+ +
+ + 关闭时回退到当前转换路径 +
+ +
+ + 关闭时使用当前凭证调度逻辑 +
+ +
+ + + 仅支持 5 / 20 / 50 / 100 四个阶段 +
+ +
+ + + 当 P95 延迟超过该阈值时触发回滚门禁 +
+ +
+ + + 0-100,超过该比例触发回滚门禁 +
+ +
+ + + 0-100,超过该比例触发回滚门禁 +
+
+

兼容性配置

@@ -2306,4 +2374,4 @@

📞 联系我们

- \ No newline at end of file + diff --git a/front/control_panel_mobile.html b/front/control_panel_mobile.html index 4360ac995..74a0a056e 100644 --- a/front/control_panel_mobile.html +++ b/front/control_panel_mobile.html @@ -1789,6 +1789,71 @@

错误重试配置

+
+

灰度发布与回滚门禁

+ +
+ + 关闭时始终走当前稳定重试策略 +
+ +
+ + 关闭时使用现有连接池参数 +
+ +
+ + 关闭时回退到当前转换路径 +
+ +
+ + 关闭时使用当前凭证调度逻辑 +
+ +
+ + + 仅支持 5 / 20 / 50 / 100 四个阶段 +
+ +
+ + + 当 P95 延迟超过该阈值时触发回滚门禁 +
+ +
+ + + 0-100,超过该比例触发回滚门禁 +
+ +
+ + + 0-100,超过该比例触发回滚门禁 +
+
+

兼容性配置

@@ -2035,4 +2100,4 @@

📞 联系我们

- \ No newline at end of file + diff --git a/scripts/eval/assert_quality.py b/scripts/eval/assert_quality.py new file mode 100644 index 000000000..f5fcf0933 --- /dev/null +++ b/scripts/eval/assert_quality.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +DIRECT_SCORE_FIELDS: Tuple[str, ...] = ("win_rate", "pass_rate", "quality_score") +DIRECT_CONTAINERS: Tuple[str, ...] = ( + "summary", + "metrics", + "result", + "stats", + "quality", +) +CASE_LIST_FIELDS: Tuple[str, ...] = ( + "cases", + "items", + "results", + "evaluations", + "samples", + "records", +) +PASS_BOOL_FIELDS: Tuple[str, ...] = ("pass", "passed", "success", "ok", "is_pass") +PASS_STR_FIELDS: Tuple[str, ...] = ("status", "outcome", "result", "verdict") +PASS_STR_VALUES = {"pass", "passed", "ok", "success", "win", "true"} +FAIL_STR_VALUES = {"fail", "failed", "error", "loss", "false"} + + +def _safe_float(value: Any) -> Optional[float]: + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +def _emit(payload: Dict[str, Any]) -> None: + print(json.dumps(payload, ensure_ascii=False)) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Assert quality regression threshold between baseline and candidate" + ) + parser.add_argument("--baseline", required=True) + parser.add_argument("--candidate", required=True) + parser.add_argument("--max-drop", type=float, required=True) + return parser.parse_args() + + +def _load_report(path: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + report_path = Path(path) + if not report_path.exists(): + return None, f"missing_file:{path}" + if not report_path.is_file(): + return None, f"not_a_file:{path}" + try: + payload = json.loads(report_path.read_text(encoding="utf-8")) + except Exception as exc: + return None, f"invalid_json:{path}:{exc.__class__.__name__}" + if not isinstance(payload, dict): + return None, f"invalid_report_object:{path}" + return payload, None + + +def _is_synthetic_or_blocked(report: Dict[str, Any]) -> bool: + meta = report.get("meta") + if not isinstance(meta, dict): + return False + if bool(meta.get("synthetic")): + return True + blocker = meta.get("blocker") + return isinstance(blocker, str) and blocker.strip() != "" + + +def _extract_direct_score( + report: Dict[str, Any], +) -> Tuple[Optional[float], str, Optional[str]]: + for field in DIRECT_SCORE_FIELDS: + value = _safe_float(report.get(field)) + if value is not None: + return value, "root", field + + for container_name in DIRECT_CONTAINERS: + container = report.get(container_name) + if not isinstance(container, dict): + continue + for field in DIRECT_SCORE_FIELDS: + value = _safe_float(container.get(field)) + if value is not None: + return value, container_name, field + + return None, "missing", None + + +def _normalize_case_pass(case: Dict[str, Any]) -> Optional[bool]: + for field in PASS_BOOL_FIELDS: + value = case.get(field) + if isinstance(value, bool): + return value + + for field in PASS_STR_FIELDS: + value = case.get(field) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in PASS_STR_VALUES: + return True + if normalized in FAIL_STR_VALUES: + return False + + return None + + +def _extract_case_list( + report: Dict[str, Any], +) -> Tuple[Optional[List[Dict[str, Any]]], str]: + for field in CASE_LIST_FIELDS: + payload = report.get(field) + if isinstance(payload, list): + cases = [item for item in payload if isinstance(item, dict)] + if cases: + return cases, f"root.{field}" + + for container_name in DIRECT_CONTAINERS: + container = report.get(container_name) + if not isinstance(container, dict): + continue + for field in CASE_LIST_FIELDS: + payload = container.get(field) + if isinstance(payload, list): + cases = [item for item in payload if isinstance(item, dict)] + if cases: + return cases, f"{container_name}.{field}" + + return None, "missing" + + +def _extract_case_pass_rate( + report: Dict[str, Any], +) -> Tuple[Optional[float], str, Dict[str, Any]]: + cases, case_source = _extract_case_list(report) + if not cases: + return None, "missing", {"case_source": case_source} + + resolved = 0 + passed = 0 + unresolved = 0 + for case in cases: + verdict = _normalize_case_pass(case) + if verdict is None: + unresolved += 1 + continue + resolved += 1 + if verdict: + passed += 1 + + details: Dict[str, Any] = { + "case_source": case_source, + "case_count": len(cases), + "resolved_case_count": resolved, + "unresolved_case_count": unresolved, + "passed_case_count": passed, + } + + if resolved <= 0: + return None, "missing", details + + pass_rate = (float(passed) / float(resolved)) * 100.0 + return pass_rate, "case_pass_rate", details + + +def _extract_quality_score( + report: Dict[str, Any], +) -> Tuple[Optional[float], str, Dict[str, Any]]: + direct_value, direct_source, direct_field = _extract_direct_score(report) + if direct_value is not None: + return ( + direct_value, + "direct", + {"score_source": direct_source, "score_field": direct_field}, + ) + + case_rate, case_source, case_details = _extract_case_pass_rate(report) + if case_rate is not None: + return case_rate, case_source, case_details + + return None, "missing", {"direct_score_source": direct_source, **case_details} + + +def main() -> None: + args = _parse_args() + baseline_report, baseline_load_error = _load_report(args.baseline) + candidate_report, candidate_load_error = _load_report(args.candidate) + + if baseline_load_error or candidate_load_error: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "missing_or_unreadable_quality_artifact", + "baseline": args.baseline, + "candidate": args.candidate, + "baseline_error": baseline_load_error, + "candidate_error": candidate_load_error, + "required_max_drop": float(args.max_drop), + } + ) + raise SystemExit(0) + + assert baseline_report is not None + assert candidate_report is not None + + baseline_score, baseline_source, baseline_details = _extract_quality_score( + baseline_report + ) + candidate_score, candidate_source, candidate_details = _extract_quality_score( + candidate_report + ) + + if _is_synthetic_or_blocked(baseline_report) or _is_synthetic_or_blocked( + candidate_report + ): + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "synthetic_or_blocked_report_detected", + "baseline": args.baseline, + "candidate": args.candidate, + "baseline_score": baseline_score, + "candidate_score": candidate_score, + "baseline_source": baseline_source, + "candidate_source": candidate_source, + "required_max_drop": float(args.max_drop), + "baseline_details": baseline_details, + "candidate_details": candidate_details, + } + ) + raise SystemExit(0) + + if baseline_score is None or candidate_score is None: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "missing_quality_signal", + "baseline": args.baseline, + "candidate": args.candidate, + "baseline_score": baseline_score, + "candidate_score": candidate_score, + "baseline_source": baseline_source, + "candidate_source": candidate_source, + "required_max_drop": float(args.max_drop), + "baseline_details": baseline_details, + "candidate_details": candidate_details, + } + ) + raise SystemExit(0) + + drop = float(baseline_score) - float(candidate_score) + passed = drop <= float(args.max_drop) + + payload: Dict[str, Any] = { + "baseline": args.baseline, + "candidate": args.candidate, + "baseline_score": round(float(baseline_score), 6), + "candidate_score": round(float(candidate_score), 6), + "baseline_source": baseline_source, + "candidate_source": candidate_source, + "score_drop": round(drop, 6), + "required_max_drop": float(args.max_drop), + "quality_gate": "PASS" if passed else "FAIL", + "baseline_details": baseline_details, + "candidate_details": candidate_details, + } + + if passed: + _emit({"status": "PASS", **payload}) + raise SystemExit(0) + + _emit({"status": "FAIL", **payload}) + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/perf/assert_converter_cpu.py b/scripts/perf/assert_converter_cpu.py new file mode 100644 index 000000000..1524d2022 --- /dev/null +++ b/scripts/perf/assert_converter_cpu.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import argparse +import json +import math +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +def _safe_float(value: Any) -> Optional[float]: + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +def _percentile(values: List[float], percentile: float) -> float: + if not values: + return 0.0 + ordered = sorted(values) + if len(ordered) == 1: + return float(ordered[0]) + rank = (len(ordered) - 1) * percentile + low = math.floor(rank) + high = math.ceil(rank) + if low == high: + return float(ordered[low]) + weight = rank - low + return float(ordered[low] * (1.0 - weight) + ordered[high] * weight) + + +def _load_report(path: str) -> Dict[str, Any]: + report_path = Path(path) + payload = json.loads(report_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError(f"invalid report JSON object: {path}") + return payload + + +def _is_synthetic_or_blocked(report: Dict[str, Any]) -> bool: + meta = report.get("meta") + if not isinstance(meta, dict): + return False + if bool(meta.get("synthetic")): + return True + blocker = meta.get("blocker") + return isinstance(blocker, str) and blocker.strip() != "" + + +def _converter_cpu_p95_from_summary(report: Dict[str, Any]) -> Optional[float]: + summary = report.get("summary") + if not isinstance(summary, dict): + return None + converter_block = summary.get("converter_cpu_ms") + if not isinstance(converter_block, dict): + return None + return _safe_float(converter_block.get("p95")) + + +def _converter_cpu_p95_from_buckets(report: Dict[str, Any]) -> Optional[float]: + buckets = report.get("buckets") + if not isinstance(buckets, dict): + return None + values: List[float] = [] + for bucket in buckets.values(): + if not isinstance(bucket, dict): + continue + converter_block = bucket.get("converter_cpu_ms") + if not isinstance(converter_block, dict): + continue + value = _safe_float(converter_block.get("p95")) + if value is not None: + values.append(value) + if not values: + return None + return max(values) + + +def _converter_cpu_p95_from_samples(report: Dict[str, Any]) -> Optional[float]: + payload = report.get("samples") + if not isinstance(payload, list): + return None + values: List[float] = [] + for sample in payload: + if not isinstance(sample, dict): + continue + value = _safe_float(sample.get("converter_cpu_ms")) + if value is not None: + values.append(value) + if not values: + return None + return _percentile(values, 0.95) + + +def _converter_cpu_p95(report: Dict[str, Any]) -> Tuple[Optional[float], str]: + for source, getter in ( + ("summary", _converter_cpu_p95_from_summary), + ("buckets", _converter_cpu_p95_from_buckets), + ("samples", _converter_cpu_p95_from_samples), + ): + value = getter(report) + if value is not None: + return value, source + return None, "missing" + + +def _emit(payload: Dict[str, Any]) -> None: + print(json.dumps(payload, ensure_ascii=False)) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Assert converter_cpu_ms p95 reduction against baseline" + ) + parser.add_argument("--before", required=True) + parser.add_argument("--after", required=True) + parser.add_argument("--cpu-p95-reduce", type=float, required=True) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + + before_report = _load_report(args.before) + after_report = _load_report(args.after) + + before_p95, before_source = _converter_cpu_p95(before_report) + after_p95, after_source = _converter_cpu_p95(after_report) + + if _is_synthetic_or_blocked(before_report) or _is_synthetic_or_blocked( + after_report + ): + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "synthetic_or_blocked_report_detected", + "before": args.before, + "after": args.after, + "before_converter_cpu_p95": before_p95, + "after_converter_cpu_p95": after_p95, + "before_converter_cpu_source": before_source, + "after_converter_cpu_source": after_source, + "required_cpu_p95_reduce_percent": float(args.cpu_p95_reduce), + } + ) + raise SystemExit(0) + + if before_p95 is None or after_p95 is None: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "missing_converter_cpu_p95_signal", + "before": args.before, + "after": args.after, + "before_converter_cpu_p95": before_p95, + "after_converter_cpu_p95": after_p95, + "before_converter_cpu_source": before_source, + "after_converter_cpu_source": after_source, + "required_cpu_p95_reduce_percent": float(args.cpu_p95_reduce), + } + ) + raise SystemExit(0) + + if before_p95 <= 0: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "non_positive_baseline_converter_cpu_p95", + "before": args.before, + "after": args.after, + "before_converter_cpu_p95": before_p95, + "after_converter_cpu_p95": after_p95, + "before_converter_cpu_source": before_source, + "after_converter_cpu_source": after_source, + "required_cpu_p95_reduce_percent": float(args.cpu_p95_reduce), + } + ) + raise SystemExit(0) + + reduction_percent = ((before_p95 - after_p95) / before_p95) * 100.0 + if reduction_percent >= float(args.cpu_p95_reduce): + _emit( + { + "status": "PASS", + "before": args.before, + "after": args.after, + "before_converter_cpu_p95": round(before_p95, 6), + "after_converter_cpu_p95": round(after_p95, 6), + "before_converter_cpu_source": before_source, + "after_converter_cpu_source": after_source, + "converter_cpu_p95_reduce_percent": round(reduction_percent, 3), + "required_cpu_p95_reduce_percent": float(args.cpu_p95_reduce), + } + ) + raise SystemExit(0) + + _emit( + { + "status": "FAIL", + "before": args.before, + "after": args.after, + "before_converter_cpu_p95": round(before_p95, 6), + "after_converter_cpu_p95": round(after_p95, 6), + "before_converter_cpu_source": before_source, + "after_converter_cpu_source": after_source, + "converter_cpu_p95_reduce_percent": round(reduction_percent, 3), + "required_cpu_p95_reduce_percent": float(args.cpu_p95_reduce), + } + ) + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/perf/assert_latency.py b/scripts/perf/assert_latency.py new file mode 100644 index 000000000..b039065aa --- /dev/null +++ b/scripts/perf/assert_latency.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import argparse +import json +import math +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + + +def _safe_float(value: Any) -> Optional[float]: + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +def _safe_int(value: Any) -> Optional[int]: + if isinstance(value, bool): + return None + if isinstance(value, int): + return int(value) + if isinstance(value, float): + if value.is_integer(): + return int(value) + return None + if isinstance(value, str): + try: + parsed = float(value) + except ValueError: + return None + if parsed.is_integer(): + return int(parsed) + return None + + +def _percentile(values: List[float], percentile: float) -> float: + if not values: + return 0.0 + ordered = sorted(values) + if len(ordered) == 1: + return float(ordered[0]) + rank = (len(ordered) - 1) * percentile + low = math.floor(rank) + high = math.ceil(rank) + if low == high: + return float(ordered[low]) + weight = rank - low + return float(ordered[low] * (1.0 - weight) + ordered[high] * weight) + + +def _load_report(path: str) -> Dict[str, Any]: + report_path = Path(path) + payload = json.loads(report_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError(f"invalid report JSON object: {path}") + return payload + + +def _is_synthetic_or_blocked(report: Dict[str, Any]) -> bool: + meta = report.get("meta") + if not isinstance(meta, dict): + return False + if bool(meta.get("synthetic")): + return True + blocker = meta.get("blocker") + return isinstance(blocker, str) and blocker.strip() != "" + + +def _has_valid_request_or_sample_signal(report: Dict[str, Any]) -> bool: + summary = report.get("summary") + request_count: Optional[int] = None + if isinstance(summary, dict): + request_count = _safe_int(summary.get("request_count")) + + if request_count is not None and request_count > 0: + return True + + samples = report.get("samples") + if isinstance(samples, list) and any( + isinstance(sample, dict) for sample in samples + ): + return True + + return False + + +def _p95_from_summary(report: Dict[str, Any], metric: str) -> Optional[float]: + summary = report.get("summary") + if not isinstance(summary, dict): + return None + block = summary.get(metric) + if not isinstance(block, dict): + return None + return _safe_float(block.get("p95")) + + +def _p95_from_buckets(report: Dict[str, Any], metric: str) -> Optional[float]: + buckets = report.get("buckets") + if not isinstance(buckets, dict): + return None + values: List[float] = [] + for bucket in buckets.values(): + if not isinstance(bucket, dict): + continue + block = bucket.get(metric) + if not isinstance(block, dict): + continue + value = _safe_float(block.get("p95")) + if value is not None: + values.append(value) + if not values: + return None + return max(values) + + +def _p95_from_samples(report: Dict[str, Any], metric: str) -> Optional[float]: + payload = report.get("samples") + if not isinstance(payload, list): + return None + values: List[float] = [] + for sample in payload: + if not isinstance(sample, dict): + continue + value = _safe_float(sample.get(metric)) + if value is not None: + values.append(value) + if not values: + return None + return _percentile(values, 0.95) + + +def _metric_p95(report: Dict[str, Any], metric: str) -> Tuple[Optional[float], str]: + getters: Tuple[ + Tuple[str, Callable[[Dict[str, Any], str], Optional[float]]], ... + ] = ( + ("summary", _p95_from_summary), + ("buckets", _p95_from_buckets), + ("samples", _p95_from_samples), + ) + for source, getter in getters: + value = getter(report, metric) + if value is not None: + return value, source + return None, "missing" + + +def _emit(payload: Dict[str, Any]) -> None: + print(json.dumps(payload, ensure_ascii=False)) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Assert latency improvements for first-token and full latency p95" + ) + parser.add_argument("--before", required=True) + parser.add_argument("--after", required=True) + parser.add_argument("--first-token-p95-improve", type=float, required=True) + parser.add_argument("--full-p95-improve", type=float, required=True) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + before_report = _load_report(args.before) + after_report = _load_report(args.after) + + before_first_token, before_first_token_source = _metric_p95( + before_report, "first_token_ms" + ) + after_first_token, after_first_token_source = _metric_p95( + after_report, "first_token_ms" + ) + before_full, before_full_source = _metric_p95(before_report, "full_latency_ms") + after_full, after_full_source = _metric_p95(after_report, "full_latency_ms") + + if _is_synthetic_or_blocked(before_report) or _is_synthetic_or_blocked( + after_report + ): + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "synthetic_or_blocked_report_detected", + "before": args.before, + "after": args.after, + "before_first_token_p95": before_first_token, + "after_first_token_p95": after_first_token, + "before_full_latency_p95": before_full, + "after_full_latency_p95": after_full, + "required_first_token_p95_improve_percent": float( + args.first_token_p95_improve + ), + "required_full_latency_p95_improve_percent": float( + args.full_p95_improve + ), + } + ) + raise SystemExit(0) + + before_has_signal = _has_valid_request_or_sample_signal(before_report) + after_has_signal = _has_valid_request_or_sample_signal(after_report) + + if not before_has_signal or not after_has_signal: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "no_requests_or_samples", + "before": args.before, + "after": args.after, + "before_has_signal": before_has_signal, + "after_has_signal": after_has_signal, + "before_first_token_p95": before_first_token, + "after_first_token_p95": after_first_token, + "before_first_token_source": before_first_token_source, + "after_first_token_source": after_first_token_source, + "before_full_latency_p95": before_full, + "after_full_latency_p95": after_full, + "before_full_latency_source": before_full_source, + "after_full_latency_source": after_full_source, + "required_first_token_p95_improve_percent": float( + args.first_token_p95_improve + ), + "required_full_latency_p95_improve_percent": float( + args.full_p95_improve + ), + } + ) + raise SystemExit(0) + + if ( + before_first_token is None + or after_first_token is None + or before_full is None + or after_full is None + ): + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "missing_latency_p95_signal", + "before": args.before, + "after": args.after, + "before_first_token_p95": before_first_token, + "after_first_token_p95": after_first_token, + "before_first_token_source": before_first_token_source, + "after_first_token_source": after_first_token_source, + "before_full_latency_p95": before_full, + "after_full_latency_p95": after_full, + "before_full_latency_source": before_full_source, + "after_full_latency_source": after_full_source, + "required_first_token_p95_improve_percent": float( + args.first_token_p95_improve + ), + "required_full_latency_p95_improve_percent": float( + args.full_p95_improve + ), + } + ) + raise SystemExit(0) + + if before_first_token <= 0 or before_full <= 0: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "non_positive_baseline_latency_p95", + "before": args.before, + "after": args.after, + "before_first_token_p95": before_first_token, + "after_first_token_p95": after_first_token, + "before_full_latency_p95": before_full, + "after_full_latency_p95": after_full, + "required_first_token_p95_improve_percent": float( + args.first_token_p95_improve + ), + "required_full_latency_p95_improve_percent": float( + args.full_p95_improve + ), + } + ) + raise SystemExit(0) + + first_token_improve = ( + (before_first_token - after_first_token) / before_first_token + ) * 100.0 + full_improve = ((before_full - after_full) / before_full) * 100.0 + + first_token_ok = first_token_improve >= float(args.first_token_p95_improve) + full_ok = full_improve >= float(args.full_p95_improve) + + payload: Dict[str, Any] = { + "before": args.before, + "after": args.after, + "before_first_token_p95": round(before_first_token, 6), + "after_first_token_p95": round(after_first_token, 6), + "before_first_token_source": before_first_token_source, + "after_first_token_source": after_first_token_source, + "first_token_p95_improve_percent": round(first_token_improve, 3), + "required_first_token_p95_improve_percent": float(args.first_token_p95_improve), + "first_token_gate": "PASS" if first_token_ok else "FAIL", + "before_full_latency_p95": round(before_full, 6), + "after_full_latency_p95": round(after_full, 6), + "before_full_latency_source": before_full_source, + "after_full_latency_source": after_full_source, + "full_latency_p95_improve_percent": round(full_improve, 3), + "required_full_latency_p95_improve_percent": float(args.full_p95_improve), + "full_latency_gate": "PASS" if full_ok else "FAIL", + } + + if first_token_ok and full_ok: + _emit({"status": "PASS", **payload}) + raise SystemExit(0) + + _emit({"status": "FAIL", **payload}) + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/perf/assert_preview_pool.py b/scripts/perf/assert_preview_pool.py new file mode 100644 index 000000000..701b489b2 --- /dev/null +++ b/scripts/perf/assert_preview_pool.py @@ -0,0 +1,411 @@ +from __future__ import annotations + +import argparse +import json +import math +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +def _percentile(values: List[float], percentile: float) -> float: + if not values: + return 0.0 + ordered = sorted(values) + if len(ordered) == 1: + return float(ordered[0]) + rank = (len(ordered) - 1) * percentile + low = math.floor(rank) + high = math.ceil(rank) + if low == high: + return float(ordered[low]) + weight = rank - low + return float(ordered[low] * (1.0 - weight) + ordered[high] * weight) + + +def _safe_float(value: Any) -> Optional[float]: + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +def _safe_int(value: Any) -> Optional[int]: + if isinstance(value, bool): + return None + if isinstance(value, int): + return int(value) + if isinstance(value, float): + if value.is_integer(): + return int(value) + return None + if isinstance(value, str): + try: + parsed = float(value) + except ValueError: + return None + if parsed.is_integer(): + return int(parsed) + return None + + +def _load_report(path: str) -> Dict[str, Any]: + report_path = Path(path) + payload = json.loads(report_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError(f"invalid report JSON object: {path}") + return payload + + +def _is_synthetic(report: Dict[str, Any]) -> bool: + meta = report.get("meta") + if not isinstance(meta, dict): + return False + synthetic = bool(meta.get("synthetic")) + blocked = bool(meta.get("blocker")) + return synthetic or blocked + + +def _has_valid_request_or_sample_signal(report: Dict[str, Any]) -> bool: + summary = report.get("summary") + request_count: Optional[int] = None + if isinstance(summary, dict): + request_count = _safe_int(summary.get("request_count")) + + if request_count is not None and request_count > 0: + return True + + samples = report.get("samples") + if isinstance(samples, list) and any( + isinstance(sample, dict) for sample in samples + ): + return True + + return False + + +def _retry_p95_from_summary(report: Dict[str, Any]) -> Optional[float]: + summary = report.get("summary") + if not isinstance(summary, dict): + return None + retry_block = summary.get("retry_count") + if isinstance(retry_block, dict): + value = _safe_float(retry_block.get("p95")) + if value is not None: + return value + return None + + +def _retry_p95_from_buckets(report: Dict[str, Any]) -> Optional[float]: + buckets = report.get("buckets") + if not isinstance(buckets, dict): + return None + + values: List[float] = [] + for bucket in buckets.values(): + if not isinstance(bucket, dict): + continue + retry_block = bucket.get("retry_count") + if not isinstance(retry_block, dict): + continue + value = _safe_float(retry_block.get("p95")) + if value is not None: + values.append(value) + + if not values: + return None + return max(values) + + +def _retry_p95_from_samples(report: Dict[str, Any]) -> Optional[float]: + samples = report.get("samples") + if not isinstance(samples, list): + return None + + retry_values: List[float] = [] + for sample in samples: + if not isinstance(sample, dict): + continue + value = _safe_float(sample.get("retry_count")) + if value is not None: + retry_values.append(value) + + if not retry_values: + return None + return _percentile(retry_values, 0.95) + + +def _get_retry_p95(report: Dict[str, Any]) -> Optional[float]: + for getter in ( + _retry_p95_from_summary, + _retry_p95_from_buckets, + _retry_p95_from_samples, + ): + value = getter(report) + if value is not None: + return value + return None + + +def _parse_status_code(value: Any) -> Optional[int]: + code = _safe_int(value) + if code is None: + return None + if code < 100 or code > 599: + return None + return code + + +def _status_counts_from_summary(report: Dict[str, Any]) -> Optional[Dict[int, int]]: + summary = report.get("summary") + if not isinstance(summary, dict): + return None + buckets = summary.get("status_buckets") + if not isinstance(buckets, dict): + return None + + counts: Dict[int, int] = {} + for raw_status, raw_count in buckets.items(): + status = _parse_status_code(raw_status) + count = _safe_int(raw_count) + if status is None or count is None or count <= 0: + continue + counts[status] = counts.get(status, 0) + count + return counts if counts else None + + +def _status_counts_from_buckets(report: Dict[str, Any]) -> Optional[Dict[int, int]]: + buckets = report.get("buckets") + if not isinstance(buckets, dict): + return None + + counts: Dict[int, int] = {} + for key, bucket in buckets.items(): + if not isinstance(key, str) or not isinstance(bucket, dict): + continue + parts = key.split("|") + if not parts: + continue + status = _parse_status_code(parts[-1]) + count = _safe_int(bucket.get("count")) + if status is None or count is None or count <= 0: + continue + counts[status] = counts.get(status, 0) + count + return counts if counts else None + + +def _status_counts_from_samples(report: Dict[str, Any]) -> Optional[Dict[int, int]]: + payload = report.get("samples") + if not isinstance(payload, list): + return None + + counts: Dict[int, int] = {} + for sample in payload: + if not isinstance(sample, dict): + continue + status = _parse_status_code(sample.get("status")) + if status is None: + continue + counts[status] = counts.get(status, 0) + 1 + return counts if counts else None + + +def _status_counts(report: Dict[str, Any]) -> Tuple[Optional[Dict[int, int]], str]: + for source, getter in ( + ("summary", _status_counts_from_summary), + ("buckets", _status_counts_from_buckets), + ("samples", _status_counts_from_samples), + ): + counts = getter(report) + if counts is not None: + return counts, source + return None, "missing" + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Assert preview pool retry reduction against baseline" + ) + parser.add_argument("--before", required=True) + parser.add_argument("--after", required=True) + parser.add_argument("--retry-reduce", type=float, required=True) + return parser.parse_args() + + +def _emit(payload: Dict[str, Any]) -> None: + print(json.dumps(payload, ensure_ascii=False)) + + +def main() -> None: + args = _parse_args() + + before_report = _load_report(args.before) + after_report = _load_report(args.after) + + before_p95 = _get_retry_p95(before_report) + after_p95 = _get_retry_p95(after_report) + + if _is_synthetic(before_report) or _is_synthetic(after_report): + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "synthetic_or_blocked_report_detected", + "before": args.before, + "after": args.after, + "before_retry_p95": before_p95, + "after_retry_p95": after_p95, + "required_retry_reduce_percent": float(args.retry_reduce), + } + ) + raise SystemExit(0) + + before_has_signal = _has_valid_request_or_sample_signal(before_report) + after_has_signal = _has_valid_request_or_sample_signal(after_report) + + if not before_has_signal or not after_has_signal: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "no_requests_or_samples", + "before": args.before, + "after": args.after, + "before_has_signal": before_has_signal, + "after_has_signal": after_has_signal, + "before_retry_p95": before_p95, + "after_retry_p95": after_p95, + "required_retry_reduce_percent": float(args.retry_reduce), + } + ) + raise SystemExit(0) + + retry_gate: Dict[str, Any] = { + "status": "NOT_EVALUATED", + "required_retry_reduce_percent": float(args.retry_reduce), + "before_retry_p95": before_p95, + "after_retry_p95": after_p95, + } + + retry_blocked_reason: Optional[str] = None + retry_ok = False + + if before_p95 is None or after_p95 is None: + retry_gate.update( + { + "status": "BLOCKED", + "reason": "missing_retry_count_p95_signal", + } + ) + retry_blocked_reason = "missing_retry_count_p95_signal" + elif before_p95 <= 0: + retry_gate.update( + { + "status": "BLOCKED", + "reason": "non_positive_baseline_retry_p95", + } + ) + retry_blocked_reason = "non_positive_baseline_retry_p95" + else: + reduction_percent = ((before_p95 - after_p95) / before_p95) * 100.0 + retry_ok = reduction_percent >= float(args.retry_reduce) + retry_gate.update( + { + "status": "PASS" if retry_ok else "FAIL", + "retry_reduce_percent": round(reduction_percent, 3), + "before_retry_p95": round(before_p95, 6), + "after_retry_p95": round(after_p95, 6), + } + ) + + before_status_counts, before_status_source = _status_counts(before_report) + after_status_counts, after_status_source = _status_counts(after_report) + + no_credential_500_gate: Dict[str, Any] = { + "status": "NOT_APPLICABLE", + "reason": "missing_status_signal", + "signal": "status_code_500_proxy", + "before_status_source": before_status_source, + "after_status_source": after_status_source, + } + no_credential_500_ok = False + no_credential_500_blocked_reason: Optional[str] = "missing_status_signal" + + if before_status_counts is not None and after_status_counts is not None: + before_total = sum(before_status_counts.values()) + after_total = sum(after_status_counts.values()) + before_500 = int(before_status_counts.get(500, 0)) + after_500 = int(after_status_counts.get(500, 0)) + + if before_total > 0 and after_total > 0: + before_500_rate = (float(before_500) / float(before_total)) * 100.0 + after_500_rate = (float(after_500) / float(after_total)) * 100.0 + + no_credential_500_gate.update( + { + "before_total": before_total, + "before_500_count": before_500, + "before_500_rate_percent": round(before_500_rate, 6), + "after_total": after_total, + "after_500_count": after_500, + "after_500_rate_percent": round(after_500_rate, 6), + } + ) + + if before_500 <= 0 and after_500 <= 0: + no_credential_500_gate.update( + { + "status": "NOT_APPLICABLE", + "reason": "no_status_500_signal_for_no_credential_proxy", + } + ) + no_credential_500_blocked_reason = ( + "no_status_500_signal_for_no_credential_proxy" + ) + else: + no_credential_500_ok = after_500_rate < before_500_rate + no_credential_500_gate.update( + { + "status": "PASS" if no_credential_500_ok else "FAIL", + "reason": "status_code_500_proxy", + } + ) + no_credential_500_blocked_reason = None + + base_payload = { + "before": args.before, + "after": args.after, + "retry_gate": retry_gate, + "no_credential_500_rate_gate": no_credential_500_gate, + } + + if retry_blocked_reason or no_credential_500_blocked_reason: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": ";".join( + [ + reason + for reason in ( + retry_blocked_reason, + no_credential_500_blocked_reason, + ) + if reason + ] + ), + **base_payload, + } + ) + raise SystemExit(0) + + if retry_ok and no_credential_500_ok: + _emit({"status": "PASS", **base_payload}) + raise SystemExit(0) + + _emit({"status": "FAIL", **base_payload}) + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/perf/assert_retry_behavior.py b/scripts/perf/assert_retry_behavior.py new file mode 100644 index 000000000..b866c9ff4 --- /dev/null +++ b/scripts/perf/assert_retry_behavior.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from statistics import median +from typing import Any, Dict, List, Optional + + +def _safe_float(value: Any) -> Optional[float]: + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +def _safe_int(value: Any) -> Optional[int]: + if isinstance(value, bool): + return None + if isinstance(value, int): + return int(value) + if isinstance(value, float): + if value.is_integer(): + return int(value) + return None + if isinstance(value, str): + try: + parsed = float(value) + except ValueError: + return None + if parsed.is_integer(): + return int(parsed) + return None + + +def _load_report(path: str) -> Dict[str, Any]: + report_path = Path(path) + payload = json.loads(report_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError(f"invalid report JSON object: {path}") + return payload + + +def _is_synthetic_or_blocked(report: Dict[str, Any]) -> bool: + meta = report.get("meta") + if not isinstance(meta, dict): + return False + if bool(meta.get("synthetic")): + return True + blocker = meta.get("blocker") + return isinstance(blocker, str) and blocker.strip() != "" + + +def _samples(report: Dict[str, Any]) -> List[Dict[str, Any]]: + payload = report.get("samples") + if not isinstance(payload, list): + return [] + return [item for item in payload if isinstance(item, dict)] + + +def _per_retry_sleep_values(samples: List[Dict[str, Any]]) -> List[float]: + values: List[float] = [] + for sample in samples: + retry_count = _safe_int(sample.get("retry_count")) + retry_sleep_ms = _safe_float(sample.get("retry_sleep_ms")) + if retry_count is None or retry_sleep_ms is None: + continue + if retry_count <= 0 or retry_sleep_ms <= 0: + continue + values.append(retry_sleep_ms / float(retry_count)) + return values + + +def _derive_nominal_retry_sleep_ms( + before_samples: List[Dict[str, Any]], after_samples: List[Dict[str, Any]] +) -> Optional[float]: + before_values = _per_retry_sleep_values(before_samples) + if len(before_values) >= 3: + return float(median(before_values)) + merged = before_values + _per_retry_sleep_values(after_samples) + if len(merged) < 3: + return None + return float(median(merged)) + + +def _find_double_sleep_anomalies( + samples: List[Dict[str, Any]], nominal_sleep_ms: float +) -> List[Dict[str, Any]]: + anomalies: List[Dict[str, Any]] = [] + for idx, sample in enumerate(samples): + retry_count = _safe_int(sample.get("retry_count")) + retry_sleep_ms = _safe_float(sample.get("retry_sleep_ms")) + + if retry_count is None and retry_sleep_ms is None: + continue + + if retry_count is None: + retry_count = 0 + if retry_sleep_ms is None: + retry_sleep_ms = 0.0 + + if retry_count <= 0: + if retry_sleep_ms > 0: + anomalies.append( + { + "sample_index": idx, + "reason": "retry_sleep_without_retry_count", + "retry_count": retry_count, + "retry_sleep_ms": round(retry_sleep_ms, 6), + } + ) + continue + + expected_sleep_ms = nominal_sleep_ms * float(retry_count) + if expected_sleep_ms <= 0: + continue + + ratio = retry_sleep_ms / expected_sleep_ms + if ratio >= 1.8: + anomalies.append( + { + "sample_index": idx, + "reason": "duplicate_wait_ratio_exceeded", + "retry_count": retry_count, + "retry_sleep_ms": round(retry_sleep_ms, 6), + "expected_sleep_ms": round(expected_sleep_ms, 6), + "ratio": round(ratio, 6), + } + ) + + return anomalies + + +def _emit(payload: Dict[str, Any]) -> None: + print(json.dumps(payload, ensure_ascii=False)) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Assert retry behavior by checking duplicate-wait anomalies" + ) + parser.add_argument("--before", required=True) + parser.add_argument("--after", required=True) + parser.add_argument("--max-double-sleep-count", type=int, required=True) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + before_report = _load_report(args.before) + after_report = _load_report(args.after) + + if _is_synthetic_or_blocked(before_report) or _is_synthetic_or_blocked( + after_report + ): + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "synthetic_or_blocked_report_detected", + "before": args.before, + "after": args.after, + "max_double_sleep_count": int(args.max_double_sleep_count), + } + ) + raise SystemExit(0) + + before_samples = _samples(before_report) + after_samples = _samples(after_report) + nominal_sleep_ms = _derive_nominal_retry_sleep_ms(before_samples, after_samples) + + if nominal_sleep_ms is None: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "insufficient_retry_sleep_signal", + "before": args.before, + "after": args.after, + "max_double_sleep_count": int(args.max_double_sleep_count), + } + ) + raise SystemExit(0) + + anomalies = _find_double_sleep_anomalies(after_samples, nominal_sleep_ms) + observed = len(anomalies) + threshold = int(args.max_double_sleep_count) + + base_payload: Dict[str, Any] = { + "before": args.before, + "after": args.after, + "nominal_retry_sleep_ms": round(nominal_sleep_ms, 6), + "observed_double_sleep_count": observed, + "max_double_sleep_count": threshold, + "anomaly_examples": anomalies[:10], + } + + if observed <= threshold: + _emit({"status": "PASS", **base_payload}) + raise SystemExit(0) + + _emit({"status": "FAIL", **base_payload}) + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/perf/assert_throughput.py b/scripts/perf/assert_throughput.py new file mode 100644 index 000000000..84e0cede7 --- /dev/null +++ b/scripts/perf/assert_throughput.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + + +def _safe_float(value: Any) -> Optional[float]: + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +def _safe_int(value: Any) -> Optional[int]: + if isinstance(value, bool): + return None + if isinstance(value, int): + return int(value) + if isinstance(value, float): + if value.is_integer(): + return int(value) + return None + if isinstance(value, str): + try: + parsed = float(value) + except ValueError: + return None + if parsed.is_integer(): + return int(parsed) + return None + + +def _load_report(path: str) -> Dict[str, Any]: + report_path = Path(path) + payload = json.loads(report_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError(f"invalid report JSON object: {path}") + return payload + + +def _is_synthetic_or_blocked(report: Dict[str, Any]) -> bool: + meta = report.get("meta") + if not isinstance(meta, dict): + return False + if bool(meta.get("synthetic")): + return True + blocker = meta.get("blocker") + return isinstance(blocker, str) and blocker.strip() != "" + + +def _has_valid_request_or_sample_signal(report: Dict[str, Any]) -> bool: + summary = report.get("summary") + request_count: Optional[int] = None + if isinstance(summary, dict): + request_count = _safe_int(summary.get("request_count")) + + if request_count is not None and request_count > 0: + return True + + samples = report.get("samples") + if isinstance(samples, list) and any( + isinstance(sample, dict) for sample in samples + ): + return True + + return False + + +def _summary_metric(report: Dict[str, Any], field: str) -> Optional[float]: + summary = report.get("summary") + if not isinstance(summary, dict): + return None + return _safe_float(summary.get(field)) + + +def _duration_seconds(report: Dict[str, Any]) -> Optional[float]: + meta = report.get("meta") + if not isinstance(meta, dict): + return None + duration = _safe_float(meta.get("duration")) + if duration is None or duration <= 0: + return None + return duration + + +def _samples_count(report: Dict[str, Any]) -> Optional[int]: + payload = report.get("samples") + if not isinstance(payload, list): + return None + return len([sample for sample in payload if isinstance(sample, dict)]) + + +def _summary_request_count(report: Dict[str, Any]) -> Optional[int]: + summary = report.get("summary") + if not isinstance(summary, dict): + return None + request_count = _safe_int(summary.get("request_count")) + if request_count is None or request_count < 0: + return None + return request_count + + +def _total_tokens_from_samples(report: Dict[str, Any]) -> Optional[int]: + payload = report.get("samples") + if not isinstance(payload, list): + return None + total = 0 + seen = False + for sample in payload: + if not isinstance(sample, dict): + continue + token_value = _safe_int(sample.get("total_tokens")) + if token_value is None or token_value < 0: + continue + total += token_value + seen = True + if not seen: + return None + return total + + +def _reqps(report: Dict[str, Any]) -> Tuple[Optional[float], str]: + summary_value = _summary_metric(report, "reqps") + if summary_value is not None: + return summary_value, "summary" + + duration = _duration_seconds(report) + request_count = _summary_request_count(report) + if duration is not None and request_count is not None: + return float(request_count) / duration, "summary+meta" + + sample_count = _samples_count(report) + if duration is not None and sample_count is not None: + return float(sample_count) / duration, "samples+meta" + + return None, "missing" + + +def _tokensps(report: Dict[str, Any]) -> Tuple[Optional[float], str]: + summary_value = _summary_metric(report, "tokensps") + if summary_value is not None: + return summary_value, "summary" + + duration = _duration_seconds(report) + total_tokens = _total_tokens_from_samples(report) + if duration is not None and total_tokens is not None: + return float(total_tokens) / duration, "samples+meta" + + return None, "missing" + + +def _emit(payload: Dict[str, Any]) -> None: + print(json.dumps(payload, ensure_ascii=False)) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Assert throughput deltas for reqps and tokensps against baseline" + ) + parser.add_argument("--before", required=True) + parser.add_argument("--after", required=True) + parser.add_argument("--min-delta", type=float, required=True) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + before_report = _load_report(args.before) + after_report = _load_report(args.after) + + before_reqps, before_reqps_source = _reqps(before_report) + after_reqps, after_reqps_source = _reqps(after_report) + before_tokensps, before_tokensps_source = _tokensps(before_report) + after_tokensps, after_tokensps_source = _tokensps(after_report) + + if _is_synthetic_or_blocked(before_report) or _is_synthetic_or_blocked( + after_report + ): + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "synthetic_or_blocked_report_detected", + "before": args.before, + "after": args.after, + "before_reqps": before_reqps, + "after_reqps": after_reqps, + "before_reqps_source": before_reqps_source, + "after_reqps_source": after_reqps_source, + "before_tokensps": before_tokensps, + "after_tokensps": after_tokensps, + "before_tokensps_source": before_tokensps_source, + "after_tokensps_source": after_tokensps_source, + "required_min_delta_percent": float(args.min_delta), + } + ) + raise SystemExit(0) + + before_has_signal = _has_valid_request_or_sample_signal(before_report) + after_has_signal = _has_valid_request_or_sample_signal(after_report) + + if not before_has_signal or not after_has_signal: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "no_requests_or_samples", + "before": args.before, + "after": args.after, + "before_has_signal": before_has_signal, + "after_has_signal": after_has_signal, + "before_reqps": before_reqps, + "after_reqps": after_reqps, + "before_reqps_source": before_reqps_source, + "after_reqps_source": after_reqps_source, + "before_tokensps": before_tokensps, + "after_tokensps": after_tokensps, + "before_tokensps_source": before_tokensps_source, + "after_tokensps_source": after_tokensps_source, + "required_min_delta_percent": float(args.min_delta), + } + ) + raise SystemExit(0) + + if ( + before_reqps is None + or after_reqps is None + or before_tokensps is None + or after_tokensps is None + ): + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "missing_throughput_signal", + "before": args.before, + "after": args.after, + "before_reqps": before_reqps, + "after_reqps": after_reqps, + "before_reqps_source": before_reqps_source, + "after_reqps_source": after_reqps_source, + "before_tokensps": before_tokensps, + "after_tokensps": after_tokensps, + "before_tokensps_source": before_tokensps_source, + "after_tokensps_source": after_tokensps_source, + "required_min_delta_percent": float(args.min_delta), + } + ) + raise SystemExit(0) + + if before_reqps <= 0 or before_tokensps <= 0: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "non_positive_baseline_throughput", + "before": args.before, + "after": args.after, + "before_reqps": before_reqps, + "after_reqps": after_reqps, + "before_reqps_source": before_reqps_source, + "after_reqps_source": after_reqps_source, + "before_tokensps": before_tokensps, + "after_tokensps": after_tokensps, + "before_tokensps_source": before_tokensps_source, + "after_tokensps_source": after_tokensps_source, + "required_min_delta_percent": float(args.min_delta), + } + ) + raise SystemExit(0) + + reqps_delta = ((after_reqps - before_reqps) / before_reqps) * 100.0 + tokensps_delta = ((after_tokensps - before_tokensps) / before_tokensps) * 100.0 + + reqps_ok = reqps_delta >= float(args.min_delta) + tokensps_ok = tokensps_delta >= float(args.min_delta) + + payload: Dict[str, Any] = { + "before": args.before, + "after": args.after, + "before_reqps": round(before_reqps, 6), + "after_reqps": round(after_reqps, 6), + "before_reqps_source": before_reqps_source, + "after_reqps_source": after_reqps_source, + "reqps_delta_percent": round(reqps_delta, 3), + "reqps_gate": "PASS" if reqps_ok else "FAIL", + "before_tokensps": round(before_tokensps, 6), + "after_tokensps": round(after_tokensps, 6), + "before_tokensps_source": before_tokensps_source, + "after_tokensps_source": after_tokensps_source, + "tokensps_delta_percent": round(tokensps_delta, 3), + "tokensps_gate": "PASS" if tokensps_ok else "FAIL", + "required_min_delta_percent": float(args.min_delta), + } + + if reqps_ok and tokensps_ok: + _emit({"status": "PASS", **payload}) + raise SystemExit(0) + + _emit({"status": "FAIL", **payload}) + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/perf/assert_transport.py b/scripts/perf/assert_transport.py new file mode 100644 index 000000000..9ae3dceda --- /dev/null +++ b/scripts/perf/assert_transport.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import argparse +import json +import math +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +def _safe_float(value: Any) -> Optional[float]: + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +def _safe_int(value: Any) -> Optional[int]: + if isinstance(value, bool): + return None + if isinstance(value, int): + return int(value) + if isinstance(value, float): + if value.is_integer(): + return int(value) + return None + if isinstance(value, str): + try: + parsed = float(value) + except ValueError: + return None + if parsed.is_integer(): + return int(parsed) + return None + + +def _percentile(values: List[float], percentile: float) -> float: + if not values: + return 0.0 + ordered = sorted(values) + if len(ordered) == 1: + return float(ordered[0]) + rank = (len(ordered) - 1) * percentile + low = math.floor(rank) + high = math.ceil(rank) + if low == high: + return float(ordered[low]) + weight = rank - low + return float(ordered[low] * (1.0 - weight) + ordered[high] * weight) + + +def _load_report(path: str) -> Dict[str, Any]: + report_path = Path(path) + payload = json.loads(report_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError(f"invalid report JSON object: {path}") + return payload + + +def _is_synthetic_or_blocked(report: Dict[str, Any]) -> bool: + meta = report.get("meta") + if not isinstance(meta, dict): + return False + if bool(meta.get("synthetic")): + return True + blocker = meta.get("blocker") + return isinstance(blocker, str) and blocker.strip() != "" + + +def _has_valid_request_or_sample_signal(report: Dict[str, Any]) -> bool: + summary = report.get("summary") + request_count: Optional[int] = None + if isinstance(summary, dict): + request_count = _safe_int(summary.get("request_count")) + + if request_count is not None and request_count > 0: + return True + + samples = report.get("samples") + if isinstance(samples, list) and any( + isinstance(sample, dict) for sample in samples + ): + return True + + return False + + +def _ttfb_p95_from_summary(report: Dict[str, Any]) -> Optional[float]: + summary = report.get("summary") + if not isinstance(summary, dict): + return None + ttfb_block = summary.get("ttfb_ms") + if not isinstance(ttfb_block, dict): + return None + return _safe_float(ttfb_block.get("p95")) + + +def _ttfb_p95_from_buckets(report: Dict[str, Any]) -> Optional[float]: + buckets = report.get("buckets") + if not isinstance(buckets, dict): + return None + values: List[float] = [] + for bucket in buckets.values(): + if not isinstance(bucket, dict): + continue + ttfb_block = bucket.get("ttfb_ms") + if not isinstance(ttfb_block, dict): + continue + value = _safe_float(ttfb_block.get("p95")) + if value is not None: + values.append(value) + if not values: + return None + return max(values) + + +def _ttfb_p95_from_samples(report: Dict[str, Any]) -> Optional[float]: + payload = report.get("samples") + if not isinstance(payload, list): + return None + values: List[float] = [] + for sample in payload: + if not isinstance(sample, dict): + continue + value = _safe_float(sample.get("ttfb_ms")) + if value is not None: + values.append(value) + if not values: + return None + return _percentile(values, 0.95) + + +def _ttfb_p95(report: Dict[str, Any]) -> Tuple[Optional[float], str]: + for source, getter in ( + ("summary", _ttfb_p95_from_summary), + ("buckets", _ttfb_p95_from_buckets), + ("samples", _ttfb_p95_from_samples), + ): + value = getter(report) + if value is not None: + return value, source + return None, "missing" + + +def _parse_status_code(value: Any) -> Optional[int]: + code = _safe_int(value) + if code is None: + return None + if code < 100 or code > 599: + return None + return code + + +def _status_counts_from_summary(report: Dict[str, Any]) -> Optional[Dict[int, int]]: + summary = report.get("summary") + if not isinstance(summary, dict): + return None + buckets = summary.get("status_buckets") + if not isinstance(buckets, dict): + return None + + counts: Dict[int, int] = {} + for raw_status, raw_count in buckets.items(): + status = _parse_status_code(raw_status) + count = _safe_int(raw_count) + if status is None or count is None or count <= 0: + continue + counts[status] = counts.get(status, 0) + count + return counts if counts else None + + +def _status_counts_from_buckets(report: Dict[str, Any]) -> Optional[Dict[int, int]]: + buckets = report.get("buckets") + if not isinstance(buckets, dict): + return None + + counts: Dict[int, int] = {} + for key, bucket in buckets.items(): + if not isinstance(key, str) or not isinstance(bucket, dict): + continue + parts = key.split("|") + if not parts: + continue + status = _parse_status_code(parts[-1]) + count = _safe_int(bucket.get("count")) + if status is None or count is None or count <= 0: + continue + counts[status] = counts.get(status, 0) + count + return counts if counts else None + + +def _status_counts_from_samples(report: Dict[str, Any]) -> Optional[Dict[int, int]]: + payload = report.get("samples") + if not isinstance(payload, list): + return None + + counts: Dict[int, int] = {} + for sample in payload: + if not isinstance(sample, dict): + continue + status = _parse_status_code(sample.get("status")) + if status is None: + continue + counts[status] = counts.get(status, 0) + 1 + return counts if counts else None + + +def _status_counts(report: Dict[str, Any]) -> Tuple[Optional[Dict[int, int]], str]: + for source, getter in ( + ("summary", _status_counts_from_summary), + ("buckets", _status_counts_from_buckets), + ("samples", _status_counts_from_samples), + ): + counts = getter(report) + if counts is not None: + return counts, source + return None, "missing" + + +def _connection_error_rate(counts: Dict[int, int]) -> Tuple[int, int, float]: + total = sum(counts.values()) + if total <= 0: + return 0, 0, 0.0 + errors = sum(count for status, count in counts.items() if 400 <= status <= 599) + rate = (float(errors) / float(total)) * 100.0 + return total, errors, rate + + +def _emit(payload: Dict[str, Any]) -> None: + print(json.dumps(payload, ensure_ascii=False)) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Assert transport improvements and connection error-rate stability" + ) + parser.add_argument("--before", required=True) + parser.add_argument("--after", required=True) + parser.add_argument("--ttfb-p95-improve", type=float, required=True) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + before_report = _load_report(args.before) + after_report = _load_report(args.after) + + before_ttfb, before_ttfb_source = _ttfb_p95(before_report) + after_ttfb, after_ttfb_source = _ttfb_p95(after_report) + + if _is_synthetic_or_blocked(before_report) or _is_synthetic_or_blocked( + after_report + ): + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "synthetic_or_blocked_report_detected", + "before": args.before, + "after": args.after, + "before_ttfb_p95": before_ttfb, + "after_ttfb_p95": after_ttfb, + "required_ttfb_p95_improve_percent": float(args.ttfb_p95_improve), + } + ) + raise SystemExit(0) + + before_has_signal = _has_valid_request_or_sample_signal(before_report) + after_has_signal = _has_valid_request_or_sample_signal(after_report) + + if not before_has_signal or not after_has_signal: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "no_requests_or_samples", + "before": args.before, + "after": args.after, + "before_has_signal": before_has_signal, + "after_has_signal": after_has_signal, + "before_ttfb_p95": before_ttfb, + "after_ttfb_p95": after_ttfb, + "before_ttfb_source": before_ttfb_source, + "after_ttfb_source": after_ttfb_source, + "required_ttfb_p95_improve_percent": float(args.ttfb_p95_improve), + } + ) + raise SystemExit(0) + + if before_ttfb is None or after_ttfb is None: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "missing_ttfb_p95_signal", + "before": args.before, + "after": args.after, + "before_ttfb_p95": before_ttfb, + "after_ttfb_p95": after_ttfb, + "before_ttfb_source": before_ttfb_source, + "after_ttfb_source": after_ttfb_source, + "required_ttfb_p95_improve_percent": float(args.ttfb_p95_improve), + } + ) + raise SystemExit(0) + + if before_ttfb <= 0: + _emit( + { + "status": "SKIP(BLOCKED)", + "reason": "non_positive_baseline_ttfb_p95", + "before": args.before, + "after": args.after, + "before_ttfb_p95": before_ttfb, + "after_ttfb_p95": after_ttfb, + "required_ttfb_p95_improve_percent": float(args.ttfb_p95_improve), + } + ) + raise SystemExit(0) + + ttfb_improve = ((before_ttfb - after_ttfb) / before_ttfb) * 100.0 + ttfb_ok = ttfb_improve >= float(args.ttfb_p95_improve) + + before_status_counts, before_status_source = _status_counts(before_report) + after_status_counts, after_status_source = _status_counts(after_report) + + connection_gate: Dict[str, Any] = { + "applied": False, + "status": "NOT_APPLICABLE", + "reason": "insufficient_connection_error_signal", + } + + connection_ok = True + + if before_status_counts is not None and after_status_counts is not None: + before_total, before_errors, before_rate = _connection_error_rate( + before_status_counts + ) + after_total, after_errors, after_rate = _connection_error_rate( + after_status_counts + ) + + signal_exists = before_errors > 0 or after_errors > 0 + if signal_exists: + connection_ok = after_rate <= before_rate + connection_gate = { + "applied": True, + "status": "PASS" if connection_ok else "FAIL", + "before_total": before_total, + "before_error_count": before_errors, + "before_error_rate_percent": round(before_rate, 6), + "before_status_source": before_status_source, + "after_total": after_total, + "after_error_count": after_errors, + "after_error_rate_percent": round(after_rate, 6), + "after_status_source": after_status_source, + } + else: + connection_gate = { + "applied": False, + "status": "NOT_APPLICABLE", + "reason": "no_4xx_or_5xx_signal", + "before_total": before_total, + "before_status_source": before_status_source, + "after_total": after_total, + "after_status_source": after_status_source, + } + + base_payload: Dict[str, Any] = { + "before": args.before, + "after": args.after, + "before_ttfb_p95": round(before_ttfb, 6), + "after_ttfb_p95": round(after_ttfb, 6), + "before_ttfb_source": before_ttfb_source, + "after_ttfb_source": after_ttfb_source, + "ttfb_p95_improve_percent": round(ttfb_improve, 3), + "required_ttfb_p95_improve_percent": float(args.ttfb_p95_improve), + "ttfb_gate": "PASS" if ttfb_ok else "FAIL", + "connection_error_rate_gate": connection_gate, + } + + if ttfb_ok and connection_ok: + _emit({"status": "PASS", **base_payload}) + raise SystemExit(0) + + _emit({"status": "FAIL", **base_payload}) + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/perf/bench.py b/scripts/perf/bench.py new file mode 100644 index 000000000..3a0a46112 --- /dev/null +++ b/scripts/perf/bench.py @@ -0,0 +1,556 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import math +import os +import random +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import httpx + + +METRIC_FIELDS = ( + "ttfb_ms", + "first_token_ms", + "full_latency_ms", + "retry_count", + "converter_cpu_ms", +) + + +def _percentile(values: List[float], percentile: float) -> float: + if not values: + return 0.0 + ordered = sorted(values) + if len(ordered) == 1: + return float(ordered[0]) + rank = (len(ordered) - 1) * percentile + low = math.floor(rank) + high = math.ceil(rank) + if low == high: + return float(ordered[low]) + weight = rank - low + return float(ordered[low] * (1.0 - weight) + ordered[high] * weight) + + +def _metric_quantiles(values: List[float]) -> Dict[str, float]: + if not values: + return {"p50": 0.0, "p95": 0.0, "p99": 0.0, "avg": 0.0} + avg = sum(values) / len(values) + return { + "p50": round(_percentile(values, 0.50), 3), + "p95": round(_percentile(values, 0.95), 3), + "p99": round(_percentile(values, 0.99), 3), + "avg": round(avg, 3), + } + + +def _parse_float_header(headers: httpx.Headers, key: str) -> Optional[float]: + raw = headers.get(key) + if raw is None: + return None + try: + return float(raw) + except (TypeError, ValueError): + return None + + +def _parse_int_header(headers: httpx.Headers, key: str) -> Optional[int]: + raw = headers.get(key) + if raw is None: + return None + try: + return int(raw) + except (TypeError, ValueError): + return None + + +def _safe_json_loads(raw: bytes) -> Optional[Dict[str, Any]]: + if not raw: + return None + try: + parsed = json.loads(raw.decode("utf-8")) + except Exception: + return None + return parsed if isinstance(parsed, dict) else None + + +def _extract_total_tokens(route: str, payload: Optional[Dict[str, Any]]) -> int: + if not payload: + return 0 + + if route == "direct-gemini": + usage = payload.get("usageMetadata") + if not isinstance(usage, dict): + nested = payload.get("response") + usage = nested.get("usageMetadata") if isinstance(nested, dict) else None + if isinstance(usage, dict): + for key in ("totalTokenCount", "candidatesTokenCount", "promptTokenCount"): + value = usage.get(key) + if isinstance(value, int): + return value + return 0 + + usage = payload.get("usage") + if isinstance(usage, dict): + input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens") or 0 + output_tokens = ( + usage.get("output_tokens") or usage.get("completion_tokens") or 0 + ) + try: + return int(input_tokens) + int(output_tokens) + except (TypeError, ValueError): + return 0 + return 0 + + +def _build_request_spec( + args: argparse.Namespace, +) -> Tuple[str, Dict[str, str], Dict[str, Any]]: + prompt = "Give a concise answer in two bullet points about reliable latency benchmarking." + + if args.route == "direct-gemini": + url = f"{args.base_url}/v1/models/{args.model}:generateContent" + headers = { + "Content-Type": "application/json", + "x-goog-api-key": args.api_key, + } + body = { + "contents": [ + { + "role": "user", + "parts": [{"text": prompt}], + } + ], + "generationConfig": { + "temperature": args.temperature, + "topP": args.top_p, + "maxOutputTokens": args.max_output_tokens, + }, + "tools": [], + } + return url, headers, body + + url = f"{args.base_url}/v1/messages" + headers = { + "Content-Type": "application/json", + "x-api-key": args.api_key, + "anthropic-version": "2023-06-01", + } + body = { + "model": args.model, + "max_tokens": args.max_output_tokens, + "temperature": args.temperature, + "top_p": args.top_p, + "messages": [{"role": "user", "content": prompt}], + "tools": [], + "stream": False, + } + return url, headers, body + + +def _build_sample( + *, + route: str, + model: str, + status: int, + response_headers: httpx.Headers, + response_body: bytes, + local_started_wall: float, + local_ttfb_ms: float, + local_full_ms: float, +) -> Dict[str, Any]: + t_req_in = _parse_float_header(response_headers, "x-gcli-obs-t-req-in") + if t_req_in is None: + t_req_in = local_started_wall + + t_upstream_send = _parse_float_header( + response_headers, "x-gcli-obs-t-upstream-send" + ) + if t_upstream_send is None: + t_upstream_send = t_req_in + + t_first_byte = _parse_float_header(response_headers, "x-gcli-obs-t-first-byte") + if t_first_byte is None: + t_first_byte = t_req_in + local_ttfb_ms / 1000.0 + + t_first_token = _parse_float_header(response_headers, "x-gcli-obs-t-first-token") + if t_first_token is None: + t_first_token = t_first_byte + + t_done = _parse_float_header(response_headers, "x-gcli-obs-t-done") + if t_done is None: + t_done = t_req_in + local_full_ms / 1000.0 + + retry_count = _parse_int_header(response_headers, "x-gcli-obs-retry-count") + if retry_count is None: + retry_count = 0 + + retry_sleep_ms = _parse_float_header(response_headers, "x-gcli-obs-retry-sleep-ms") + if retry_sleep_ms is None: + retry_sleep_ms = 0.0 + + converter_cpu_ms = _parse_float_header( + response_headers, "x-gcli-obs-converter-cpu-ms" + ) + if converter_cpu_ms is None: + converter_cpu_ms = 0.0 + + payload = _safe_json_loads(response_body) + total_tokens = _extract_total_tokens(route, payload) + + return { + "route": route, + "model": model, + "stream": False, + "status": int(status), + "t_req_in": round(t_req_in, 6), + "t_upstream_send": round(t_upstream_send, 6), + "t_first_byte": round(t_first_byte, 6), + "t_first_token": round(t_first_token, 6), + "t_done": round(t_done, 6), + "retry_count": int(retry_count), + "retry_sleep_ms": round(retry_sleep_ms, 3), + "converter_cpu_ms": round(converter_cpu_ms, 3), + "ttfb_ms": round(max(0.0, (t_first_byte - t_req_in) * 1000.0), 3), + "first_token_ms": round(max(0.0, (t_first_token - t_req_in) * 1000.0), 3), + "full_latency_ms": round(max(0.0, (t_done - t_req_in) * 1000.0), 3), + "total_tokens": int(total_tokens), + } + + +async def _probe_connectivity( + client: httpx.AsyncClient, + url: str, + headers: Dict[str, str], + body: Dict[str, Any], +) -> None: + try: + response = await client.post(url, headers=headers, json=body) + await response.aclose() + except Exception as exc: # pragma: no cover - runtime/network dependent + raise RuntimeError( + f"upstream unreachable: {type(exc).__name__}: {exc}" + ) from exc + + +async def _run_one_request( + client: httpx.AsyncClient, + route: str, + model: str, + url: str, + headers: Dict[str, str], + body: Dict[str, Any], +) -> Dict[str, Any]: + local_started_wall = time.time() + local_started_perf = time.perf_counter() + first_byte_perf: Optional[float] = None + chunks: List[bytes] = [] + + async with client.stream("POST", url, headers=headers, json=body) as response: + async for chunk in response.aiter_bytes(): + if first_byte_perf is None: + first_byte_perf = time.perf_counter() + if chunk: + chunks.append(chunk) + + local_done_perf = time.perf_counter() + if first_byte_perf is None: + first_byte_perf = local_done_perf + + local_ttfb_ms = max(0.0, (first_byte_perf - local_started_perf) * 1000.0) + local_full_ms = max(0.0, (local_done_perf - local_started_perf) * 1000.0) + + return _build_sample( + route=route, + model=model, + status=response.status_code, + response_headers=response.headers, + response_body=b"".join(chunks), + local_started_wall=local_started_wall, + local_ttfb_ms=local_ttfb_ms, + local_full_ms=local_full_ms, + ) + + +async def _run_live_benchmark( + args: argparse.Namespace, +) -> Tuple[List[Dict[str, Any]], float]: + url, headers, body = _build_request_spec(args) + timeout = httpx.Timeout(connect=5.0, read=60.0, write=30.0, pool=30.0) + samples: List[Dict[str, Any]] = [] + + started_at = time.perf_counter() + end_at = started_at + args.duration + + async with httpx.AsyncClient(timeout=timeout) as client: + await _probe_connectivity(client, url, headers, body) + + async def worker() -> None: + while time.perf_counter() < end_at: + sample = await _run_one_request( + client=client, + route=args.route, + model=args.model, + url=url, + headers=headers, + body=body, + ) + samples.append(sample) + + workers = [ + asyncio.create_task(worker()) for _ in range(max(1, args.concurrency)) + ] + await asyncio.gather(*workers) + + elapsed = max(0.001, time.perf_counter() - started_at) + return samples, elapsed + + +def _build_synthetic_samples( + args: argparse.Namespace, +) -> Tuple[List[Dict[str, Any]], float]: + seed_input = f"{args.route}|{args.model}|{args.duration}|{args.concurrency}" + seed = sum(ord(ch) for ch in seed_input) + rng = random.Random(seed) + + count = max(80, min(400, max(1, args.concurrency) * 12)) + base = time.time() + + if args.route == "direct-gemini": + ttfb_center, ttfb_jitter = 140.0, 25.0 + token_center, token_jitter = 150.0, 28.0 + full_center, full_jitter = 820.0, 120.0 + converter_center = 0.0 + else: + ttfb_center, ttfb_jitter = 200.0, 35.0 + token_center, token_jitter = 235.0, 40.0 + full_center, full_jitter = 980.0, 160.0 + converter_center = 26.0 + + samples: List[Dict[str, Any]] = [] + for idx in range(count): + t_req_in = base + idx * 0.013 + t_upstream_send = t_req_in + max(0.0, rng.gauss(0.008, 0.002)) + ttfb_ms = max(30.0, rng.gauss(ttfb_center, ttfb_jitter)) + first_token_ms = max(ttfb_ms, rng.gauss(token_center, token_jitter)) + full_latency_ms = max( + first_token_ms + 20.0, rng.gauss(full_center, full_jitter) + ) + + retry_count = 0 + retry_sleep_ms = 0.0 + status = 200 + if rng.random() < 0.08: + retry_count = 1 if rng.random() < 0.9 else 2 + retry_sleep_ms = retry_count * max(250.0, rng.gauss(900.0, 180.0)) + if rng.random() < 0.02: + status = 429 + + converter_cpu_ms = max(0.0, rng.gauss(converter_center, 7.0)) + if args.route == "direct-gemini": + converter_cpu_ms = 0.0 + + tokens = max(64, int(rng.gauss(720.0, 110.0))) + t_first_byte = t_req_in + ttfb_ms / 1000.0 + t_first_token = t_req_in + first_token_ms / 1000.0 + t_done = t_req_in + full_latency_ms / 1000.0 + + samples.append( + { + "route": args.route, + "model": args.model, + "stream": False, + "status": status, + "t_req_in": round(t_req_in, 6), + "t_upstream_send": round(t_upstream_send, 6), + "t_first_byte": round(t_first_byte, 6), + "t_first_token": round(t_first_token, 6), + "t_done": round(t_done, 6), + "retry_count": int(retry_count), + "retry_sleep_ms": round(retry_sleep_ms, 3), + "converter_cpu_ms": round(converter_cpu_ms, 3), + "ttfb_ms": round(ttfb_ms, 3), + "first_token_ms": round(first_token_ms, 3), + "full_latency_ms": round(full_latency_ms, 3), + "total_tokens": int(tokens), + } + ) + + synthetic_elapsed = max(1.0, min(float(args.duration), 30.0)) + return samples, synthetic_elapsed + + +def _build_bucket_summary(samples: List[Dict[str, Any]]) -> Dict[str, Any]: + grouped: Dict[str, List[Dict[str, Any]]] = {} + for sample in samples: + key = ( + f"{sample['route']}|{sample['model']}|" + f"{str(sample['stream']).lower()}|{sample['status']}" + ) + grouped.setdefault(key, []).append(sample) + + summary: Dict[str, Any] = {} + for key, group in grouped.items(): + bucket = { + "count": len(group), + "ttfb_ms": _metric_quantiles([float(s["ttfb_ms"]) for s in group]), + "first_token_ms": _metric_quantiles( + [float(s["first_token_ms"]) for s in group] + ), + "full_latency_ms": _metric_quantiles( + [float(s["full_latency_ms"]) for s in group] + ), + "retry_count": _metric_quantiles([float(s["retry_count"]) for s in group]), + "converter_cpu_ms": _metric_quantiles( + [float(s["converter_cpu_ms"]) for s in group] + ), + } + summary[key] = bucket + return summary + + +def _build_report( + args: argparse.Namespace, + samples: List[Dict[str, Any]], + elapsed: float, + synthetic: bool, + blocker: Optional[str], +) -> Dict[str, Any]: + elapsed = max(0.001, elapsed) + ttfb_values = [float(s["ttfb_ms"]) for s in samples] + first_token_values = [float(s["first_token_ms"]) for s in samples] + full_values = [float(s["full_latency_ms"]) for s in samples] + retry_values = [float(s["retry_count"]) for s in samples] + converter_values = [float(s["converter_cpu_ms"]) for s in samples] + total_tokens = sum(int(s.get("total_tokens", 0)) for s in samples) + success_count = sum(1 for s in samples if int(s.get("status", 0)) == 200) + + status_buckets: Dict[str, int] = {} + for sample in samples: + status = str(sample.get("status", "unknown")) + status_buckets[status] = status_buckets.get(status, 0) + 1 + + report = { + "meta": { + "route": args.route, + "model": args.model, + "stream": False, + "duration": args.duration, + "concurrency": args.concurrency, + "base_url": args.base_url, + "params": { + "temperature": args.temperature, + "topP": args.top_p, + "maxOutputTokens": args.max_output_tokens, + "tools": [], + }, + "synthetic": synthetic, + "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + }, + "summary": { + "request_count": len(samples), + "success_count": success_count, + "status_buckets": status_buckets, + "ttfb_ms": _metric_quantiles(ttfb_values), + "first_token_ms": _metric_quantiles(first_token_values), + "full_latency_ms": _metric_quantiles(full_values), + "retry_count": _metric_quantiles(retry_values), + "converter_cpu_ms": _metric_quantiles(converter_values), + "reqps": round(len(samples) / elapsed, 3), + "tokensps": round(total_tokens / elapsed, 3), + }, + "buckets": _build_bucket_summary(samples), + "samples": samples, + } + + if blocker: + report["meta"]["blocker"] = blocker + + return report + + +def _default_api_key() -> str: + for env_key in ("BENCH_API_KEY", "API_PASSWORD", "PASSWORD"): + value = os.getenv(env_key) + if value: + return value + return "pwd" + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate baseline perf report") + parser.add_argument( + "--route", + required=True, + choices=["direct-gemini", "claude-compat"], + ) + parser.add_argument("--model", required=True) + parser.add_argument("--duration", type=int, default=60) + parser.add_argument("--concurrency", type=int, default=8) + parser.add_argument("--out", required=True) + parser.add_argument( + "--base-url", default=os.getenv("BENCH_BASE_URL", "http://127.0.0.1:7861") + ) + parser.add_argument("--api-key", default=_default_api_key()) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top-p", type=float, default=0.95) + parser.add_argument("--max-output-tokens", type=int, default=256) + parser.add_argument( + "--no-synthetic-on-failure", + action="store_true", + help="Fail when endpoint is unreachable instead of writing deterministic synthetic baseline.", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + blocker: Optional[str] = None + + try: + samples, elapsed = asyncio.run(_run_live_benchmark(args)) + synthetic = False + except Exception as exc: # pragma: no cover - runtime/network dependent + blocker = str(exc) + if args.no_synthetic_on_failure: + raise SystemExit(f"bench failed: {blocker}") from exc + samples, elapsed = _build_synthetic_samples(args) + synthetic = True + + report = _build_report( + args=args, + samples=samples, + elapsed=elapsed, + synthetic=synthetic, + blocker=blocker, + ) + + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text( + json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8" + ) + + stdout_report = { + "route": report["meta"]["route"], + "model": report["meta"]["model"], + "synthetic": report["meta"]["synthetic"], + "ttfb_ms": report["summary"]["ttfb_ms"], + "first_token_ms": report["summary"]["first_token_ms"], + "full_latency_ms": report["summary"]["full_latency_ms"], + "retry_count": report["summary"]["retry_count"], + "converter_cpu_ms": report["summary"]["converter_cpu_ms"], + "reqps": report["summary"]["reqps"], + "tokensps": report["summary"]["tokensps"], + "out": str(out_path), + } + print(json.dumps(stdout_report, ensure_ascii=False)) + + +if __name__ == "__main__": + main() diff --git a/scripts/perf/rollout_guard.py b/scripts/perf/rollout_guard.py new file mode 100644 index 000000000..342587063 --- /dev/null +++ b/scripts/perf/rollout_guard.py @@ -0,0 +1,676 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from config import ( # noqa: E402 + get_rollback_trigger_latency_p95_ms, + get_rollback_trigger_quality_drop_pct, + get_rollback_trigger_throughput_drop_pct, + get_rollout_stage_percent, + reload_config, +) +from scripts.eval.assert_quality import ( # noqa: E402 + _extract_quality_score, + _is_synthetic_or_blocked as _quality_is_synthetic_or_blocked, + _load_report as _quality_load_report, +) +from scripts.perf.assert_latency import ( # noqa: E402 + _has_valid_request_or_sample_signal as _latency_has_valid_request_or_sample_signal, + _is_synthetic_or_blocked as _latency_is_synthetic_or_blocked, + _load_report as _load_perf_report_or_raise, + _metric_p95, +) +from scripts.perf.assert_throughput import ( # noqa: E402 + _has_valid_request_or_sample_signal as _throughput_has_valid_request_or_sample_signal, + _is_synthetic_or_blocked as _throughput_is_synthetic_or_blocked, + _reqps, + _tokensps, +) +from src.storage_adapter import get_storage_adapter # noqa: E402 + +STAGE_LADDER: Tuple[int, ...] = (5, 20, 50, 100) +LATENCY_POLICY_MODE_ABSOLUTE_P95_CAP = "absolute_p95_cap" +LATENCY_POLICY_MODE_RELATIVE_FULL_P95_IMPROVE = "relative_full_p95_improve" +LATENCY_POLICY_MODES: Tuple[str, ...] = ( + LATENCY_POLICY_MODE_ABSOLUTE_P95_CAP, + LATENCY_POLICY_MODE_RELATIVE_FULL_P95_IMPROVE, +) + + +def _emit(payload: Dict[str, Any]) -> None: + print(json.dumps(payload, ensure_ascii=False)) + + +def _non_negative_float(value: str) -> float: + parsed = float(value) + if parsed < 0: + raise argparse.ArgumentTypeError("value must be >= 0") + return parsed + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Rollout/rollback decision guard based on perf + quality artifacts" + ) + parser.add_argument("--before-perf", required=True) + parser.add_argument("--after-perf", required=True) + parser.add_argument("--baseline-quality", required=True) + parser.add_argument("--candidate-quality", required=True) + parser.add_argument("--rollout-stage-percent", type=int, choices=STAGE_LADDER) + parser.add_argument("--rollback-trigger-latency-p95-ms", type=_non_negative_float) + parser.add_argument( + "--latency-policy-mode", + choices=LATENCY_POLICY_MODES, + help="Latency rollback policy mode", + ) + parser.add_argument( + "--rollback-trigger-latency-p95-improve-pct", type=_non_negative_float + ) + parser.add_argument( + "--rollback-trigger-throughput-drop-pct", type=_non_negative_float + ) + parser.add_argument("--rollback-trigger-quality-drop-pct", type=_non_negative_float) + parser.add_argument( + "--apply", + action="store_true", + help="Persist target rollout stage + stage-mapped feature flags", + ) + return parser.parse_args() + + +def get_next_stage_percent(current_stage_percent: int) -> int: + if current_stage_percent not in STAGE_LADDER: + raise ValueError(f"unsupported stage percent: {current_stage_percent}") + + stage_index = STAGE_LADDER.index(current_stage_percent) + if stage_index >= len(STAGE_LADDER) - 1: + return STAGE_LADDER[-1] + return STAGE_LADDER[stage_index + 1] + + +def get_previous_stage_percent(current_stage_percent: int) -> int: + if current_stage_percent not in STAGE_LADDER: + raise ValueError(f"unsupported stage percent: {current_stage_percent}") + + stage_index = STAGE_LADDER.index(current_stage_percent) + if stage_index <= 0: + return STAGE_LADDER[0] + return STAGE_LADDER[stage_index - 1] + + +def stage_percent_to_feature_flags(stage_percent: int) -> Dict[str, bool]: + if stage_percent not in STAGE_LADDER: + raise ValueError(f"unsupported stage percent: {stage_percent}") + + return { + "ff_retry_policy_v2": stage_percent >= 5, + "ff_http2_pool_tuning": stage_percent >= 20, + "ff_converter_fast_path": stage_percent >= 50, + "ff_preview_credential_scheduler_v2": stage_percent >= 100, + } + + +def _load_perf_report(path: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + try: + payload = _load_perf_report_or_raise(path) + except FileNotFoundError: + return None, f"missing_file:{path}" + except Exception as exc: + return ( + None, + f"invalid_or_unreadable_perf_artifact:{path}:{exc.__class__.__name__}", + ) + return payload, None + + +def _latency_gate( + *, + before_report: Dict[str, Any], + after_report: Dict[str, Any], + latency_threshold_ms: float, + latency_policy_mode: str, + latency_improve_pct_threshold: float, +) -> Dict[str, Any]: + threshold_latency_p95_ms = float(latency_threshold_ms) + required_full_latency_p95_improve_percent = float(latency_improve_pct_threshold) + + if latency_policy_mode == LATENCY_POLICY_MODE_RELATIVE_FULL_P95_IMPROVE: + if _latency_is_synthetic_or_blocked( + before_report + ) or _latency_is_synthetic_or_blocked(after_report): + return { + "status": "BLOCKED", + "reason": "synthetic_or_blocked_report_detected", + "latency_policy_mode": latency_policy_mode, + "threshold_latency_p95_ms": threshold_latency_p95_ms, + "required_full_latency_p95_improve_percent": required_full_latency_p95_improve_percent, + } + + before_has_signal = _latency_has_valid_request_or_sample_signal(before_report) + after_has_signal = _latency_has_valid_request_or_sample_signal(after_report) + if not before_has_signal or not after_has_signal: + return { + "status": "BLOCKED", + "reason": "no_requests_or_samples", + "before_has_signal": before_has_signal, + "after_has_signal": after_has_signal, + "latency_policy_mode": latency_policy_mode, + "threshold_latency_p95_ms": threshold_latency_p95_ms, + "required_full_latency_p95_improve_percent": required_full_latency_p95_improve_percent, + } + + before_full_latency_p95, before_full_latency_source = _metric_p95( + before_report, "full_latency_ms" + ) + after_full_latency_p95, after_full_latency_source = _metric_p95( + after_report, "full_latency_ms" + ) + if before_full_latency_p95 is None or after_full_latency_p95 is None: + return { + "status": "BLOCKED", + "reason": "missing_latency_p95_signal", + "before_full_latency_source": before_full_latency_source, + "after_full_latency_source": after_full_latency_source, + "latency_policy_mode": latency_policy_mode, + "threshold_latency_p95_ms": threshold_latency_p95_ms, + "required_full_latency_p95_improve_percent": required_full_latency_p95_improve_percent, + } + + before_full_latency_p95_float = float(before_full_latency_p95) + after_full_latency_p95_float = float(after_full_latency_p95) + if before_full_latency_p95_float <= 0: + return { + "status": "BLOCKED", + "reason": "non_positive_baseline_latency_p95", + "before_full_latency_p95_ms": round(before_full_latency_p95_float, 6), + "after_full_latency_p95_ms": round(after_full_latency_p95_float, 6), + "latency_policy_mode": latency_policy_mode, + "threshold_latency_p95_ms": threshold_latency_p95_ms, + "required_full_latency_p95_improve_percent": required_full_latency_p95_improve_percent, + } + + full_latency_p95_improve_percent = ( + (before_full_latency_p95_float - after_full_latency_p95_float) + / before_full_latency_p95_float + ) * 100.0 + passed = ( + full_latency_p95_improve_percent + >= required_full_latency_p95_improve_percent + ) + return { + "status": "PASS" if passed else "FAIL", + "reason": "full_latency_p95_improve_within_threshold" + if passed + else "full_latency_p95_improve_below_threshold", + "before_full_latency_p95_ms": round(before_full_latency_p95_float, 6), + "after_full_latency_p95_ms": round(after_full_latency_p95_float, 6), + "before_full_latency_source": before_full_latency_source, + "after_full_latency_source": after_full_latency_source, + "full_latency_p95_improve_percent": round( + full_latency_p95_improve_percent, 6 + ), + "latency_policy_mode": latency_policy_mode, + "threshold_latency_p95_ms": threshold_latency_p95_ms, + "required_full_latency_p95_improve_percent": required_full_latency_p95_improve_percent, + } + + if _latency_is_synthetic_or_blocked(after_report): + return { + "status": "BLOCKED", + "reason": "synthetic_or_blocked_report_detected", + "latency_policy_mode": latency_policy_mode, + "threshold_latency_p95_ms": threshold_latency_p95_ms, + "required_full_latency_p95_improve_percent": required_full_latency_p95_improve_percent, + } + + if not _latency_has_valid_request_or_sample_signal(after_report): + return { + "status": "BLOCKED", + "reason": "no_requests_or_samples", + "latency_policy_mode": latency_policy_mode, + "threshold_latency_p95_ms": threshold_latency_p95_ms, + "required_full_latency_p95_improve_percent": required_full_latency_p95_improve_percent, + } + + after_full_latency_p95, after_full_latency_source = _metric_p95( + after_report, "full_latency_ms" + ) + if after_full_latency_p95 is None: + return { + "status": "BLOCKED", + "reason": "missing_latency_p95_signal", + "after_full_latency_source": after_full_latency_source, + "latency_policy_mode": latency_policy_mode, + "threshold_latency_p95_ms": threshold_latency_p95_ms, + "required_full_latency_p95_improve_percent": required_full_latency_p95_improve_percent, + } + + passed = float(after_full_latency_p95) <= threshold_latency_p95_ms + return { + "status": "PASS" if passed else "FAIL", + "reason": "latency_p95_within_threshold" + if passed + else "latency_p95_exceeds_threshold", + "after_full_latency_p95_ms": round(float(after_full_latency_p95), 6), + "after_full_latency_source": after_full_latency_source, + "latency_policy_mode": latency_policy_mode, + "threshold_latency_p95_ms": threshold_latency_p95_ms, + "required_full_latency_p95_improve_percent": required_full_latency_p95_improve_percent, + } + + +def _throughput_gate( + before_report: Dict[str, Any], + after_report: Dict[str, Any], + throughput_drop_pct_threshold: float, +) -> Dict[str, Any]: + if _throughput_is_synthetic_or_blocked( + before_report + ) or _throughput_is_synthetic_or_blocked(after_report): + return { + "status": "BLOCKED", + "reason": "synthetic_or_blocked_report_detected", + } + + before_has_signal = _throughput_has_valid_request_or_sample_signal(before_report) + after_has_signal = _throughput_has_valid_request_or_sample_signal(after_report) + if not before_has_signal or not after_has_signal: + return { + "status": "BLOCKED", + "reason": "no_requests_or_samples", + "before_has_signal": before_has_signal, + "after_has_signal": after_has_signal, + } + + before_reqps, before_reqps_source = _reqps(before_report) + after_reqps, after_reqps_source = _reqps(after_report) + before_tokensps, before_tokensps_source = _tokensps(before_report) + after_tokensps, after_tokensps_source = _tokensps(after_report) + + if ( + before_reqps is None + or after_reqps is None + or before_tokensps is None + or after_tokensps is None + ): + return { + "status": "BLOCKED", + "reason": "missing_throughput_signal", + "before_reqps": before_reqps, + "after_reqps": after_reqps, + "before_reqps_source": before_reqps_source, + "after_reqps_source": after_reqps_source, + "before_tokensps": before_tokensps, + "after_tokensps": after_tokensps, + "before_tokensps_source": before_tokensps_source, + "after_tokensps_source": after_tokensps_source, + } + + if float(before_reqps) <= 0 or float(before_tokensps) <= 0: + return { + "status": "BLOCKED", + "reason": "non_positive_baseline_throughput", + "before_reqps": round(float(before_reqps), 6), + "before_tokensps": round(float(before_tokensps), 6), + } + + reqps_drop_pct = ( + (float(before_reqps) - float(after_reqps)) / float(before_reqps) + ) * 100.0 + tokensps_drop_pct = ( + (float(before_tokensps) - float(after_tokensps)) / float(before_tokensps) + ) * 100.0 + threshold = float(throughput_drop_pct_threshold) + + reqps_passed = reqps_drop_pct <= threshold + tokensps_passed = tokensps_drop_pct <= threshold + passed = reqps_passed and tokensps_passed + + failed_metrics = [] + if not reqps_passed: + failed_metrics.append("reqps") + if not tokensps_passed: + failed_metrics.append("tokensps") + + return { + "status": "PASS" if passed else "FAIL", + "reason": "throughput_drop_within_threshold" + if passed + else "throughput_drop_exceeds_threshold", + "before_reqps": round(float(before_reqps), 6), + "after_reqps": round(float(after_reqps), 6), + "before_reqps_source": before_reqps_source, + "after_reqps_source": after_reqps_source, + "reqps_drop_pct": round(reqps_drop_pct, 6), + "before_tokensps": round(float(before_tokensps), 6), + "after_tokensps": round(float(after_tokensps), 6), + "before_tokensps_source": before_tokensps_source, + "after_tokensps_source": after_tokensps_source, + "tokensps_drop_pct": round(tokensps_drop_pct, 6), + "throughput_drop_pct_threshold": threshold, + "failed_metrics": failed_metrics, + } + + +def _quality_gate( + baseline_quality_path: str, + candidate_quality_path: str, + quality_drop_pct_threshold: float, +) -> Dict[str, Any]: + baseline_report, baseline_error = _quality_load_report(baseline_quality_path) + candidate_report, candidate_error = _quality_load_report(candidate_quality_path) + if baseline_error or candidate_error: + return { + "status": "BLOCKED", + "reason": "missing_or_unreadable_quality_artifact", + "baseline_error": baseline_error, + "candidate_error": candidate_error, + } + + assert baseline_report is not None + assert candidate_report is not None + + if _quality_is_synthetic_or_blocked( + baseline_report + ) or _quality_is_synthetic_or_blocked(candidate_report): + return { + "status": "BLOCKED", + "reason": "synthetic_or_blocked_report_detected", + } + + baseline_score, baseline_source, baseline_details = _extract_quality_score( + baseline_report + ) + candidate_score, candidate_source, candidate_details = _extract_quality_score( + candidate_report + ) + + if baseline_score is None or candidate_score is None: + return { + "status": "BLOCKED", + "reason": "missing_quality_signal", + "baseline_score": baseline_score, + "candidate_score": candidate_score, + "baseline_source": baseline_source, + "candidate_source": candidate_source, + "baseline_details": baseline_details, + "candidate_details": candidate_details, + } + + baseline_score_float = float(baseline_score) + candidate_score_float = float(candidate_score) + score_drop = baseline_score_float - candidate_score_float + allowed_drop = max(baseline_score_float, 0.0) * ( + float(quality_drop_pct_threshold) / 100.0 + ) + passed = score_drop <= allowed_drop + + return { + "status": "PASS" if passed else "FAIL", + "reason": "quality_drop_within_threshold" + if passed + else "quality_drop_exceeds_threshold", + "baseline_score": round(baseline_score_float, 6), + "candidate_score": round(candidate_score_float, 6), + "score_drop": round(score_drop, 6), + "quality_drop_pct_threshold": float(quality_drop_pct_threshold), + "quality_drop_max_allowed": round(allowed_drop, 6), + "baseline_source": baseline_source, + "candidate_source": candidate_source, + "baseline_details": baseline_details, + "candidate_details": candidate_details, + } + + +async def _resolve_effective_thresholds( + rollout_stage_percent: Optional[int], + rollback_trigger_latency_p95_ms: Optional[float], + latency_policy_mode: Optional[str], + rollback_trigger_latency_p95_improve_pct: Optional[float], + rollback_trigger_throughput_drop_pct: Optional[float], + rollback_trigger_quality_drop_pct: Optional[float], +) -> Dict[str, Any]: + current_stage_percent = ( + int(rollout_stage_percent) + if rollout_stage_percent is not None + else int(await get_rollout_stage_percent()) + ) + latency_threshold_ms = ( + float(rollback_trigger_latency_p95_ms) + if rollback_trigger_latency_p95_ms is not None + else float(await get_rollback_trigger_latency_p95_ms()) + ) + resolved_latency_policy_mode = ( + str(latency_policy_mode) + if latency_policy_mode is not None + else LATENCY_POLICY_MODE_ABSOLUTE_P95_CAP + ) + if resolved_latency_policy_mode not in LATENCY_POLICY_MODES: + raise ValueError( + f"unsupported latency policy mode: {resolved_latency_policy_mode}" + ) + latency_improve_pct_threshold = ( + float(rollback_trigger_latency_p95_improve_pct) + if rollback_trigger_latency_p95_improve_pct is not None + else 0.0 + ) + throughput_drop_pct_threshold = ( + float(rollback_trigger_throughput_drop_pct) + if rollback_trigger_throughput_drop_pct is not None + else float(await get_rollback_trigger_throughput_drop_pct()) + ) + quality_drop_pct_threshold = ( + float(rollback_trigger_quality_drop_pct) + if rollback_trigger_quality_drop_pct is not None + else float(await get_rollback_trigger_quality_drop_pct()) + ) + + return { + "current_stage_percent": current_stage_percent, + "latency_threshold_ms": latency_threshold_ms, + "latency_policy_mode": resolved_latency_policy_mode, + "latency_improve_pct_threshold": latency_improve_pct_threshold, + "throughput_drop_pct_threshold": throughput_drop_pct_threshold, + "quality_drop_pct_threshold": quality_drop_pct_threshold, + } + + +async def _persist_rollout_stage_with_feature_flags( + target_stage_percent: int, +) -> Dict[str, Any]: + persisted_config = { + "rollout_stage_percent": int(target_stage_percent), + **stage_percent_to_feature_flags(target_stage_percent), + } + + storage_adapter = await get_storage_adapter() + for key, value in persisted_config.items(): + await storage_adapter.set_config(key, value) + await reload_config() + return persisted_config + + +async def evaluate_rollout_decision( + *, + before_perf_path: str, + after_perf_path: str, + baseline_quality_path: str, + candidate_quality_path: str, + rollout_stage_percent: Optional[int] = None, + rollback_trigger_latency_p95_ms: Optional[float] = None, + latency_policy_mode: Optional[str] = None, + rollback_trigger_latency_p95_improve_pct: Optional[float] = None, + rollback_trigger_throughput_drop_pct: Optional[float] = None, + rollback_trigger_quality_drop_pct: Optional[float] = None, + apply: bool = False, +) -> Dict[str, Any]: + thresholds = await _resolve_effective_thresholds( + rollout_stage_percent=rollout_stage_percent, + rollback_trigger_latency_p95_ms=rollback_trigger_latency_p95_ms, + latency_policy_mode=latency_policy_mode, + rollback_trigger_latency_p95_improve_pct=rollback_trigger_latency_p95_improve_pct, + rollback_trigger_throughput_drop_pct=rollback_trigger_throughput_drop_pct, + rollback_trigger_quality_drop_pct=rollback_trigger_quality_drop_pct, + ) + + current_stage_percent = int(thresholds["current_stage_percent"]) + latency_threshold_ms = float(thresholds["latency_threshold_ms"]) + resolved_latency_policy_mode = str(thresholds["latency_policy_mode"]) + latency_improve_pct_threshold = float(thresholds["latency_improve_pct_threshold"]) + throughput_drop_pct_threshold = float(thresholds["throughput_drop_pct_threshold"]) + quality_drop_pct_threshold = float(thresholds["quality_drop_pct_threshold"]) + + next_stage_percent = get_next_stage_percent(current_stage_percent) + rollback_stage_percent = get_previous_stage_percent(current_stage_percent) + + before_perf_report, before_perf_error = _load_perf_report(before_perf_path) + after_perf_report, after_perf_error = _load_perf_report(after_perf_path) + + if before_perf_error or after_perf_error: + latency_gate = { + "status": "BLOCKED", + "reason": "missing_or_unreadable_perf_artifact", + "before_perf_error": before_perf_error, + "after_perf_error": after_perf_error, + "latency_policy_mode": resolved_latency_policy_mode, + "threshold_latency_p95_ms": latency_threshold_ms, + "required_full_latency_p95_improve_percent": latency_improve_pct_threshold, + } + throughput_gate = { + "status": "BLOCKED", + "reason": "missing_or_unreadable_perf_artifact", + "before_perf_error": before_perf_error, + "after_perf_error": after_perf_error, + } + else: + assert before_perf_report is not None + assert after_perf_report is not None + latency_gate = _latency_gate( + before_report=before_perf_report, + after_report=after_perf_report, + latency_threshold_ms=latency_threshold_ms, + latency_policy_mode=resolved_latency_policy_mode, + latency_improve_pct_threshold=latency_improve_pct_threshold, + ) + throughput_gate = _throughput_gate( + before_report=before_perf_report, + after_report=after_perf_report, + throughput_drop_pct_threshold=throughput_drop_pct_threshold, + ) + + quality_gate = _quality_gate( + baseline_quality_path=baseline_quality_path, + candidate_quality_path=candidate_quality_path, + quality_drop_pct_threshold=quality_drop_pct_threshold, + ) + + gates = { + "latency": latency_gate, + "throughput": throughput_gate, + "quality": quality_gate, + } + blocked_gates = [ + name for name, gate in gates.items() if gate.get("status") == "BLOCKED" + ] + failed_gates = [ + name for name, gate in gates.items() if gate.get("status") == "FAIL" + ] + blocked_reasons = [ + f"{name}:{gate.get('reason', 'unknown')}" + for name, gate in gates.items() + if gate.get("status") == "BLOCKED" + ] + + if blocked_gates: + decision = "HOLD_BLOCKED" + target_stage_percent = current_stage_percent + elif failed_gates: + decision = "ROLLBACK" + target_stage_percent = rollback_stage_percent + else: + decision = "PROMOTE" + target_stage_percent = next_stage_percent + + apply_requested = bool(apply) + dry_run = not apply_requested + applied = False + apply_error = None + apply_skipped_reason = None + persisted_config: Dict[str, Any] = {} + + if apply_requested: + if decision == "HOLD_BLOCKED": + apply_skipped_reason = "decision_hold_blocked" + else: + try: + persisted_config = await _persist_rollout_stage_with_feature_flags( + target_stage_percent + ) + applied = True + except Exception as exc: + apply_error = f"{exc.__class__.__name__}:{exc}" + + return { + "decision": decision, + "current_stage_percent": current_stage_percent, + "next_stage_percent": next_stage_percent, + "rollback_stage_percent": rollback_stage_percent, + "target_stage_percent": target_stage_percent, + "feature_flags_for_target_stage": stage_percent_to_feature_flags( + target_stage_percent + ), + "thresholds": { + "latency_policy_mode": resolved_latency_policy_mode, + "rollback_trigger_latency_p95_ms": latency_threshold_ms, + "rollback_trigger_latency_p95_improve_pct": latency_improve_pct_threshold, + "rollback_trigger_throughput_drop_pct": throughput_drop_pct_threshold, + "rollback_trigger_quality_drop_pct": quality_drop_pct_threshold, + }, + "blocked_gates": blocked_gates, + "blocked_reason": ";".join(blocked_reasons) if blocked_reasons else None, + "failed_gates": failed_gates, + "gates": gates, + "before_perf": before_perf_path, + "after_perf": after_perf_path, + "baseline_quality": baseline_quality_path, + "candidate_quality": candidate_quality_path, + "apply_requested": apply_requested, + "dry_run": dry_run, + "applied": applied, + "apply_skipped_reason": apply_skipped_reason, + "apply_error": apply_error, + "persisted_config": persisted_config, + } + + +def main() -> None: + args = _parse_args() + result = asyncio.run( + evaluate_rollout_decision( + before_perf_path=args.before_perf, + after_perf_path=args.after_perf, + baseline_quality_path=args.baseline_quality, + candidate_quality_path=args.candidate_quality, + rollout_stage_percent=args.rollout_stage_percent, + rollback_trigger_latency_p95_ms=args.rollback_trigger_latency_p95_ms, + latency_policy_mode=args.latency_policy_mode, + rollback_trigger_latency_p95_improve_pct=args.rollback_trigger_latency_p95_improve_pct, + rollback_trigger_throughput_drop_pct=args.rollback_trigger_throughput_drop_pct, + rollback_trigger_quality_drop_pct=args.rollback_trigger_quality_drop_pct, + apply=args.apply, + ) + ) + _emit(result) + + +if __name__ == "__main__": + main() diff --git a/src/api/geminicli.py b/src/api/geminicli.py index a10e70234..5e3d622dd 100644 --- a/src/api/geminicli.py +++ b/src/api/geminicli.py @@ -21,7 +21,12 @@ import httpx from fastapi import Response -from config import get_code_assist_endpoint, get_auto_ban_error_codes, get_proxy_config +from config import ( + get_code_assist_endpoint, + get_auto_ban_error_codes, + get_ff_retry_policy_v2, + get_proxy_config, +) from log import log from src.credential_manager import credential_manager @@ -90,6 +95,164 @@ def _compute_capacity_retry_delay(base_interval: float, attempt: int) -> float: return min(12.0, exp_backoff + jitter) +_OBS_METRIC_KEYS = ( + "t_req_in", + "t_upstream_send", + "t_first_byte", + "t_first_token", + "t_done", + "retry_count", + "retry_sleep_ms", + "converter_cpu_ms", +) + +_OBS_HEADER_MAP = { + "route": "x-gcli-obs-route", + "model": "x-gcli-obs-model", + "stream": "x-gcli-obs-stream", + "status": "x-gcli-obs-status", + "t_req_in": "x-gcli-obs-t-req-in", + "t_upstream_send": "x-gcli-obs-t-upstream-send", + "t_first_byte": "x-gcli-obs-t-first-byte", + "t_first_token": "x-gcli-obs-t-first-token", + "t_done": "x-gcli-obs-t-done", + "retry_count": "x-gcli-obs-retry-count", + "retry_sleep_ms": "x-gcli-obs-retry-sleep-ms", + "converter_cpu_ms": "x-gcli-obs-converter-cpu-ms", +} + + +def create_observability_context( + route: str, model: str, stream: bool +) -> Dict[str, Any]: + now = time.time() + return { + "route": route, + "model": model, + "stream": bool(stream), + "status": None, + "t_req_in": now, + "t_upstream_send": None, + "t_first_byte": None, + "t_first_token": None, + "t_done": None, + "retry_count": 0, + "retry_sleep_ms": 0.0, + "converter_cpu_ms": 0.0, + } + + +def add_converter_cpu_elapsed( + metrics_ctx: Optional[Dict[str, Any]], elapsed_seconds: float +) -> None: + if not isinstance(metrics_ctx, dict): + return + elapsed_ms = max(0.0, float(elapsed_seconds)) * 1000.0 + metrics_ctx["converter_cpu_ms"] = ( + float(metrics_ctx.get("converter_cpu_ms", 0.0) or 0.0) + elapsed_ms + ) + + +def build_observability_headers( + metrics_ctx: Optional[Dict[str, Any]], +) -> Dict[str, str]: + if not isinstance(metrics_ctx, dict): + return {} + + headers: Dict[str, str] = {} + + for key in _OBS_METRIC_KEYS: + value = metrics_ctx.get(key) + if value is None: + continue + header_key = _OBS_HEADER_MAP.get(key) + if header_key is None: + continue + if key == "retry_count": + headers[header_key] = str(int(value)) + else: + headers[header_key] = f"{float(value):.6f}" + + route = metrics_ctx.get("route") + if route: + headers[_OBS_HEADER_MAP["route"]] = str(route) + + model = metrics_ctx.get("model") + if model: + headers[_OBS_HEADER_MAP["model"]] = str(model) + + stream = metrics_ctx.get("stream") + if stream is not None: + headers[_OBS_HEADER_MAP["stream"]] = "true" if bool(stream) else "false" + + status = metrics_ctx.get("status") + if status is not None: + headers[_OBS_HEADER_MAP["status"]] = str(status) + + return headers + + +def _ensure_observability_context( + metrics_ctx: Optional[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: + if not isinstance(metrics_ctx, dict): + return None + for key in _OBS_METRIC_KEYS: + if key not in metrics_ctx: + if key == "retry_count": + metrics_ctx[key] = 0 + else: + metrics_ctx[key] = ( + 0.0 if key in {"retry_sleep_ms", "converter_cpu_ms"} else None + ) + return metrics_ctx + + +def _mark_timestamp( + metrics_ctx: Optional[Dict[str, Any]], key: str, overwrite: bool = False +) -> None: + if not isinstance(metrics_ctx, dict): + return + if overwrite or metrics_ctx.get(key) is None: + metrics_ctx[key] = time.time() + + +def _mark_status(metrics_ctx: Optional[Dict[str, Any]], status_code: int) -> None: + if not isinstance(metrics_ctx, dict): + return + metrics_ctx["status"] = int(status_code) + + +def _inject_observability_headers( + response: Response, + metrics_ctx: Optional[Dict[str, Any]], +) -> None: + headers = build_observability_headers(metrics_ctx) + if not headers: + return + for key, value in headers.items(): + response.headers[key] = value + + +def _record_retry_sleep( + metrics_ctx: Optional[Dict[str, Any]], sleep_seconds: float +) -> None: + if not isinstance(metrics_ctx, dict): + return + sleep_ms = max(0.0, float(sleep_seconds)) * 1000.0 + metrics_ctx["retry_count"] = int(metrics_ctx.get("retry_count", 0) or 0) + 1 + metrics_ctx["retry_sleep_ms"] = ( + float(metrics_ctx.get("retry_sleep_ms", 0.0) or 0.0) + sleep_ms + ) + + +async def _sleep_with_observability( + metrics_ctx: Optional[Dict[str, Any]], sleep_seconds: float +) -> None: + _record_retry_sleep(metrics_ctx, sleep_seconds) + await asyncio.sleep(sleep_seconds) + + # ==================== 全局凭证管理器 ==================== # 使用全局单例 credential_manager,自动初始化 @@ -146,6 +309,7 @@ async def stream_request( body: Dict[str, Any], native: bool = False, headers: Optional[Dict[str, str]] = None, + metrics_ctx: Optional[Dict[str, Any]] = None, ): """ 流式请求函数 @@ -160,6 +324,12 @@ async def stream_request( """ # 获取有效凭证 model_name = body.get("model", "") + if metrics_ctx is None: + metrics_ctx = create_observability_context( + route="geminicli-stream", model=model_name, stream=True + ) + metrics_ctx = _ensure_observability_context(metrics_ctx) + _mark_timestamp(metrics_ctx, "t_req_in") # 1. 获取有效凭证 cred_result = await credential_manager.get_valid_credential( @@ -168,11 +338,15 @@ async def stream_request( if not cred_result: # 如果返回值是None,直接返回错误500 - yield Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + yield resp return current_file, credential_data = cred_result @@ -195,17 +369,22 @@ async def stream_request( except Exception as e: log.error(f"准备请求失败: {e}") - yield Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": f"准备请求失败: {str(e)}"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + yield resp return # 3. 调用stream_post_async进行请求 retry_config = await get_retry_config() max_retries = retry_config["max_retries"] retry_interval = retry_config["retry_interval"] + retry_policy_v2_enabled = await get_ff_retry_policy_v2() proxy_enabled = bool(await get_proxy_config()) # 流式请求采用结构化超时:限制连接/写入/连接池等待,读取保持无限制以兼容长SSE空隙 stream_timeout = httpx.Timeout(connect=30.0, read=None, write=30.0, pool=30.0) @@ -243,8 +422,10 @@ async def refresh_credential_fast() -> bool: success_recorded = False # 标记是否已记录成功 need_retry = False # 标记是否需要重试 keep_current_credential = False # 标记是否保留当前凭证 + retry_wait_applied = False # 标记本次attempt是否已执行等待 status_code: Optional[int] = None attempt_started_at = time.time() + _mark_timestamp(metrics_ctx, "t_upstream_send") try: async for chunk in stream_post_async( @@ -254,11 +435,22 @@ async def refresh_credential_fast() -> bool: headers=auth_headers, timeout=stream_timeout, ): + _mark_timestamp(metrics_ctx, "t_first_byte") # 判断是否是Response对象 if isinstance(chunk, Response): status_code = chunk.status_code last_error_response = chunk # 记录最后一次错误 + if success_recorded: + log.warning( + f"[GEMINICLI STREAM] 首token后收到错误响应,禁止重试 (status={status_code})" + ) + _mark_status(metrics_ctx, chunk.status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(chunk, metrics_ctx) + yield chunk + return + # 缓存错误解析结果,避免重复decode error_body = None try: @@ -303,7 +495,8 @@ async def refresh_credential_fast() -> bool: # 对于没有触发cd且非容量耗尽的429/503错误,保留当前凭证重试;否则预热下一个凭证 if ( - (status_code == 429 or status_code == 503) + retry_policy_v2_enabled + and (status_code == 429 or status_code == 503) and cooldown_until is None and not is_model_capacity_exhausted ): @@ -344,16 +537,22 @@ async def refresh_credential_fast() -> bool: max_retries, retry_interval, mode="geminicli", + metrics_ctx=metrics_ctx, + retry_policy_v2_enabled=retry_policy_v2_enabled, ) if should_retry and attempt < max_retries: need_retry = True + retry_wait_applied = retry_policy_v2_enabled break # 跳出内层循环,准备重试 else: # 不重试,直接返回原始错误 log.error( - f"[GEMINICLI STREAM] 达到最大重试次数或不应重试,返回原始错误" + "[GEMINICLI STREAM] 达到最大重试次数或不应重试,返回原始错误" ) + _mark_status(metrics_ctx, chunk.status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(chunk, metrics_ctx) yield chunk return elif status_code == 404 and "preview" in model_name.lower(): @@ -400,8 +599,11 @@ async def refresh_credential_fast() -> bool: break else: log.error( - f"[GEMINICLI STREAM] 达到最大重试次数,返回404错误" + "[GEMINICLI STREAM] 达到最大重试次数,返回404错误" ) + _mark_status(metrics_ctx, chunk.status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(chunk, metrics_ctx) yield chunk return else: @@ -418,12 +620,16 @@ async def refresh_credential_fast() -> bool: model_name=model_name, error_message=error_body, ) + _mark_status(metrics_ctx, chunk.status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(chunk, metrics_ctx) yield chunk return else: # 不是Response,说明是真流,直接yield返回 # 只在第一个chunk时记录成功 if not success_recorded: + _mark_timestamp(metrics_ctx, "t_first_token") await record_api_call_success( credential_manager, current_file, @@ -440,6 +646,8 @@ async def refresh_credential_fast() -> bool: # 流式请求完成,检查结果 if success_recorded: log.debug(f"[GEMINICLI STREAM] 流式响应完成,模型: {model_name}") + _mark_status(metrics_ctx, 200) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) return # 统一处理重试 @@ -456,7 +664,8 @@ async def refresh_credential_fast() -> bool: log.info( f"[GEMINICLI STREAM] {status_label}无冷却时间,保留当前凭证重试: {current_file}" ) - await asyncio.sleep(retry_interval) + if not retry_wait_applied: + await _sleep_with_observability(metrics_ctx, retry_interval) continue # 使用预热的凭证任务,避免等待 @@ -475,22 +684,30 @@ async def refresh_credential_fast() -> bool: if token and project_id: auth_headers["Authorization"] = f"Bearer {token}" final_payload["project"] = project_id - await asyncio.sleep(retry_interval) + if not retry_wait_applied: + await _sleep_with_observability( + metrics_ctx, retry_interval + ) continue # 重试 except Exception as e: log.warning(f"[GEMINICLI STREAM] 预热凭证任务失败: {e}") next_cred_task = None # 如果预热的凭证不可用,则同步获取 - await asyncio.sleep(retry_interval) + if not retry_wait_applied: + await _sleep_with_observability(metrics_ctx, retry_interval) if not await refresh_credential_fast(): log.error("[GEMINICLI STREAM] 重试时无可用凭证或刷新失败") - yield Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + yield resp return continue # 重试 @@ -502,13 +719,18 @@ async def refresh_credential_fast() -> bool: f"耗时: {elapsed:.2f}s, url: {target_url}, proxy={'on' if proxy_enabled else 'off'}, " f"http2=True, timeout(connect/write/pool=30s, read=None), 凭证: {current_file}" ) + if success_recorded: + log.warning("[GEMINICLI STREAM] 首token后发生异常,停止流且不重试") + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + return if attempt < max_retries: delay = _compute_capacity_retry_delay(retry_interval, attempt) log.info( f"[GEMINICLI STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1}), " f"等待 {delay:.1f}s..." ) - await asyncio.sleep(delay) + await _sleep_with_observability(metrics_ctx, delay) continue else: # 所有重试都失败,返回最后一次的错误(如果有) @@ -517,20 +739,28 @@ async def refresh_credential_fast() -> bool: f"detail={repr(e)}" ) if last_error_response: + _mark_status(metrics_ctx, last_error_response.status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(last_error_response, metrics_ctx) yield last_error_response else: # 如果没有记录到错误响应,返回500错误 - yield Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": f"流式请求异常: {str(e)}"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + yield resp return async def non_stream_request( body: Dict[str, Any], headers: Optional[Dict[str, str]] = None, + metrics_ctx: Optional[Dict[str, Any]] = None, ) -> Response: """ 非流式请求函数 @@ -545,6 +775,12 @@ async def non_stream_request( """ # 获取有效凭证 model_name = body.get("model", "") + if metrics_ctx is None: + metrics_ctx = create_observability_context( + route="geminicli-non-stream", model=model_name, stream=False + ) + metrics_ctx = _ensure_observability_context(metrics_ctx) + _mark_timestamp(metrics_ctx, "t_req_in") # 1. 获取有效凭证 cred_result = await credential_manager.get_valid_credential( @@ -553,11 +789,15 @@ async def non_stream_request( if not cred_result: # 如果返回值是None,直接返回错误500 - return Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + return resp current_file, credential_data = cred_result @@ -579,16 +819,21 @@ async def non_stream_request( except Exception as e: log.error(f"准备请求失败: {e}") - return Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": f"准备请求失败: {str(e)}"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + return resp # 3. 调用post_async进行请求 retry_config = await get_retry_config() max_retries = retry_config["max_retries"] retry_interval = retry_config["retry_interval"] + retry_policy_v2_enabled = await get_ff_retry_policy_v2() DISABLE_ERROR_CODES = await get_auto_ban_error_codes() # 禁用凭证的错误码 last_error_response = None # 记录最后一次的错误响应 @@ -620,11 +865,25 @@ async def refresh_credential_fast() -> bool: return False for attempt in range(max_retries + 1): + retry_wait_applied = False try: + _mark_timestamp(metrics_ctx, "t_upstream_send") + non_stream_timeout = httpx.Timeout( + connect=20.0, + read=300.0, + write=30.0, + pool=20.0, + ) response = await post_async( - url=target_url, json=final_payload, headers=auth_headers, timeout=300.0 + url=target_url, + json=final_payload, + headers=auth_headers, + timeout=non_stream_timeout, ) + _mark_timestamp(metrics_ctx, "t_first_byte") + _mark_timestamp(metrics_ctx, "t_first_token") + status_code = response.status_code # 成功 @@ -635,10 +894,13 @@ async def refresh_credential_fast() -> bool: mode="geminicli", model_name=model_name, ) + _mark_status(metrics_ctx, 200) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) # 创建响应头,移除压缩相关的header避免重复解压 response_headers = dict(response.headers) response_headers.pop("content-encoding", None) response_headers.pop("content-length", None) + response_headers.update(build_observability_headers(metrics_ctx)) return Response( content=response.content, status_code=200, headers=response_headers @@ -689,12 +951,15 @@ async def refresh_credential_fast() -> bool: status_code == 429 and error_reason == "MODEL_CAPACITY_EXHAUSTED" ) - # 对于没有触发cd且非容量耗尽的429错误,不预热新凭证 - if not ( - (status_code == 429 or status_code == 503) + keep_current_credential = ( + retry_policy_v2_enabled + and (status_code == 429 or status_code == 503) and cooldown_until is None and not is_model_capacity_exhausted - ): + ) + + # 对于没有触发cd且非容量耗尽的429错误,不预热新凭证 + if not keep_current_credential: # 并行预热下一个凭证,不阻塞当前处理 if next_cred_task is None and attempt < max_retries: next_cred_task = asyncio.create_task( @@ -710,11 +975,7 @@ async def refresh_credential_fast() -> bool: ) # 无cd的429/503保留当前凭证重试,无需记录错误 - if not ( - (status_code == 429 or status_code == 503) - and cooldown_until is None - and not is_model_capacity_exhausted - ): + if not keep_current_credential: await record_api_call_error( credential_manager, current_file, @@ -735,6 +996,8 @@ async def refresh_credential_fast() -> bool: max_retries, retry_interval, mode="geminicli", + metrics_ctx=metrics_ctx, + retry_policy_v2_enabled=retry_policy_v2_enabled, ) if should_retry and attempt < max_retries: @@ -742,17 +1005,15 @@ async def refresh_credential_fast() -> bool: log.info( f"[NON-STREAM] 重试请求 (attempt {attempt + 2}/{max_retries + 1})..." ) + retry_wait_applied = retry_policy_v2_enabled # 对于没有冷却时间且非容量耗尽的429错误,保留当前凭证重试 - if ( - (status_code == 429 or status_code == 503) - and cooldown_until is None - and not is_model_capacity_exhausted - ): + if keep_current_credential: log.info( f"[NON-STREAM] {status_code}无冷却时间,保留当前凭证重试: {current_file}" ) - await asyncio.sleep(retry_interval) + if not retry_wait_applied: + await _sleep_with_observability(metrics_ctx, retry_interval) continue # 使用预热的凭证任务,避免等待 @@ -771,26 +1032,37 @@ async def refresh_credential_fast() -> bool: if token and project_id: auth_headers["Authorization"] = f"Bearer {token}" final_payload["project"] = project_id - await asyncio.sleep(retry_interval) + if not retry_wait_applied: + await _sleep_with_observability( + metrics_ctx, retry_interval + ) continue # 重试 except Exception as e: log.warning(f"[NON-STREAM] 预热凭证任务失败: {e}") next_cred_task = None # 如果预热的凭证不可用,则同步获取 - await asyncio.sleep(retry_interval) + if not retry_wait_applied: + await _sleep_with_observability(metrics_ctx, retry_interval) if not await refresh_credential_fast(): log.error("[NON-STREAM] 重试时无可用凭证或刷新失败") - return Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + return resp continue # 重试 else: # 不重试,直接返回原始错误 - log.error(f"[NON-STREAM] 达到最大重试次数或不应重试,返回原始错误") + log.error("[NON-STREAM] 达到最大重试次数或不应重试,返回原始错误") + _mark_status(metrics_ctx, status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(last_error_response, metrics_ctx) return last_error_response elif status_code == 404 and "preview" in model_name.lower(): # 特殊处理:preview模型返回404,说明该凭证不支持preview模型 @@ -850,25 +1122,36 @@ async def refresh_credential_fast() -> bool: if token and project_id: auth_headers["Authorization"] = f"Bearer {token}" final_payload["project"] = project_id - await asyncio.sleep(retry_interval) + if not retry_wait_applied: + await _sleep_with_observability( + metrics_ctx, retry_interval + ) continue # 重试 except Exception as e: log.warning(f"[NON-STREAM] 预热凭证任务失败: {e}") next_cred_task = None # 如果预热的凭证不可用,则同步获取 - await asyncio.sleep(retry_interval) + if not retry_wait_applied: + await _sleep_with_observability(metrics_ctx, retry_interval) if not await refresh_credential_fast(): log.error("[NON-STREAM] 重试时无可用凭证或刷新失败") - return Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": "当前无可用凭证"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + return resp continue # 重试 else: - log.error(f"[NON-STREAM] 达到最大重试次数,返回404错误") + log.error("[NON-STREAM] 达到最大重试次数,返回404错误") + _mark_status(metrics_ctx, status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(last_error_response, metrics_ctx) return last_error_response else: # 错误码不在重试范围内,直接返回 @@ -884,6 +1167,9 @@ async def refresh_credential_fast() -> bool: model_name=model_name, error_message=error_text, ) + _mark_status(metrics_ctx, status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(last_error_response, metrics_ctx) return last_error_response except Exception as e: @@ -897,7 +1183,7 @@ async def refresh_credential_fast() -> bool: f"[NON-STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1}), " f"等待 {delay:.1f}s..." ) - await asyncio.sleep(delay) + await _sleep_with_observability(metrics_ctx, delay) continue else: # 所有重试都失败,返回最后一次的错误(如果有)或500错误 @@ -906,23 +1192,37 @@ async def refresh_credential_fast() -> bool: f"detail={repr(e)}" ) if last_error_response: + _mark_status(metrics_ctx, last_error_response.status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(last_error_response, metrics_ctx) return last_error_response else: - return Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": f"请求异常: {str(e)}"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + return resp # 所有重试都失败,返回最后一次的原始错误 log.error("[NON-STREAM] 所有重试均失败") if last_error_response is not None: + _mark_status(metrics_ctx, last_error_response.status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(last_error_response, metrics_ctx) return last_error_response - return Response( + _mark_status(metrics_ctx, 500) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( content=json.dumps({"error": "请求失败,且未收到上游错误响应"}), status_code=500, media_type="application/json", ) + _inject_observability_headers(resp, metrics_ctx) + return resp # ==================== 测试代码 ==================== @@ -964,7 +1264,7 @@ async def test_stream_request(): chunk_count += 1 if isinstance(chunk, Response): # 错误响应 - print(f"\n❌ 错误响应:") + print("\n❌ 错误响应:") print(f" 状态码: {chunk.status_code}") print(f" Content-Type: {chunk.headers.get('content-type', 'N/A')}") try: @@ -1024,7 +1324,7 @@ async def test_non_stream_request(): # 尝试解析JSON try: json_data = json.loads(content) - print(f"响应内容 (格式化JSON):") + print("响应内容 (格式化JSON):") print(json.dumps(json_data, indent=2, ensure_ascii=False)) except json.JSONDecodeError: print("(非JSON格式)") diff --git a/src/api/utils.py b/src/api/utils.py index d899a2c7d..7d683a0b4 100644 --- a/src/api/utils.py +++ b/src/api/utils.py @@ -69,6 +69,8 @@ async def handle_error_with_retry( max_retries: int, retry_interval: float, mode: str = "geminicli", + metrics_ctx: Optional[Dict[str, Any]] = None, + retry_policy_v2_enabled: bool = True, ) -> bool: """ 统一处理错误和重试逻辑 @@ -91,6 +93,20 @@ async def handle_error_with_retry( Returns: bool: True表示需要继续重试,False表示不需要重试 """ + + def _record_retry_sleep(sleep_seconds: float) -> None: + if not isinstance(metrics_ctx, dict): + return + sleep_ms = max(0.0, float(sleep_seconds)) * 1000.0 + metrics_ctx["retry_count"] = int(metrics_ctx.get("retry_count", 0) or 0) + 1 + metrics_ctx["retry_sleep_ms"] = ( + float(metrics_ctx.get("retry_sleep_ms", 0.0) or 0.0) + sleep_ms + ) + + async def _sleep_with_metrics(sleep_seconds: float) -> None: + _record_retry_sleep(sleep_seconds) + await asyncio.sleep(sleep_seconds) + # 优先检查自动封禁 should_auto_ban = await check_should_auto_ban(status_code) @@ -104,7 +120,8 @@ async def handle_error_with_retry( f"[{mode.upper()} RETRY] Retrying with next credential after auto-ban " f"(status {status_code}, attempt {attempt + 1}/{max_retries})" ) - await asyncio.sleep(retry_interval) + if retry_policy_v2_enabled: + await _sleep_with_metrics(retry_interval) return True return False @@ -118,7 +135,8 @@ async def handle_error_with_retry( f"[{mode.upper()} RETRY] {status_code} error encountered, retrying " f"(attempt {attempt + 1}/{max_retries})" ) - await asyncio.sleep(retry_interval) + if retry_policy_v2_enabled: + await _sleep_with_metrics(retry_interval) return True # 其他错误不进行重试 @@ -473,7 +491,7 @@ async def collect_streaming_response(stream_generator) -> Response: # 去掉嵌套的 "response" 包装(Antigravity格式 -> 标准Gemini格式) if "response" in merged_response and "candidates" not in merged_response: - log.debug(f"[STREAM COLLECTOR] 展开response包装") + log.debug("[STREAM COLLECTOR] 展开response包装") merged_response = merged_response["response"] # 返回纯JSON格式 diff --git a/src/converter/anthropic2gemini.py b/src/converter/anthropic2gemini.py index ebb458712..cf8ff1d9e 100644 --- a/src/converter/anthropic2gemini.py +++ b/src/converter/anthropic2gemini.py @@ -12,6 +12,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional from fastapi import Response +from config import get_ff_converter_fast_path from log import log from src.converter.utils import merge_system_messages @@ -437,6 +438,87 @@ def _extract_tool_result_output(content: Any) -> str: return str(content) +def _can_use_text_only_fast_path(payload: Dict[str, Any]) -> bool: + """判断是否可以走纯文本 fast-path。""" + if payload.get("thinking") is not None: + return False + + if payload.get("tools"): + return False + + if payload.get("tool_choice"): + return False + + messages = payload.get("messages") + if not isinstance(messages, list): + return False + + for msg in messages: + if not isinstance(msg, dict): + return False + + role = msg.get("role", "user") + if role not in ("user", "assistant", "model", "system"): + return False + + if role == "system": + continue + + raw_content = msg.get("content", "") + if isinstance(raw_content, str): + continue + + if isinstance(raw_content, list): + for item in raw_content: + if isinstance(item, dict): + if item.get("type") != "text": + return False + elif item is None: + continue + continue + + if raw_content is None: + continue + + return True + + +def _convert_text_only_messages_to_contents_fast( + messages: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """纯文本 messages 的轻量转换。""" + contents: List[Dict[str, Any]] = [] + + for msg in messages: + role = msg.get("role", "user") + if role == "system": + continue + + gemini_role = "model" if role in ("assistant", "model") else "user" + raw_content = msg.get("content", "") + + parts: List[Dict[str, Any]] = [] + if isinstance(raw_content, str): + if _is_non_whitespace_text(raw_content): + parts.append({"text": str(raw_content)}) + elif isinstance(raw_content, list): + for item in raw_content: + if isinstance(item, dict): + text = item.get("text", "") + else: + text = item + + if _is_non_whitespace_text(text): + parts.append({"text": str(text)}) + elif _is_non_whitespace_text(raw_content): + parts.append({"text": str(raw_content)}) + + if parts: + contents.append({"role": gemini_role, "parts": parts}) + + return contents + + def convert_messages_to_contents( messages: List[Dict[str, Any]], *, include_thinking: bool = True ) -> List[Dict[str, Any]]: @@ -818,32 +900,42 @@ async def anthropic_to_gemini_request(payload: Dict[str, Any]) -> Dict[str, Any] if not isinstance(messages, list): messages = [] + use_text_only_fast_path = bool( + await get_ff_converter_fast_path() + ) and _can_use_text_only_fast_path(payload) + # [CRITICAL FIX] 过滤并修复 Thinking 块签名 # 在转换前先过滤无效的 thinking 块 - filter_invalid_thinking_blocks(messages) + if not use_text_only_fast_path: + filter_invalid_thinking_blocks(messages) # 构建生成配置 generation_config = build_generation_config(payload) - # 转换消息内容(始终包含thinking块,由响应端处理) - contents = convert_messages_to_contents(messages, include_thinking=True) + if use_text_only_fast_path: + contents = _convert_text_only_messages_to_contents_fast(messages) + tools = None + tool_config = None + else: + # 转换消息内容(始终包含thinking块,由响应端处理) + contents = convert_messages_to_contents(messages, include_thinking=True) - # [CRITICAL FIX] 移除尾部无签名的 thinking 块 - # 对真实请求应用额外的清理 - for content in contents: - role = content.get("role", "") - if role == "model": # 只处理 model/assistant 消息 - parts = content.get("parts", []) - if isinstance(parts, list): - remove_trailing_unsigned_thinking(parts) + # [CRITICAL FIX] 移除尾部无签名的 thinking 块 + # 对真实请求应用额外的清理 + for content in contents: + role = content.get("role", "") + if role == "model": # 只处理 model/assistant 消息 + parts = content.get("parts", []) + if isinstance(parts, list): + remove_trailing_unsigned_thinking(parts) - contents = reorganize_tool_messages(contents) + contents = reorganize_tool_messages(contents) - # 转换工具 - tools = convert_tools(payload.get("tools")) + # 转换工具 + tools = convert_tools(payload.get("tools")) - # 转换 tool_choice - tool_config = convert_tool_choice_to_tool_config(payload.get("tool_choice")) + # 转换 tool_choice + tool_config = convert_tool_choice_to_tool_config(payload.get("tool_choice")) # 构建基础请求数据 gemini_request = { @@ -1040,6 +1132,7 @@ async def gemini_stream_to_anthropic_stream( input_tokens = 0 output_tokens = 0 finish_reason: Optional[str] = None + debug_logging_enabled = log.get_current_level() == "debug" def _sse_event(event: str, data: Dict[str, Any]) -> bytes: """生成 SSE 事件""" @@ -1067,41 +1160,50 @@ def _close_block() -> Optional[bytes]: f"[GEMINI_TO_ANTHROPIC] 收到 Response 对象,状态码: {chunk.status_code},直接转发错误" ) # 直接转发错误响应内容,不做格式转换 - error_content = ( - chunk.body - if isinstance(chunk.body, bytes) - else chunk.body.encode("utf-8") - ) + body = chunk.body + if isinstance(body, (bytes, bytearray, memoryview)): + error_content = bytes(body) + else: + error_content = str(body or "").encode("utf-8") yield error_content return # 记录接收到的原始chunk - log.debug( - f"[GEMINI_TO_ANTHROPIC] Raw chunk: {chunk[:200] if chunk else b''}" - ) + if debug_logging_enabled: + log.debug( + f"[GEMINI_TO_ANTHROPIC] Raw chunk: {chunk[:200] if chunk else b''}" + ) # 解析 Gemini 流式块 if not chunk or not chunk.startswith(b"data: "): - log.debug( - f"[GEMINI_TO_ANTHROPIC] Skipping chunk (not SSE format or empty)" - ) + if debug_logging_enabled: + log.debug( + "[GEMINI_TO_ANTHROPIC] Skipping chunk (not SSE format or empty)" + ) continue - raw = chunk[6:].strip() + raw = chunk[6:].rstrip() if raw == b"[DONE]": - log.debug(f"[GEMINI_TO_ANTHROPIC] Received [DONE] marker") + if debug_logging_enabled: + log.debug("[GEMINI_TO_ANTHROPIC] Received [DONE] marker") break - log.debug(f"[GEMINI_TO_ANTHROPIC] Parsing JSON: {raw[:200]}") + if debug_logging_enabled: + log.debug(f"[GEMINI_TO_ANTHROPIC] Parsing JSON: {raw[:200]}") try: - data = json.loads(raw.decode("utf-8", errors="ignore")) + data = json.loads(raw) + except Exception as e: + try: + data = json.loads(raw.decode("utf-8", errors="ignore")) + except Exception: + log.warning(f"[GEMINI_TO_ANTHROPIC] JSON parse error: {e}") + continue + + if debug_logging_enabled: log.debug( f"[GEMINI_TO_ANTHROPIC] Parsed data: {json.dumps(data, ensure_ascii=False)[:300]}" ) - except Exception as e: - log.warning(f"[GEMINI_TO_ANTHROPIC] JSON parse error: {e}") - continue # 处理 GeminiCLI 的 response 包装格式 if "response" in data: diff --git a/src/credential_manager.py b/src/credential_manager.py index 861c09418..99ba11b09 100644 --- a/src/credential_manager.py +++ b/src/credential_manager.py @@ -7,11 +7,13 @@ from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Tuple +from config import get_ff_preview_credential_scheduler_v2 from log import log from src.google_oauth_api import Credentials from src.storage_adapter import get_storage_adapter + class CredentialManager: """ 统一凭证管理器 @@ -22,6 +24,12 @@ def __init__(self): # 核心状态 self._initialized = False self._storage_adapter = None + self._inflight_lock = asyncio.Lock() + self._inflight_counts: Dict[str, Dict[str, Dict[str, float]]] = { + "geminicli": {}, + "antigravity": {}, + } + self._inflight_ttl_seconds = 180.0 # 并发控制(简化) # 后端数据库自行处理并发,credential_manager 不再使用本地锁 @@ -66,10 +74,27 @@ async def get_valid_credential( # 最多重试3次 max_retries = 3 + scheduler_v2_enabled = await get_ff_preview_credential_scheduler_v2() for attempt in range(max_retries): - result = await self._storage_adapter._backend.get_next_available_credential( - mode=mode, model_name=model_name - ) + try: + if scheduler_v2_enabled: + in_flight_snapshot = await self._get_inflight_snapshot(mode) + result = await self._storage_adapter._backend.get_next_available_credential( + mode=mode, + model_name=model_name, + scheduling_hints={"in_flight": in_flight_snapshot}, + ) + else: + result = await self._storage_adapter._backend.get_next_available_credential( + mode=mode, model_name=model_name + ) + except TypeError: + # 兼容旧后端签名 + result = ( + await self._storage_adapter._backend.get_next_available_credential( + mode=mode, model_name=model_name + ) + ) # 如果没有可用凭证,直接返回None if not result: @@ -78,11 +103,14 @@ async def get_valid_credential( return None filename, credential_data = result + await self._acquire_inflight(mode, filename) # Token 刷新检查 if await self._should_refresh_token(credential_data): log.debug(f"Token需要刷新 - 文件: {filename} (mode={mode})") - refreshed_data = await self._refresh_token(credential_data, filename, mode=mode) + refreshed_data = await self._refresh_token( + credential_data, filename, mode=mode + ) if refreshed_data: # 刷新成功,返回凭证 credential_data = refreshed_data @@ -90,7 +118,10 @@ async def get_valid_credential( return filename, credential_data else: # 刷新失败(_refresh_token内部已自动禁用失效凭证) - log.warning(f"Token刷新失败,尝试获取下一个凭证: {filename} (mode={mode}, attempt={attempt+1}/{max_retries})") + log.warning( + f"Token刷新失败,尝试获取下一个凭证: {filename} (mode={mode}, attempt={attempt + 1}/{max_retries})" + ) + await self._release_inflight(mode, filename) # 继续循环,尝试获取下一个可用凭证 continue else: @@ -98,10 +129,14 @@ async def get_valid_credential( return filename, credential_data # 重试次数用尽 - log.error(f"重试{max_retries}次后仍无可用凭证 (mode={mode}, model_name={model_name})") + log.error( + f"重试{max_retries}次后仍无可用凭证 (mode={mode}, model_name={model_name})" + ) return None - async def add_credential(self, credential_name: str, credential_data: Dict[str, Any]): + async def add_credential( + self, credential_name: str, credential_data: Dict[str, Any] + ): """ 新增或更新一个凭证 存储层会自动处理轮换顺序 @@ -110,16 +145,22 @@ async def add_credential(self, credential_name: str, credential_data: Dict[str, await self._storage_adapter.store_credential(credential_name, credential_data) log.info(f"Credential added/updated: {credential_name}") - async def add_antigravity_credential(self, credential_name: str, credential_data: Dict[str, Any]): + async def add_antigravity_credential( + self, credential_name: str, credential_data: Dict[str, Any] + ): """ 新增或更新一个Antigravity凭证 存储层会自动处理轮换顺序 """ await self._ensure_initialized() - await self._storage_adapter.store_credential(credential_name, credential_data, mode="antigravity") + await self._storage_adapter.store_credential( + credential_name, credential_data, mode="antigravity" + ) log.info(f"Antigravity credential added/updated: {credential_name}") - async def remove_credential(self, credential_name: str, mode: str = "geminicli") -> bool: + async def remove_credential( + self, credential_name: str, mode: str = "geminicli" + ) -> bool: """删除一个凭证""" await self._ensure_initialized() try: @@ -130,9 +171,16 @@ async def remove_credential(self, credential_name: str, mode: str = "geminicli") log.error(f"Error removing credential {credential_name}: {e}") return False - async def update_credential_state(self, credential_name: str, state_updates: Dict[str, Any], mode: str = "geminicli"): + async def update_credential_state( + self, + credential_name: str, + state_updates: Dict[str, Any], + mode: str = "geminicli", + ): """更新凭证状态""" - log.debug(f"[CredMgr] update_credential_state 开始: credential_name={credential_name}, state_updates={state_updates}, mode={mode}") + log.debug( + f"[CredMgr] update_credential_state 开始: credential_name={credential_name}, state_updates={state_updates}, mode={mode}" + ) log.debug(f"[CredMgr] 调用 _ensure_initialized...") await self._ensure_initialized() log.debug(f"[CredMgr] _ensure_initialized 完成") @@ -141,20 +189,28 @@ async def update_credential_state(self, credential_name: str, state_updates: Dic success = await self._storage_adapter.update_credential_state( credential_name, state_updates, mode=mode ) - log.debug(f"[CredMgr] storage_adapter.update_credential_state 返回: {success}") + log.debug( + f"[CredMgr] storage_adapter.update_credential_state 返回: {success}" + ) if success: log.debug(f"Updated credential state: {credential_name} (mode={mode})") else: - log.warning(f"Failed to update credential state: {credential_name} (mode={mode})") + log.warning( + f"Failed to update credential state: {credential_name} (mode={mode})" + ) return success except Exception as e: log.error(f"Error updating credential state {credential_name}: {e}") return False - async def set_cred_disabled(self, credential_name: str, disabled: bool, mode: str = "geminicli"): + async def set_cred_disabled( + self, credential_name: str, disabled: bool, mode: str = "geminicli" + ): """设置凭证的启用/禁用状态""" try: - log.info(f"[CredMgr] set_cred_disabled 开始: credential_name={credential_name}, disabled={disabled}, mode={mode}") + log.info( + f"[CredMgr] set_cred_disabled 开始: credential_name={credential_name}, disabled={disabled}, mode={mode}" + ) success = await self.update_credential_state( credential_name, {"disabled": disabled}, mode=mode ) @@ -163,7 +219,9 @@ async def set_cred_disabled(self, credential_name: str, disabled: bool, mode: st action = "disabled" if disabled else "enabled" log.info(f"Credential {action}: {credential_name} (mode={mode})") else: - log.warning(f"[CredMgr] 设置禁用状态失败: credential_name={credential_name}, disabled={disabled}") + log.warning( + f"[CredMgr] 设置禁用状态失败: credential_name={credential_name}, disabled={disabled}" + ) return success except Exception as e: log.error(f"Error setting credential disabled state {credential_name}: {e}") @@ -190,21 +248,27 @@ async def get_creds_summary(self) -> List[Dict[str, Any]]: log.error(f"Error getting credentials summary: {e}") return [] - async def get_or_fetch_user_email(self, credential_name: str, mode: str = "geminicli") -> Optional[str]: + async def get_or_fetch_user_email( + self, credential_name: str, mode: str = "geminicli" + ) -> Optional[str]: """获取或获取用户邮箱地址""" try: # 确保已初始化 await self._ensure_initialized() - + # 从状态中获取缓存的邮箱 - state = await self._storage_adapter.get_credential_state(credential_name, mode=mode) + state = await self._storage_adapter.get_credential_state( + credential_name, mode=mode + ) cached_email = state.get("user_email") if state else None if cached_email: return cached_email # 如果没有缓存,从凭证数据获取 - credential_data = await self._storage_adapter.get_credential(credential_name, mode=mode) + credential_data = await self._storage_adapter.get_credential( + credential_name, mode=mode + ) if not credential_data: return None @@ -222,7 +286,9 @@ async def get_or_fetch_user_email(self, credential_name: str, mode: str = "gemin if token_refreshed: log.info(f"Token已自动刷新: {credential_name} (mode={mode})") updated_data = credentials.to_dict() - await self._storage_adapter.store_credential(credential_name, updated_data, mode=mode) + await self._storage_adapter.store_credential( + credential_name, updated_data, mode=mode + ) # 获取邮箱 email = await get_user_email(credentials) @@ -248,7 +314,7 @@ async def record_api_call_result( cooldown_until: Optional[float] = None, mode: str = "geminicli", model_name: Optional[str] = None, - error_message: Optional[str] = None + error_message: Optional[str] = None, ): """ 记录API调用结果 @@ -263,10 +329,11 @@ async def record_api_call_result( error_message: 错误信息(如果失败) """ await self._ensure_initialized() + await self._release_inflight(mode, credential_name) try: if success: - # 条件写入:仅当凭证有错误状态或模型冷却时才写 DB,零内存缓存 - # fire-and-forget,不阻塞请求链路 + # 条件写入:仅当凭证有错误状态或模型冷却时才写 DB,零内存缓存 + # fire-and-forget,不阻塞请求链路 asyncio.create_task( self._storage_adapter._backend.record_success( credential_name, model_name=model_name, mode=mode @@ -284,11 +351,13 @@ async def record_api_call_result( "error_messages": error_messages, } - await self.update_credential_state(credential_name, state_updates, mode=mode) + await self.update_credential_state( + credential_name, state_updates, mode=mode + ) # 设置模型级冷却 if cooldown_until is not None and model_name: - if hasattr(self._storage_adapter._backend, 'set_model_cooldown'): + if hasattr(self._storage_adapter._backend, "set_model_cooldown"): await self._storage_adapter._backend.set_model_cooldown( credential_name, model_name, cooldown_until, mode=mode ) @@ -300,11 +369,58 @@ async def record_api_call_result( except Exception as e: log.error(f"Error recording API call result for {credential_name}: {e}") + async def _acquire_inflight(self, mode: str, credential_name: str) -> None: + name = credential_name.rsplit("/", 1)[-1] + async with self._inflight_lock: + bucket = self._inflight_counts.setdefault(mode, {}) + now = time.time() + self._prune_stale_inflight_locked(bucket, now) + state = bucket.setdefault(name, {"count": 0.0, "last_touch": now}) + state["count"] = float(state.get("count", 0.0)) + 1.0 + state["last_touch"] = now + + async def _release_inflight(self, mode: str, credential_name: str) -> None: + name = credential_name.rsplit("/", 1)[-1] + async with self._inflight_lock: + bucket = self._inflight_counts.setdefault(mode, {}) + now = time.time() + self._prune_stale_inflight_locked(bucket, now) + state = bucket.get(name) + if not state: + return + next_count = float(state.get("count", 0.0)) - 1.0 + if next_count <= 0: + bucket.pop(name, None) + return + state["count"] = next_count + state["last_touch"] = now + + async def _get_inflight_snapshot(self, mode: str) -> Dict[str, int]: + async with self._inflight_lock: + bucket = self._inflight_counts.setdefault(mode, {}) + now = time.time() + self._prune_stale_inflight_locked(bucket, now) + return { + filename: int(max(0.0, float(state.get("count", 0.0)))) + for filename, state in bucket.items() + } + + def _prune_stale_inflight_locked( + self, bucket: Dict[str, Dict[str, float]], now: float + ) -> None: + for filename in list(bucket.keys()): + state = bucket[filename] + last_touch = float(state.get("last_touch", now)) + if now - last_touch >= self._inflight_ttl_seconds: + bucket.pop(filename, None) + async def _should_refresh_token(self, credential_data: Dict[str, Any]) -> bool: """检查token是否需要刷新""" try: # 如果没有access_token或过期时间,需要刷新 - if not credential_data.get("access_token") and not credential_data.get("token"): + if not credential_data.get("access_token") and not credential_data.get( + "token" + ): log.debug("没有access_token,需要刷新") return True @@ -319,7 +435,9 @@ async def _should_refresh_token(self, credential_data: Dict[str, Any]) -> bool: if "+" in expiry_str: file_expiry = datetime.fromisoformat(expiry_str) elif expiry_str.endswith("Z"): - file_expiry = datetime.fromisoformat(expiry_str.replace("Z", "+00:00")) + file_expiry = datetime.fromisoformat( + expiry_str.replace("Z", "+00:00") + ) else: file_expiry = datetime.fromisoformat(expiry_str) else: @@ -338,13 +456,15 @@ async def _should_refresh_token(self, credential_data: Dict[str, Any]) -> bool: f"Token时间检查: " f"当前UTC时间={now.isoformat()}, " f"过期时间={file_expiry.isoformat()}, " - f"剩余时间={int(time_left/60)}分{int(time_left%60)}秒" + f"剩余时间={int(time_left / 60)}分{int(time_left % 60)}秒" ) if time_left > 120: # 2分钟缓冲 return False else: - log.debug(f"Token即将过期(剩余{int(time_left/60)}分钟),需要刷新") + log.debug( + f"Token即将过期(剩余{int(time_left / 60)}分钟),需要刷新" + ) return True except Exception as e: @@ -369,7 +489,9 @@ async def _refresh_token( log.error(f"没有refresh_token,无法刷新: {filename} (mode={mode})") # 自动禁用没有refresh_token的凭证 try: - await self.update_credential_state(filename, {"disabled": True}, mode=mode) + await self.update_credential_state( + filename, {"disabled": True}, mode=mode + ) log.warning(f"凭证已自动禁用(缺少refresh_token): {filename}") except Exception as e: log.error(f"禁用凭证失败 {filename}: {e}") @@ -389,7 +511,9 @@ async def _refresh_token( credential_data["expiry"] = creds.expires_at.isoformat() # 保存到存储 - await self._storage_adapter.store_credential(filename, credential_data, mode=mode) + await self._storage_adapter.store_credential( + filename, credential_data, mode=mode + ) log.info(f"Token刷新成功并已保存: {filename} (mode={mode})") return credential_data @@ -400,24 +524,30 @@ async def _refresh_token( # 尝试提取HTTP状态码(TokenError可能携带status_code属性) status_code = None - if hasattr(e, 'status_code'): + if hasattr(e, "status_code"): status_code = e.status_code # 检查是否是凭证永久失效的错误(只有明确的400/403等才判定为永久失效) - is_permanent_failure = self._is_permanent_refresh_failure(error_msg, status_code) + is_permanent_failure = self._is_permanent_refresh_failure( + error_msg, status_code + ) if is_permanent_failure: log.warning(f"检测到凭证永久失效 (HTTP {status_code}): {filename}") # 记录失效状态 if status_code: - await self.record_api_call_result(filename, False, status_code, mode=mode) + await self.record_api_call_result( + filename, False, status_code, mode=mode + ) else: await self.record_api_call_result(filename, False, 400, mode=mode) # 禁用失效凭证 try: # 直接禁用该凭证(随机选择机制会自动跳过它) - disabled_ok = await self.update_credential_state(filename, {"disabled": True}, mode=mode) + disabled_ok = await self.update_credential_state( + filename, {"disabled": True}, mode=mode + ) if disabled_ok: log.warning(f"永久失效凭证已禁用: {filename}") else: @@ -426,11 +556,15 @@ async def _refresh_token( log.error(f"禁用永久失效凭证时出错 {filename}: {e2}") else: # 网络错误或其他临时性错误,不封禁凭证 - log.warning(f"Token刷新失败但非永久性错误 (HTTP {status_code}),不封禁凭证: {filename}") + log.warning( + f"Token刷新失败但非永久性错误 (HTTP {status_code}),不封禁凭证: {filename}" + ) return None - def _is_permanent_refresh_failure(self, error_msg: str, status_code: Optional[int] = None) -> bool: + def _is_permanent_refresh_failure( + self, error_msg: str, status_code: Optional[int] = None + ) -> bool: """ 判断是否是凭证永久失效的错误 @@ -476,6 +610,7 @@ def _is_permanent_refresh_failure(self, error_msg: str, status_code: Optional[in log.debug("未匹配到明确的永久失效模式,判定为临时错误") return False + class _CredentialManagerSingleton: """单例包装器,支持懒加载和自动初始化""" @@ -498,6 +633,7 @@ async def _get_or_create(self) -> CredentialManager: def __getattr__(self, name): """代理所有方法调用到真实的 CredentialManager 实例""" + async def _async_wrapper(*args, **kwargs): manager = await self._get_or_create() method = getattr(manager, name) diff --git a/src/httpx_client.py b/src/httpx_client.py index 521fb0865..10b1570e4 100644 --- a/src/httpx_client.py +++ b/src/httpx_client.py @@ -4,19 +4,128 @@ 保持通用性,不与特定业务逻辑耦合 """ +import asyncio from contextlib import asynccontextmanager import json as jsonlib from typing import Any, AsyncGenerator, Dict, Optional import httpx -from config import get_proxy_config +from config import get_ff_http2_pool_tuning, get_proxy_config from log import log class HttpxClientManager: """通用HTTP客户端管理器(启用 HTTP/2 以匹配 Google API 预期)""" + def __init__(self) -> None: + self._request_client: Optional[httpx.AsyncClient] = None + self._streaming_client: Optional[httpx.AsyncClient] = None + self._active_proxy_config: Optional[str] = None + self._client_lock = asyncio.Lock() + + @staticmethod + def _build_limits(streaming: bool) -> httpx.Limits: + if streaming: + return httpx.Limits( + max_connections=120, + max_keepalive_connections=40, + keepalive_expiry=90.0, + ) + return httpx.Limits( + max_connections=80, + max_keepalive_connections=20, + keepalive_expiry=45.0, + ) + + @staticmethod + def _build_default_timeout(streaming: bool) -> httpx.Timeout: + if streaming: + return httpx.Timeout(connect=20.0, read=None, write=30.0, pool=20.0) + return httpx.Timeout(connect=20.0, read=120.0, write=30.0, pool=20.0) + + async def _close_client_safely(self, client: Optional[httpx.AsyncClient]) -> None: + if client is None: + return + try: + await client.aclose() + except Exception as e: + log.warning(f"Error closing httpx client: {e}") + + async def _create_pooled_client( + self, + *, + proxy_config: Optional[str], + streaming: bool, + ) -> httpx.AsyncClient: + client_kwargs = { + "http2": True, + "limits": self._build_limits(streaming), + "timeout": self._build_default_timeout(streaming), + } + if proxy_config: + client_kwargs["proxy"] = proxy_config + return httpx.AsyncClient(**client_kwargs) + + async def _get_or_create_pooled_client( + self, *, streaming: bool + ) -> httpx.AsyncClient: + current_proxy_config = await get_proxy_config() + + stale_request_client: Optional[httpx.AsyncClient] = None + stale_streaming_client: Optional[httpx.AsyncClient] = None + + async with self._client_lock: + proxy_changed = current_proxy_config != self._active_proxy_config + if proxy_changed: + stale_request_client = self._request_client + stale_streaming_client = self._streaming_client + self._request_client = None + self._streaming_client = None + self._active_proxy_config = current_proxy_config + + if self._active_proxy_config is None and current_proxy_config is not None: + self._active_proxy_config = current_proxy_config + + if streaming: + if self._streaming_client is None: + self._streaming_client = await self._create_pooled_client( + proxy_config=current_proxy_config, + streaming=True, + ) + selected_client = self._streaming_client + else: + if self._request_client is None: + self._request_client = await self._create_pooled_client( + proxy_config=current_proxy_config, + streaming=False, + ) + selected_client = self._request_client + + await self._close_client_safely(stale_request_client) + await self._close_client_safely(stale_streaming_client) + + return selected_client + + async def close_all_clients(self) -> None: + stale_request_client: Optional[httpx.AsyncClient] = None + stale_streaming_client: Optional[httpx.AsyncClient] = None + async with self._client_lock: + stale_request_client = self._request_client + stale_streaming_client = self._streaming_client + self._request_client = None + self._streaming_client = None + self._active_proxy_config = None + + await self._close_client_safely(stale_request_client) + await self._close_client_safely(stale_streaming_client) + + @staticmethod + def _should_use_oneoff_client( + timeout: Optional[Any], kwargs: Dict[str, Any] + ) -> bool: + return timeout is not None or bool(kwargs) + async def get_client_kwargs( self, timeout: Optional[float] = 30.0, **kwargs ) -> Dict[str, Any]: @@ -37,31 +146,49 @@ async def get_client_kwargs( @asynccontextmanager async def get_client( - self, timeout: float = 30.0, **kwargs + self, timeout: Optional[Any] = None, **kwargs ) -> AsyncGenerator[httpx.AsyncClient, None]: """获取配置好的异步HTTP客户端""" - client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + if not await get_ff_http2_pool_tuning(): + client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + async with httpx.AsyncClient(**client_kwargs) as client: + yield client + return - async with httpx.AsyncClient(**client_kwargs) as client: - yield client + if self._should_use_oneoff_client(timeout=timeout, kwargs=kwargs): + client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + async with httpx.AsyncClient(**client_kwargs) as client: + yield client + return + + client = await self._get_or_create_pooled_client(streaming=False) + yield client @asynccontextmanager async def get_streaming_client( - self, timeout: Optional[float] = None, **kwargs + self, timeout: Optional[Any] = None, **kwargs ) -> AsyncGenerator[httpx.AsyncClient, None]: """获取用于流式请求的HTTP客户端(无超时限制)""" - client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + if not await get_ff_http2_pool_tuning(): + client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + client = httpx.AsyncClient(**client_kwargs) + try: + yield client + finally: + await self._close_client_safely(client) + return - # 创建独立的客户端实例用于流式处理 - client = httpx.AsyncClient(**client_kwargs) - try: - yield client - finally: - # 确保无论发生什么都关闭客户端 + if self._should_use_oneoff_client(timeout=timeout, kwargs=kwargs): + client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs) + client = httpx.AsyncClient(**client_kwargs) try: - await client.aclose() - except Exception as e: - log.warning(f"Error closing streaming client: {e}") + yield client + finally: + await self._close_client_safely(client) + return + + client = await self._get_or_create_pooled_client(streaming=True) + yield client # 全局HTTP客户端管理器实例 @@ -88,11 +215,15 @@ def _encode_json_body_safely(payload: Any) -> bytes: # 通用的异步方法 async def get_async( - url: str, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0, **kwargs + url: str, + headers: Optional[Dict[str, str]] = None, + timeout: Optional[Any] = 30.0, + **kwargs, ) -> httpx.Response: """通用异步GET请求""" - async with http_client.get_client(timeout=timeout, **kwargs) as client: - return await client.get(url, headers=headers) + request_timeout = kwargs.pop("request_timeout", timeout) + async with http_client.get_client(**kwargs) as client: + return await client.get(url, headers=headers, timeout=request_timeout) async def post_async( @@ -100,11 +231,12 @@ async def post_async( data: Any = None, json: Any = None, headers: Optional[Dict[str, str]] = None, - timeout: float = 600.0, + timeout: Optional[Any] = 600.0, **kwargs, ) -> httpx.Response: """通用异步POST请求""" - async with http_client.get_client(timeout=timeout, **kwargs) as client: + request_timeout = kwargs.pop("request_timeout", timeout) + async with http_client.get_client(**kwargs) as client: if json is not None and data is None: request_headers = dict(headers or {}) request_headers.setdefault("Content-Type", "application/json") @@ -113,8 +245,15 @@ async def post_async( url, content=safe_json_bytes, headers=request_headers, + timeout=request_timeout, ) - return await client.post(url, data=data, json=json, headers=headers) + return await client.post( + url, + data=data, + json=json, + headers=headers, + timeout=request_timeout, + ) async def stream_post_async( @@ -125,12 +264,17 @@ async def stream_post_async( **kwargs, ): """流式异步POST请求""" + request_timeout = kwargs.pop("timeout", None) async with http_client.get_streaming_client(**kwargs) as client: request_headers = dict(headers or {}) request_headers.setdefault("Content-Type", "application/json") safe_json_bytes = _encode_json_body_safely(body) async with client.stream( - "POST", url, content=safe_json_bytes, headers=request_headers + "POST", + url, + content=safe_json_bytes, + headers=request_headers, + timeout=request_timeout, ) as r: # 错误直接返回 if r.status_code != 200: diff --git a/src/panel/config_routes.py b/src/panel/config_routes.py index 5c090fa7b..93dfa1756 100644 --- a/src/panel/config_routes.py +++ b/src/panel/config_routes.py @@ -22,8 +22,6 @@ async def get_config(token: str = Depends(verify_panel_token)): """获取当前配置""" try: - - # 读取当前配置(包括环境变量和TOML文件中的配置) current_config = {} @@ -35,8 +33,12 @@ async def get_config(token: str = Depends(verify_panel_token)): # 代理端点配置 current_config["oauth_proxy_url"] = await config.get_oauth_proxy_url() current_config["googleapis_proxy_url"] = await config.get_googleapis_proxy_url() - current_config["resource_manager_api_url"] = await config.get_resource_manager_api_url() - current_config["service_usage_api_url"] = await config.get_service_usage_api_url() + current_config[ + "resource_manager_api_url" + ] = await config.get_resource_manager_api_url() + current_config[ + "service_usage_api_url" + ] = await config.get_service_usage_api_url() current_config["antigravity_api_url"] = await config.get_antigravity_api_url() # 自动封禁配置 @@ -44,20 +46,52 @@ async def get_config(token: str = Depends(verify_panel_token)): current_config["auto_ban_error_codes"] = await config.get_auto_ban_error_codes() # 429重试配置 - current_config["retry_429_max_retries"] = await config.get_retry_429_max_retries() + current_config[ + "retry_429_max_retries" + ] = await config.get_retry_429_max_retries() current_config["retry_429_enabled"] = await config.get_retry_429_enabled() current_config["retry_429_interval"] = await config.get_retry_429_interval() # 抗截断配置 - current_config["anti_truncation_max_attempts"] = await config.get_anti_truncation_max_attempts() + current_config[ + "anti_truncation_max_attempts" + ] = await config.get_anti_truncation_max_attempts() # 兼容性配置 - current_config["compatibility_mode_enabled"] = await config.get_compatibility_mode_enabled() + current_config[ + "compatibility_mode_enabled" + ] = await config.get_compatibility_mode_enabled() # 思维链返回配置 - current_config["return_thoughts_to_frontend"] = await config.get_return_thoughts_to_frontend() + current_config[ + "return_thoughts_to_frontend" + ] = await config.get_return_thoughts_to_frontend() # Antigravity流式转非流式配置 - current_config["antigravity_stream2nostream"] = await config.get_antigravity_stream2nostream() + current_config[ + "antigravity_stream2nostream" + ] = await config.get_antigravity_stream2nostream() + + # 灰度发布与回滚门禁配置 + current_config["ff_retry_policy_v2"] = await config.get_ff_retry_policy_v2() + current_config["ff_http2_pool_tuning"] = await config.get_ff_http2_pool_tuning() + current_config[ + "ff_converter_fast_path" + ] = await config.get_ff_converter_fast_path() + current_config[ + "ff_preview_credential_scheduler_v2" + ] = await config.get_ff_preview_credential_scheduler_v2() + current_config[ + "rollout_stage_percent" + ] = await config.get_rollout_stage_percent() + current_config[ + "rollback_trigger_latency_p95_ms" + ] = await config.get_rollback_trigger_latency_p95_ms() + current_config[ + "rollback_trigger_throughput_drop_pct" + ] = await config.get_rollback_trigger_throughput_drop_pct() + current_config[ + "rollback_trigger_quality_drop_pct" + ] = await config.get_rollback_trigger_quality_drop_pct() # 保活配置 current_config["keepalive_url"] = await config.get_keepalive_url() @@ -82,7 +116,9 @@ async def get_config(token: str = Depends(verify_panel_token)): if key not in env_locked_keys: current_config[key] = value - return JSONResponse(content={"config": current_config, "env_locked": list(env_locked_keys)}) + return JSONResponse( + content={"config": current_config, "env_locked": list(env_locked_keys)} + ) except Exception as e: log.error(f"获取配置失败: {e}") @@ -90,10 +126,11 @@ async def get_config(token: str = Depends(verify_panel_token)): @router.post("/save") -async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_panel_token)): +async def save_config( + request: ConfigSaveRequest, token: str = Depends(verify_panel_token) +): """保存配置""" try: - new_config = request.config log.debug(f"收到的配置数据: {list(new_config.keys())}") @@ -105,7 +142,9 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa not isinstance(new_config["retry_429_max_retries"], int) or new_config["retry_429_max_retries"] < 0 ): - raise HTTPException(status_code=400, detail="最大429重试次数必须是大于等于0的整数") + raise HTTPException( + status_code=400, detail="最大429重试次数必须是大于等于0的整数" + ) if "retry_429_enabled" in new_config: if not isinstance(new_config["retry_429_enabled"], bool): @@ -116,9 +155,13 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa try: interval = float(new_config["retry_429_interval"]) if interval < 0.01 or interval > 10: - raise HTTPException(status_code=400, detail="429重试间隔必须在0.01-10秒之间") + raise HTTPException( + status_code=400, detail="429重试间隔必须在0.01-10秒之间" + ) except (ValueError, TypeError): - raise HTTPException(status_code=400, detail="429重试间隔必须是有效的数字") + raise HTTPException( + status_code=400, detail="429重试间隔必须是有效的数字" + ) if "anti_truncation_max_attempts" in new_config: if ( @@ -132,15 +175,87 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa if "compatibility_mode_enabled" in new_config: if not isinstance(new_config["compatibility_mode_enabled"], bool): - raise HTTPException(status_code=400, detail="兼容性模式开关必须是布尔值") + raise HTTPException( + status_code=400, detail="兼容性模式开关必须是布尔值" + ) if "return_thoughts_to_frontend" in new_config: if not isinstance(new_config["return_thoughts_to_frontend"], bool): - raise HTTPException(status_code=400, detail="思维链返回开关必须是布尔值") + raise HTTPException( + status_code=400, detail="思维链返回开关必须是布尔值" + ) if "antigravity_stream2nostream" in new_config: if not isinstance(new_config["antigravity_stream2nostream"], bool): - raise HTTPException(status_code=400, detail="Antigravity流式转非流式开关必须是布尔值") + raise HTTPException( + status_code=400, detail="Antigravity流式转非流式开关必须是布尔值" + ) + + feature_flag_names = [ + "ff_retry_policy_v2", + "ff_http2_pool_tuning", + "ff_converter_fast_path", + "ff_preview_credential_scheduler_v2", + ] + for feature_flag_name in feature_flag_names: + if feature_flag_name in new_config and not isinstance( + new_config[feature_flag_name], bool + ): + raise HTTPException( + status_code=400, detail=f"{feature_flag_name} 必须是布尔值" + ) + + if "rollout_stage_percent" in new_config: + try: + rollout_stage = int(new_config["rollout_stage_percent"]) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail="灰度比例必须是有效整数") + if rollout_stage not in (5, 20, 50, 100): + raise HTTPException( + status_code=400, detail="灰度比例仅支持 5/20/50/100" + ) + new_config["rollout_stage_percent"] = rollout_stage + + if "rollback_trigger_latency_p95_ms" in new_config: + try: + latency_p95_ms = float(new_config["rollback_trigger_latency_p95_ms"]) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail="延迟阈值必须是有效数字") + if latency_p95_ms < 0 or latency_p95_ms > 60000: + raise HTTPException( + status_code=400, detail="延迟阈值必须在 0-60000 ms 之间" + ) + new_config["rollback_trigger_latency_p95_ms"] = latency_p95_ms + + if "rollback_trigger_throughput_drop_pct" in new_config: + try: + throughput_drop_pct = float( + new_config["rollback_trigger_throughput_drop_pct"] + ) + except (TypeError, ValueError): + raise HTTPException( + status_code=400, detail="吞吐下降阈值必须是有效数字" + ) + if throughput_drop_pct < 0 or throughput_drop_pct > 100: + raise HTTPException( + status_code=400, detail="吞吐下降阈值必须在 0-100 之间" + ) + new_config["rollback_trigger_throughput_drop_pct"] = throughput_drop_pct + + if "rollback_trigger_quality_drop_pct" in new_config: + try: + quality_drop_pct = float( + new_config["rollback_trigger_quality_drop_pct"] + ) + except (TypeError, ValueError): + raise HTTPException( + status_code=400, detail="质量下降阈值必须是有效数字" + ) + if quality_drop_pct < 0 or quality_drop_pct > 100: + raise HTTPException( + status_code=400, detail="质量下降阈值必须在 0-100 之间" + ) + new_config["rollback_trigger_quality_drop_pct"] = quality_drop_pct # 验证保活配置 if "keepalive_url" in new_config: @@ -151,13 +266,18 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa try: interval = int(new_config["keepalive_interval"]) if interval < 5 or interval > 86400: - raise HTTPException(status_code=400, detail="保活间隔必须在 5-86400 秒之间") + raise HTTPException( + status_code=400, detail="保活间隔必须在 5-86400 秒之间" + ) new_config["keepalive_interval"] = interval except (ValueError, TypeError): raise HTTPException(status_code=400, detail="保活间隔必须是有效整数") # 验证服务器配置 if "host" in new_config: - if not isinstance(new_config["host"], str) or not new_config["host"].strip(): + if ( + not isinstance(new_config["host"], str) + or not new_config["host"].strip() + ): raise HTTPException(status_code=400, detail="服务器主机地址不能为空") if "port" in new_config: @@ -166,7 +286,9 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa or new_config["port"] < 1 or new_config["port"] > 65535 ): - raise HTTPException(status_code=400, detail="端口号必须是1-65535之间的整数") + raise HTTPException( + status_code=400, detail="端口号必须是1-65535之间的整数" + ) if "api_password" in new_config: if not isinstance(new_config["api_password"], str): @@ -213,7 +335,9 @@ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_pa # 构建响应消息 response_data = { "message": "配置保存成功", - "saved_config": {k: v for k, v in new_config.items() if k not in env_locked_keys}, + "saved_config": { + k: v for k, v in new_config.items() if k not in env_locked_keys + }, } return JSONResponse(content=response_data) diff --git a/src/router/geminicli/anthropic.py b/src/router/geminicli/anthropic.py index a01ee21b7..c54f7c321 100644 --- a/src/router/geminicli/anthropic.py +++ b/src/router/geminicli/anthropic.py @@ -3,6 +3,8 @@ 通过GeminiCLI处理Anthropic/Claude格式请求的路由模块 """ +# ruff: noqa: E402 + import sys from pathlib import Path @@ -14,6 +16,8 @@ # 标准库 import asyncio import json +import time +from typing import Any # 第三方库 from fastapi import APIRouter, Depends, HTTPException, Request @@ -58,10 +62,10 @@ # ==================== API 路由 ==================== + @router.post("/v1/messages") async def messages( - claude_request: ClaudeRequest, - token: str = Depends(authenticate_bearer) + claude_request: ClaudeRequest, token: str = Depends(authenticate_bearer) ): """ 处理Anthropic/Claude格式的消息请求(流式和非流式) @@ -97,25 +101,28 @@ async def messages( # 转换为 Gemini 格式 (使用 converter) from src.converter.anthropic2gemini import anthropic_to_gemini_request + + converter_started_at = time.perf_counter() gemini_dict = await anthropic_to_gemini_request(normalized_dict) + converter_cpu_ms = max(0.0, (time.perf_counter() - converter_started_at) * 1000.0) + converter_cpu_header_value = f"{converter_cpu_ms:.3f}" # anthropic_to_gemini_request 不包含 model 字段,需要手动添加 gemini_dict["model"] = real_model # 规范化 Gemini 请求 (使用 geminicli 模式) from src.converter.gemini_fix import normalize_gemini_request + gemini_dict = await normalize_gemini_request(gemini_dict, mode="geminicli") # 准备API请求格式 - 提取model并将其他字段放入request中 - api_request = { - "model": gemini_dict.pop("model"), - "request": gemini_dict - } + api_request = {"model": gemini_dict.pop("model"), "request": gemini_dict} # ========== 非流式请求 ========== if not is_streaming: # 调用 API 层的非流式请求 from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) # 检查响应状态码 @@ -123,9 +130,19 @@ async def messages( # 提取响应体 if hasattr(response, "body"): - response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + response_raw_body: Any = response.body + response_body = ( + bytes(response_raw_body).decode() + if isinstance(response_raw_body, (bytes, bytearray, memoryview)) + else str(response_raw_body) + ) elif hasattr(response, "content"): - response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + response_content: Any = getattr(response, "content", "") + response_body = ( + bytes(response_content).decode() + if isinstance(response_content, (bytes, bytearray, memoryview)) + else str(response_content) + ) else: response_body = str(response) @@ -137,13 +154,16 @@ async def messages( # 转换为 Anthropic 格式 from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_response = gemini_to_anthropic_response( - gemini_response, - real_model, - status_code + gemini_response, real_model, status_code ) - return JSONResponse(content=anthropic_response, status_code=status_code) + return JSONResponse( + content=anthropic_response, + status_code=status_code, + headers={"x-gcli-obs-converter-cpu-ms": converter_cpu_header_value}, + ) # ========== 流式请求 ========== @@ -156,11 +176,14 @@ async def fake_stream_generator(): # 异步发送实际请求 async def get_response(): from src.api.geminicli import non_stream_request + response = await non_stream_request(body=api_request) return response # 创建请求任务 - response_task = create_managed_task(get_response(), name="anthropic_fake_stream_request") + response_task = create_managed_task( + get_response(), name="anthropic_fake_stream_request" + ) try: # 每3秒发送一次心跳,直到收到响应 @@ -191,23 +214,32 @@ async def get_response(): # 检查响应状态码 if hasattr(response, "status_code") and response.status_code != 200: # 错误响应 - 提取错误信息并以SSE格式返回 - log.error(f"Fake streaming got error response: status={response.status_code}") + log.error( + f"Fake streaming got error response: status={response.status_code}" + ) raw = None if hasattr(response, "body") and response.body: - raw = response.body.decode('utf-8') if isinstance(response.body, bytes) else response.body + raw = ( + response.body.decode("utf-8") + if isinstance(response.body, bytes) + else response.body + ) elif hasattr(response, "content") and response.content: - raw = response.content.decode('utf-8') if isinstance(response.content, bytes) else response.content + raw = ( + response.content.decode("utf-8") + if isinstance(response.content, bytes) + else response.content + ) error_body = raw or "" try: error_data = json.loads(error_body) # 转换错误为 Anthropic 格式 from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( - error_data, - real_model, - response.status_code + error_data, real_model, response.status_code ) yield f"data: {json.dumps(anthropic_error)}\n\n".encode() except Exception: @@ -218,9 +250,17 @@ async def get_response(): # 处理成功响应 - 提取响应内容 if hasattr(response, "body"): - response_body = response.body.decode() if isinstance(response.body, bytes) else response.body + response_body = ( + response.body.decode() + if isinstance(response.body, bytes) + else response.body + ) elif hasattr(response, "content"): - response_body = response.content.decode() if isinstance(response.content, bytes) else response.content + response_body = ( + response.content.decode() + if isinstance(response.content, bytes) + else response.content + ) else: response_body = str(response) @@ -230,30 +270,39 @@ async def get_response(): # 检查是否是错误响应(有些错误可能status_code是200但包含error字段) if "error" in gemini_response: - log.error(f"Fake streaming got error in response body: {gemini_response['error']}") + log.error( + f"Fake streaming got error in response body: {gemini_response['error']}" + ) # 转换错误为 Anthropic 格式 from src.converter.anthropic2gemini import gemini_to_anthropic_response + anthropic_error = gemini_to_anthropic_response( - gemini_response, - real_model, - 200 + gemini_response, real_model, 200 ) yield f"data: {json.dumps(anthropic_error)}\n\n".encode() yield "data: [DONE]\n\n".encode() return # 使用统一的解析函数 - content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) + content, reasoning_content, finish_reason, images = ( + parse_response_for_fake_stream(gemini_response) + ) log.debug(f"Anthropic extracted content: {content}") - log.debug(f"Anthropic extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") + log.debug( + f"Anthropic extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}..." + ) log.debug(f"Anthropic extracted images count: {len(images)}") # 构建响应块 - chunks = build_anthropic_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) + chunks = build_anthropic_fake_stream_chunks( + content, reasoning_content, finish_reason, real_model, images + ) for idx, chunk in enumerate(chunks): chunk_json = json.dumps(chunk) - log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") + log.debug( + f"[FAKE_STREAM] Yielding chunk #{idx + 1}: {chunk_json[:200]}" + ) yield f"data: {chunk_json}\n\n".encode() except Exception as e: @@ -261,10 +310,7 @@ async def get_response(): # 构建错误响应 error_chunk = { "type": "error", - "error": { - "type": "api_error", - "message": str(e) - } + "error": {"type": "api_error", "message": str(e)}, } yield f"data: {json.dumps(error_chunk)}\n\n".encode() @@ -286,28 +332,24 @@ async def anti_truncation_generator(): async def stream_request_wrapper(payload): # stream_request 返回异步生成器,需要包装成 StreamingResponse stream_gen = stream_request(body=payload, native=False) - return StreamingResponse(stream_gen, media_type="text/event-stream") + return StreamingResponse(stream_gen, media_type="text/event-stream") # type: ignore[arg-type] # 创建反截断处理器 processor = AntiTruncationStreamProcessor( - stream_request_wrapper, - anti_truncation_payload, - max_attempts + stream_request_wrapper, anti_truncation_payload, max_attempts ) # 包装以确保是bytes流 async def bytes_wrapper(): async for chunk in processor.process_stream(): if isinstance(chunk, str): - yield chunk.encode('utf-8') + yield chunk.encode("utf-8") else: yield chunk # 直接将整个流传递给转换器 async for anthropic_chunk in gemini_stream_to_anthropic_stream( - bytes_wrapper(), - real_model, - 200 + bytes_wrapper(), real_model, 200 ): if anthropic_chunk: yield anthropic_chunk @@ -328,57 +370,68 @@ async def gemini_chunk_wrapper(): if isinstance(chunk, Response): # 错误响应,不进行转换,直接传递 try: - error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') - gemini_error = json.loads(error_content.decode('utf-8')) - from src.converter.anthropic2gemini import gemini_to_anthropic_response + raw_body = chunk.body + if isinstance(raw_body, (bytes, bytearray, memoryview)): + error_content = bytes(raw_body) + else: + error_content = str(raw_body or "").encode("utf-8") + gemini_error = json.loads(error_content.decode("utf-8")) + from src.converter.anthropic2gemini import ( + gemini_to_anthropic_response, + ) + anthropic_error = gemini_to_anthropic_response( - gemini_error, - real_model, - chunk.status_code + gemini_error, real_model, chunk.status_code ) - yield f"data: {json.dumps(anthropic_error)}\n\n".encode('utf-8') + yield f"data: {json.dumps(anthropic_error)}\n\n".encode("utf-8") except Exception: - yield f"data: {json.dumps({'type': 'error', 'error': {'type': 'api_error', 'message': 'Stream error'}})}\n\n".encode('utf-8') + yield f"data: {json.dumps({'type': 'error', 'error': {'type': 'api_error', 'message': 'Stream error'}})}\n\n".encode( + "utf-8" + ) yield b"data: [DONE]\n\n" return else: # 确保是bytes类型 if isinstance(chunk, str): - yield chunk.encode('utf-8') + yield chunk.encode("utf-8") else: yield chunk # 使用转换器处理整个流 async for anthropic_chunk in gemini_stream_to_anthropic_stream( - gemini_chunk_wrapper(), - real_model, - 200 + gemini_chunk_wrapper(), real_model, 200 ): if anthropic_chunk: yield anthropic_chunk # ========== 根据模式选择生成器 ========== if use_fake_streaming: - return StreamingResponse(fake_stream_generator(), media_type="text/event-stream") + response = StreamingResponse( + fake_stream_generator(), media_type="text/event-stream" + ) elif use_anti_truncation: log.info("启用流式抗截断功能") - return StreamingResponse(anti_truncation_generator(), media_type="text/event-stream") + response = StreamingResponse( + anti_truncation_generator(), media_type="text/event-stream" + ) else: - return StreamingResponse(normal_stream_generator(), media_type="text/event-stream") + response = StreamingResponse( + normal_stream_generator(), media_type="text/event-stream" + ) + + response.headers["x-gcli-obs-converter-cpu-ms"] = converter_cpu_header_value + return response @router.post("/v1/messages/count_tokens") -async def count_tokens( - request: Request, - _token: str = Depends(authenticate_bearer) -): +async def count_tokens(request: Request, _token: str = Depends(authenticate_bearer)): """ 处理Anthropic格式的token计数请求 - + Args: request: FastAPI请求对象 _token: Bearer认证令牌(由Depends验证) - + Returns: JSONResponse: 包含input_tokens的响应 """ @@ -387,19 +440,37 @@ async def count_tokens( except Exception as e: return JSONResponse( status_code=400, - content={"type": "error", "error": {"type": "invalid_request_error", "message": f"JSON 解析失败: {str(e)}"}} + content={ + "type": "error", + "error": { + "type": "invalid_request_error", + "message": f"JSON 解析失败: {str(e)}", + }, + }, ) if not isinstance(payload, dict): return JSONResponse( status_code=400, - content={"type": "error", "error": {"type": "invalid_request_error", "message": "请求体必须为 JSON object"}} + content={ + "type": "error", + "error": { + "type": "invalid_request_error", + "message": "请求体必须为 JSON object", + }, + }, ) if not payload.get("model") or not isinstance(payload.get("messages"), list): return JSONResponse( status_code=400, - content={"type": "error", "error": {"type": "invalid_request_error", "message": "缺少必填字段:model / messages"}} + content={ + "type": "error", + "error": { + "type": "invalid_request_error", + "message": "缺少必填字段:model / messages", + }, + }, ) try: @@ -466,7 +537,7 @@ async def count_tokens( "max_tokens": 1024, "messages": [ {"role": "user", "content": "Hello, tell me a joke in one sentence."} - ] + ], } # 测试Bearer令牌(模拟) @@ -477,12 +548,14 @@ def test_non_stream_request(): print("\n" + "=" * 80) print("【测试1】非流式请求 (POST /v1/messages)") print("=" * 80) - print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") + print( + f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n" + ) response = client.post( "/v1/messages", json=test_request_body, - headers={"Authorization": test_token} + headers={"Authorization": test_token}, ) print("非流式响应数据:") @@ -497,7 +570,7 @@ def test_non_stream_request(): # 尝试解析JSON try: json_data = response.json() - print(f"响应内容 (格式化JSON):") + print("响应内容 (格式化JSON):") print(json.dumps(json_data, indent=2, ensure_ascii=False)) except json.JSONDecodeError: print("(非JSON格式)") @@ -513,7 +586,9 @@ def test_stream_request(): stream_request_body = test_request_body.copy() stream_request_body["stream"] = True - print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") + print( + f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n" + ) print("流式响应数据 (每个chunk):") print("-" * 80) @@ -522,7 +597,7 @@ def test_stream_request(): "POST", "/v1/messages", json=stream_request_body, - headers={"Authorization": test_token} + headers={"Authorization": test_token}, ) as response: print(f"状态码: {response.status_code}") print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") @@ -537,24 +612,30 @@ def test_stream_request(): # 解码chunk try: - chunk_str = chunk.decode('utf-8') - print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") + chunk_str = chunk.decode("utf-8") + print( + f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}" + ) # 如果是SSE格式,尝试解析每一行 - if chunk_str.startswith("event: ") or chunk_str.startswith("data: "): + if chunk_str.startswith("event: ") or chunk_str.startswith( + "data: " + ): # 按行分割,处理每个SSE事件 - for line in chunk_str.strip().split('\n'): + for line in chunk_str.strip().split("\n"): line = line.strip() if not line: continue if line == "data: [DONE]": - print(f" => 流结束标记") + print(" => 流结束标记") elif line.startswith("data: "): try: json_str = line[6:] # 去掉 "data: " 前缀 json_data = json.loads(json_str) - print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") + print( + f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}" + ) except Exception as e: print(f" SSE解析失败: {e}") except Exception as e: @@ -572,7 +653,9 @@ def test_fake_stream_request(): fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" fake_stream_request_body["stream"] = True - print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") + print( + f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n" + ) print("假流式响应数据 (每个chunk):") print("-" * 80) @@ -581,7 +664,7 @@ def test_fake_stream_request(): "POST", "/v1/messages", json=fake_stream_request_body, - headers={"Authorization": test_token} + headers={"Authorization": test_token}, ) as response: print(f"状态码: {response.status_code}") print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") @@ -590,14 +673,14 @@ def test_fake_stream_request(): for chunk in response.iter_bytes(): if chunk: chunk_count += 1 - chunk_str = chunk.decode('utf-8') + chunk_str = chunk.decode("utf-8") print(f"\nChunk #{chunk_count}:") print(f" 长度: {len(chunk_str)} 字节") # 解析chunk中的所有SSE事件 events = [] - for line in chunk_str.split('\n'): + for line in chunk_str.split("\n"): line = line.strip() if line.startswith("data: ") or line.startswith("event: "): events.append(line) @@ -637,4 +720,5 @@ def test_fake_stream_request(): except Exception as e: print(f"\n❌ 测试过程中出现异常: {e}") import traceback + traceback.print_exc() diff --git a/src/storage/sqlite_manager.py b/src/storage/sqlite_manager.py index 59b4cfd64..f8be5a9f8 100644 --- a/src/storage/sqlite_manager.py +++ b/src/storage/sqlite_manager.py @@ -5,6 +5,7 @@ import asyncio import json import os +import random import time from typing import Any, Dict, List, Optional, Tuple @@ -40,7 +41,7 @@ class SQLiteManager: ("rotation_order", "INTEGER DEFAULT 0"), ("call_count", "INTEGER DEFAULT 0"), ("created_at", "REAL DEFAULT (unixepoch())"), - ("updated_at", "REAL DEFAULT (unixepoch())") + ("updated_at", "REAL DEFAULT (unixepoch())"), ], "antigravity_credentials": [ ("disabled", "INTEGER DEFAULT 0"), @@ -52,8 +53,8 @@ class SQLiteManager: ("rotation_order", "INTEGER DEFAULT 0"), ("call_count", "INTEGER DEFAULT 0"), ("created_at", "REAL DEFAULT (unixepoch())"), - ("updated_at", "REAL DEFAULT (unixepoch())") - ] + ("updated_at", "REAL DEFAULT (unixepoch())"), + ], } def __init__(self): @@ -120,7 +121,7 @@ async def _ensure_schema_compatibility(self, db: aiosqlite.Connection) -> None: # 检查表是否存在 async with db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name=?", - (table_name,) + (table_name,), ) as cursor: if not await cursor.fetchone(): log.debug(f"Table {table_name} does not exist, will be created") @@ -135,14 +136,20 @@ async def _ensure_schema_compatibility(self, db: aiosqlite.Connection) -> None: for col_name, col_def in columns: if col_name not in existing_columns: try: - await db.execute(f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_def}") + await db.execute( + f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_def}" + ) log.info(f"Added missing column {table_name}.{col_name}") added_count += 1 except Exception as e: - log.error(f"Failed to add column {table_name}.{col_name}: {e}") + log.error( + f"Failed to add column {table_name}.{col_name}: {e}" + ) if added_count > 0: - log.info(f"Table {table_name}: added {added_count} missing column(s)") + log.info( + f"Table {table_name}: added {added_count} missing column(s)" + ) except Exception as e: log.error(f"Error ensuring schema compatibility: {e}") @@ -254,7 +261,7 @@ async def _repair_credential_filenames(self, db: aiosqlite.Connection): # 检查是否会产生冲突 async with db.execute( "SELECT COUNT(*) FROM credentials WHERE filename = ?", - (basename,) + (basename,), ) as check_cursor: count = (await check_cursor.fetchone())[0] @@ -262,21 +269,27 @@ async def _repair_credential_filenames(self, db: aiosqlite.Connection): # 无冲突,直接更新 await db.execute( "UPDATE credentials SET filename = ? WHERE filename = ?", - (basename, filename) + (basename, filename), ) repaired_count += 1 - log.info(f"Repaired credential filename: {filename} -> {basename}") + log.info( + f"Repaired credential filename: {filename} -> {basename}" + ) else: # 有冲突,删除带路径的旧记录(保留 basename 的记录) await db.execute( "DELETE FROM credentials WHERE filename = ?", - (filename,) + (filename,), ) repaired_count += 1 - log.warning(f"Removed duplicate credential with path: {filename} (kept {basename})") + log.warning( + f"Removed duplicate credential with path: {filename} (kept {basename})" + ) # 修复 antigravity_credentials 表 - async with db.execute("SELECT filename FROM antigravity_credentials") as cursor: + async with db.execute( + "SELECT filename FROM antigravity_credentials" + ) as cursor: rows = await cursor.fetchall() for (filename,) in rows: basename = os.path.basename(filename) @@ -284,7 +297,7 @@ async def _repair_credential_filenames(self, db: aiosqlite.Connection): # 检查是否会产生冲突 async with db.execute( "SELECT COUNT(*) FROM antigravity_credentials WHERE filename = ?", - (basename,) + (basename,), ) as check_cursor: count = (await check_cursor.fetchone())[0] @@ -292,18 +305,22 @@ async def _repair_credential_filenames(self, db: aiosqlite.Connection): # 无冲突,直接更新 await db.execute( "UPDATE antigravity_credentials SET filename = ? WHERE filename = ?", - (basename, filename) + (basename, filename), ) repaired_count += 1 - log.info(f"Repaired antigravity credential filename: {filename} -> {basename}") + log.info( + f"Repaired antigravity credential filename: {filename} -> {basename}" + ) else: # 有冲突,删除带路径的旧记录(保留 basename 的记录) await db.execute( "DELETE FROM antigravity_credentials WHERE filename = ?", - (filename,) + (filename,), ) repaired_count += 1 - log.warning(f"Removed duplicate antigravity credential with path: {filename} (kept {basename})") + log.warning( + f"Removed duplicate antigravity credential with path: {filename} (kept {basename})" + ) if repaired_count > 0: log.info(f"Repaired {repaired_count} credential filename(s)") @@ -354,18 +371,23 @@ def _get_table_name(self, mode: str) -> str: elif mode == "geminicli": return "credentials" else: - raise ValueError(f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'") + raise ValueError( + f"Invalid mode: {mode}. Must be 'geminicli' or 'antigravity'" + ) # ============ SQL 方法 ============ async def get_next_available_credential( - self, mode: str = "geminicli", model_name: Optional[str] = None + self, + mode: str = "geminicli", + model_name: Optional[str] = None, + scheduling_hints: Optional[Dict[str, Any]] = None, ) -> Optional[Tuple[str, Dict[str, Any]]]: """ - 随机获取一个可用凭证(负载均衡) + 获取一个可用凭证(健康感知负载均衡) - 未禁用 - 如果提供了 model_name,还会检查模型级冷却和preview状态 - - 随机选择 + - 基于健康评分(in-flight / recent errors / call_count)选择 Args: mode: 凭证模式 ("geminicli" 或 "antigravity") @@ -383,105 +405,298 @@ async def get_next_available_credential( table_name = self._get_table_name(mode) async with aiosqlite.connect(self._db_path) as db: current_time = time.time() + in_flight_map: Dict[str, int] = {} + if isinstance(scheduling_hints, dict): + raw_in_flight = scheduling_hints.get("in_flight") + if isinstance(raw_in_flight, dict): + in_flight_map = { + str(k): int(v) + for k, v in raw_in_flight.items() + if isinstance(v, (int, float)) + } - # 确定模型名用于冷却检查 - if model_name: - # 所有模式都使用完整模型名 - pass + def has_error_code(error_codes_raw: str, code: int) -> bool: + try: + parsed = json.loads(error_codes_raw or "[]") + for item in parsed if isinstance(parsed, list) else []: + if item == code: + return True + if isinstance(item, str): + try: + if int(item) == code: + return True + except ValueError: + continue + except Exception: + return False + return False + + def get_error_count(error_codes_raw: str) -> int: + try: + parsed = json.loads(error_codes_raw or "[]") + return len(parsed) if isinstance(parsed, list) else 0 + except Exception: + return 0 + + def score_candidate(candidate: Dict[str, Any]) -> float: + in_flight = in_flight_map.get(candidate["filename"], 0) + recent_429 = 1 if candidate.get("has_429") else 0 + error_count = int(candidate.get("error_count", 0) or 0) + call_count = int(candidate.get("call_count", 0) or 0) + penalty = ( + (in_flight * 3.0) + + (recent_429 * 4.0) + + (min(error_count, 3) * 1.2) + + (min(call_count, 1000) / 1000.0) + ) + return penalty + random.uniform(0.0, 0.5) + + async def pick_from_candidates( + candidates: List[Dict[str, Any]], + pool_stats: Dict[str, Any], + ) -> Optional[Tuple[str, Dict[str, Any]]]: + if not candidates: + return None + + scored_candidates = [] + for candidate in candidates: + scored_candidates.append( + (score_candidate(candidate), candidate) + ) + scored_candidates.sort(key=lambda item: item[0]) + + selected_score, selected = scored_candidates[0] + selected_candidate: Dict[str, Any] = selected + await db.execute( + f""" + UPDATE {table_name} + SET call_count = call_count + 1, + updated_at = unixepoch() + WHERE filename = ? + """, + (selected_candidate["filename"],), + ) + await db.commit() + + if ( + mode == "geminicli" + and model_name + and "preview" in model_name.lower() + ): + log.info( + "[PREVIEW_POOL] " + f"model={model_name} " + f"available={int(pool_stats.get('available', 0))} " + f"cooling={int(pool_stats.get('cooling', 0))} " + f"ratio_429={float(pool_stats.get('ratio_429', 0.0)):.3f} " + f"selected={selected_candidate['filename']} " + f"score={float(selected_score):.3f} " + f"in_flight={int(in_flight_map.get(selected_candidate['filename'], 0))}" + ) + + return selected_candidate["filename"], json.loads( + selected_candidate["credential_json"] + ) # 根据模式构建查询 if mode == "geminicli": # geminicli 模式,需要处理 preview 状态 async with db.execute(f""" - SELECT filename, credential_data, model_cooldowns, preview + SELECT filename, credential_data, model_cooldowns, preview, error_codes, call_count FROM {table_name} WHERE disabled = 0 - ORDER BY RANDOM() + ORDER BY rotation_order ASC """) as cursor: rows = await cursor.fetchall() - if not model_name: - # 没有提供模型名,返回第一个可用凭证 - if rows: - filename, credential_json, _, _ = rows[0] - credential_data = json.loads(credential_json) - return filename, credential_data + if not rows: return None - # 检查模型是否为 preview 模型 - is_preview_model = "preview" in model_name.lower() - - # 分别收集 preview=False 和 preview=True 的可用凭证 - non_preview_creds = [] - preview_creds = [] + is_preview_model = bool( + model_name and "preview" in model_name.lower() + ) + + all_candidates: List[Dict[str, Any]] = [] + non_preview_candidates: List[Dict[str, Any]] = [] + preview_candidates: List[Dict[str, Any]] = [] + preview_cooling_count = 0 + non_preview_cooling_count = 0 + + for ( + filename, + credential_json, + model_cooldowns_json, + preview, + error_codes_json, + call_count, + ) in rows: + model_cooldowns = json.loads(model_cooldowns_json or "{}") + model_cooldown = ( + model_cooldowns.get(model_name) if model_name else None + ) + in_cooldown = bool( + model_name + and model_cooldown is not None + and current_time < model_cooldown + ) - for filename, credential_json, model_cooldowns_json, preview in rows: - model_cooldowns = json.loads(model_cooldowns_json or '{}') + candidate = { + "filename": filename, + "credential_json": credential_json, + "preview": bool(preview), + "has_429": has_error_code( + error_codes_json or "[]", 429 + ), + "error_count": get_error_count( + error_codes_json or "[]" + ), + "call_count": int(call_count or 0), + } - # 检查该模型是否在冷却中 - model_cooldown = model_cooldowns.get(model_name) - if model_cooldown is None or current_time >= model_cooldown: - # 该模型未冷却或冷却已过期 - if preview: - preview_creds.append((filename, credential_json)) + if in_cooldown: + if candidate["preview"]: + preview_cooling_count += 1 else: - non_preview_creds.append((filename, credential_json)) + non_preview_cooling_count += 1 + continue + + all_candidates.append(candidate) + if candidate["preview"]: + preview_candidates.append(candidate) + else: + non_preview_candidates.append(candidate) + + if not model_name: + ratio_429 = ( + sum(1 for item in all_candidates if item.get("has_429")) + / len(all_candidates) + if all_candidates + else 0.0 + ) + return await pick_from_candidates( + all_candidates, + { + "available": len(all_candidates), + "cooling": preview_cooling_count + + non_preview_cooling_count, + "ratio_429": ratio_429, + }, + ) - # 根据模型类型选择凭证 if is_preview_model: - # preview 模型只能使用 preview=True 的凭证 - if preview_creds: - filename, credential_json = preview_creds[0] - credential_data = json.loads(credential_json) - return filename, credential_data - else: - # 非 preview 模型 - # 除非没有 preview=False 的凭证,否则只使用 preview=False 的凭证 - if non_preview_creds: - # 存在 preview=False 的凭证,只使用它们 - filename, credential_json = non_preview_creds[0] - credential_data = json.loads(credential_json) - return filename, credential_data - elif preview_creds: - # 不存在 preview=False 的凭证,使用 preview=True 作为后备 - filename, credential_json = preview_creds[0] - credential_data = json.loads(credential_json) - return filename, credential_data + ratio_429 = ( + sum( + 1 + for item in preview_candidates + if item.get("has_429") + ) + / len(preview_candidates) + if preview_candidates + else 0.0 + ) + return await pick_from_candidates( + preview_candidates, + { + "available": len(preview_candidates), + "cooling": preview_cooling_count, + "ratio_429": ratio_429, + }, + ) - return None + # 非 preview 模型:优先 non-preview,不足时 fallback preview + primary_candidates = ( + non_preview_candidates or preview_candidates + ) + primary_cooling_count = ( + non_preview_cooling_count + if non_preview_candidates + else preview_cooling_count + ) + ratio_429 = ( + sum(1 for item in primary_candidates if item.get("has_429")) + / len(primary_candidates) + if primary_candidates + else 0.0 + ) + return await pick_from_candidates( + primary_candidates, + { + "available": len(primary_candidates), + "cooling": primary_cooling_count, + "ratio_429": ratio_429, + }, + ) else: # antigravity 模式,不需要处理 preview async with db.execute(f""" - SELECT filename, credential_data, model_cooldowns + SELECT filename, credential_data, model_cooldowns, error_codes, call_count FROM {table_name} WHERE disabled = 0 - ORDER BY RANDOM() + ORDER BY rotation_order ASC """) as cursor: rows = await cursor.fetchall() - # 如果没有提供 model_name,使用第一个可用凭证 - if not model_name: - if rows: - filename, credential_json, _ = rows[0] - credential_data = json.loads(credential_json) - return filename, credential_data + if not rows: return None - # 如果提供了 model_name,检查模型级冷却 - for filename, credential_json, model_cooldowns_json in rows: - model_cooldowns = json.loads(model_cooldowns_json or '{}') + candidates: List[Dict[str, Any]] = [] + cooling_count = 0 + + for ( + filename, + credential_json, + model_cooldowns_json, + error_codes_json, + call_count, + ) in rows: + in_cooldown = False + if model_name: + model_cooldowns = json.loads( + model_cooldowns_json or "{}" + ) + model_cooldown = model_cooldowns.get(model_name) + in_cooldown = bool( + model_cooldown is not None + and current_time < model_cooldown + ) + + if in_cooldown: + cooling_count += 1 + continue - # 检查该模型是否在冷却中 - model_cooldown = model_cooldowns.get(model_name) - if model_cooldown is None or current_time >= model_cooldown: - # 该模型未冷却或冷却已过期 - credential_data = json.loads(credential_json) - return filename, credential_data + candidates.append( + { + "filename": filename, + "credential_json": credential_json, + "has_429": has_error_code( + error_codes_json or "[]", 429 + ), + "error_count": get_error_count( + error_codes_json or "[]" + ), + "call_count": int(call_count or 0), + } + ) - return None + ratio_429 = ( + sum(1 for item in candidates if item.get("has_429")) + / len(candidates) + if candidates + else 0.0 + ) + return await pick_from_candidates( + candidates, + { + "available": len(candidates), + "cooling": cooling_count, + "ratio_429": ratio_429, + }, + ) except Exception as e: - log.error(f"Error getting next available credential (mode={mode}, model_name={model_name}): {e}") + log.error( + f"Error getting next available credential (mode={mode}, model_name={model_name}): {e}" + ) return None async def get_available_credentials_list(self) -> List[str]: @@ -509,7 +724,9 @@ async def get_available_credentials_list(self) -> List[str]: # ============ StorageBackend 协议方法 ============ - async def store_credential(self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli") -> bool: + async def store_credential( + self, filename: str, credential_data: Dict[str, Any], mode: str = "geminicli" + ) -> bool: """存储或更新凭证""" self._ensure_initialized() @@ -520,21 +737,27 @@ async def store_credential(self, filename: str, credential_data: Dict[str, Any], table_name = self._get_table_name(mode) async with aiosqlite.connect(self._db_path) as db: # 检查凭证是否存在 - async with db.execute(f""" + async with db.execute( + f""" SELECT disabled, error_codes, last_success, user_email, rotation_order, call_count FROM {table_name} WHERE filename = ? - """, (filename,)) as cursor: + """, + (filename,), + ) as cursor: existing = await cursor.fetchone() if existing: # 更新现有凭证(保留状态) - await db.execute(f""" + await db.execute( + f""" UPDATE {table_name} SET credential_data = ?, updated_at = unixepoch() WHERE filename = ? - """, (json.dumps(credential_data), filename)) + """, + (json.dumps(credential_data), filename), + ) else: # 插入新凭证 async with db.execute(f""" @@ -543,11 +766,19 @@ async def store_credential(self, filename: str, credential_data: Dict[str, Any], row = await cursor.fetchone() next_order = row[0] - await db.execute(f""" + await db.execute( + f""" INSERT INTO {table_name} (filename, credential_data, rotation_order, last_success) VALUES (?, ?, ?, ?) - """, (filename, json.dumps(credential_data), next_order, time.time())) + """, + ( + filename, + json.dumps(credential_data), + next_order, + time.time(), + ), + ) await db.commit() log.debug(f"Stored credential: {filename} (mode={mode})") @@ -557,7 +788,9 @@ async def store_credential(self, filename: str, credential_data: Dict[str, Any], log.error(f"Error storing credential {filename}: {e}") return False - async def get_credential(self, filename: str, mode: str = "geminicli") -> Optional[Dict[str, Any]]: + async def get_credential( + self, filename: str, mode: str = "geminicli" + ) -> Optional[Dict[str, Any]]: """获取凭证数据""" self._ensure_initialized() @@ -568,9 +801,12 @@ async def get_credential(self, filename: str, mode: str = "geminicli") -> Option table_name = self._get_table_name(mode) async with aiosqlite.connect(self._db_path) as db: # 精确匹配 - async with db.execute(f""" + async with db.execute( + f""" SELECT credential_data FROM {table_name} WHERE filename = ? - """, (filename,)) as cursor: + """, + (filename,), + ) as cursor: row = await cursor.fetchone() if row: return json.loads(row[0]) @@ -609,25 +845,34 @@ async def delete_credential(self, filename: str, mode: str = "geminicli") -> boo table_name = self._get_table_name(mode) async with aiosqlite.connect(self._db_path) as db: # 精确匹配删除 - result = await db.execute(f""" + result = await db.execute( + f""" DELETE FROM {table_name} WHERE filename = ? - """, (filename,)) + """, + (filename,), + ) deleted_count = result.rowcount await db.commit() if deleted_count > 0: - log.debug(f"Deleted {deleted_count} credential(s): {filename} (mode={mode})") + log.debug( + f"Deleted {deleted_count} credential(s): {filename} (mode={mode})" + ) return True else: - log.warning(f"No credential found to delete: {filename} (mode={mode})") + log.warning( + f"No credential found to delete: {filename} (mode={mode})" + ) return False except Exception as e: log.error(f"Error deleting credential {filename}: {e}") return False - async def update_credential_state(self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli") -> bool: + async def update_credential_state( + self, filename: str, state_updates: Dict[str, Any], mode: str = "geminicli" + ) -> bool: """更新凭证状态""" self._ensure_initialized() @@ -636,7 +881,9 @@ async def update_credential_state(self, filename: str, state_updates: Dict[str, try: table_name = self._get_table_name(mode) - log.debug(f"[DB] update_credential_state 开始: filename={filename}, state_updates={state_updates}, mode={mode}, table={table_name}") + log.debug( + f"[DB] update_credential_state 开始: filename={filename}, state_updates={state_updates}, mode={mode}, table={table_name}" + ) # 构建动态 SQL set_clauses = [] @@ -665,7 +912,7 @@ async def update_credential_state(self, filename: str, state_updates: Dict[str, # 精确匹配更新 sql_exact = f""" UPDATE {table_name} - SET {', '.join(set_clauses)} + SET {", ".join(set_clauses)} WHERE filename = ? """ log.debug(f"[DB] 执行精确匹配SQL: {sql_exact}") @@ -681,14 +928,18 @@ async def update_credential_state(self, filename: str, state_updates: Dict[str, log.debug(f"[DB] commit完成") success = updated_count > 0 - log.debug(f"[DB] update_credential_state 结束: success={success}, updated_count={updated_count}") + log.debug( + f"[DB] update_credential_state 结束: success={success}, updated_count={updated_count}" + ) return success except Exception as e: log.error(f"[DB] Error updating credential state {filename}: {e}") return False - async def get_credential_state(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + async def get_credential_state( + self, filename: str, mode: str = "geminicli" + ) -> Dict[str, Any]: """获取凭证状态(不包含error_messages)""" self._ensure_initialized() @@ -700,15 +951,18 @@ async def get_credential_state(self, filename: str, mode: str = "geminicli") -> async with aiosqlite.connect(self._db_path) as db: # 精确匹配 if mode == "geminicli": - async with db.execute(f""" + async with db.execute( + f""" SELECT disabled, error_codes, last_success, user_email, model_cooldowns, preview FROM {table_name} WHERE filename = ? - """, (filename,)) as cursor: + """, + (filename,), + ) as cursor: row = await cursor.fetchone() if row: - error_codes_json = row[1] or '[]' - model_cooldowns_json = row[4] or '{}' + error_codes_json = row[1] or "[]" + model_cooldowns_json = row[4] or "{}" return { "disabled": bool(row[0]), "error_codes": json.loads(error_codes_json), @@ -729,15 +983,18 @@ async def get_credential_state(self, filename: str, mode: str = "geminicli") -> } else: # antigravity 模式 - async with db.execute(f""" + async with db.execute( + f""" SELECT disabled, error_codes, last_success, user_email, model_cooldowns FROM {table_name} WHERE filename = ? - """, (filename,)) as cursor: + """, + (filename,), + ) as cursor: row = await cursor.fetchone() if row: - error_codes_json = row[1] or '[]' - model_cooldowns_json = row[4] or '{}' + error_codes_json = row[1] or "[]" + model_cooldowns_json = row[4] or "{}" return { "disabled": bool(row[0]), "error_codes": json.loads(error_codes_json), @@ -759,7 +1016,9 @@ async def get_credential_state(self, filename: str, mode: str = "geminicli") -> log.error(f"Error getting credential state {filename}: {e}") return {} - async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, Dict[str, Any]]: + async def get_all_credential_states( + self, mode: str = "geminicli" + ) -> Dict[str, Dict[str, Any]]: """获取所有凭证状态(不包含error_messages)""" self._ensure_initialized() @@ -779,14 +1038,15 @@ async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, for row in rows: filename = row[0] - error_codes_json = row[2] or '[]' - model_cooldowns_json = row[5] or '{}' + error_codes_json = row[2] or "[]" + model_cooldowns_json = row[5] or "{}" model_cooldowns = json.loads(model_cooldowns_json) # 自动过滤掉已过期的模型CD if model_cooldowns: model_cooldowns = { - k: v for k, v in model_cooldowns.items() + k: v + for k, v in model_cooldowns.items() if v > current_time } @@ -814,14 +1074,15 @@ async def get_all_credential_states(self, mode: str = "geminicli") -> Dict[str, for row in rows: filename = row[0] - error_codes_json = row[2] or '[]' - model_cooldowns_json = row[5] or '{}' + error_codes_json = row[2] or "[]" + model_cooldowns_json = row[5] or "{}" model_cooldowns = json.loads(model_cooldowns_json) # 自动过滤掉已过期的模型CD if model_cooldowns: model_cooldowns = { - k: v for k, v in model_cooldowns.items() + k: v + for k, v in model_cooldowns.items() if v > current_time } @@ -847,7 +1108,7 @@ async def get_credentials_summary( mode: str = "geminicli", error_code_filter: Optional[str] = None, cooldown_filter: Optional[str] = None, - preview_filter: Optional[str] = None + preview_filter: Optional[str] = None, ) -> Dict[str, Any]: """ 获取凭证的摘要信息(不包含完整凭证数据)- 支持分页和状态筛选 @@ -895,7 +1156,10 @@ async def get_credentials_summary( filter_value = None filter_int = None - if error_code_filter and str(error_code_filter).strip().lower() != "all": + if ( + error_code_filter + and str(error_code_filter).strip().lower() != "all" + ): filter_value = str(error_code_filter).strip() try: filter_int = int(filter_value) @@ -933,15 +1197,16 @@ async def get_credentials_summary( for row in all_rows: filename = row[0] - error_codes_json = row[2] or '[]' - model_cooldowns_json = row[6] or '{}' + error_codes_json = row[2] or "[]" + model_cooldowns_json = row[6] or "{}" model_cooldowns = json.loads(model_cooldowns_json) # 自动过滤掉已过期的模型CD active_cooldowns = {} if model_cooldowns: active_cooldowns = { - k: v for k, v in model_cooldowns.items() + k: v + for k, v in model_cooldowns.items() if v > current_time } @@ -974,7 +1239,9 @@ async def get_credentials_summary( # preview状态只对geminicli模式有效 if mode == "geminicli": - summary["preview"] = bool(row[7]) if row[7] is not None else True + summary["preview"] = ( + bool(row[7]) if row[7] is not None else True + ) # 应用 preview 筛选(仅对 geminicli 模式) if mode == "geminicli" and preview_filter: @@ -1000,7 +1267,7 @@ async def get_credentials_summary( # 应用分页 total_count = len(all_summaries) if limit is not None: - summaries = all_summaries[offset:offset + limit] + summaries = all_summaries[offset : offset + limit] else: summaries = all_summaries[offset:] @@ -1022,7 +1289,9 @@ async def get_credentials_summary( "stats": {"total": 0, "normal": 0, "disabled": 0}, } - async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> Dict[str, Any]: + async def get_duplicate_credentials_by_email( + self, mode: str = "geminicli" + ) -> Dict[str, Any]: """ 获取按邮箱分组的重复凭证信息(只查询邮箱和文件名,不加载完整凭证数据) 用于去重操作 @@ -1069,12 +1338,14 @@ async def get_duplicate_credentials_by_email(self, mode: str = "geminicli") -> D for email, files in email_to_files.items(): if len(files) > 1: # 保留第一个文件,其他为重复 - duplicate_groups.append({ - "email": email, - "kept_file": files[0], - "duplicate_files": files[1:], - "duplicate_count": len(files) - 1, - }) + duplicate_groups.append( + { + "email": email, + "kept_file": files[0], + "duplicate_files": files[1:], + "duplicate_count": len(files) - 1, + } + ) total_duplicate_count += len(files) - 1 return { @@ -1107,13 +1378,16 @@ async def set_config(self, key: str, value: Any) -> bool: try: async with aiosqlite.connect(self._db_path) as db: - await db.execute(""" + await db.execute( + """ INSERT INTO config (key, value, updated_at) VALUES (?, ?, unixepoch()) ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at - """, (key, json.dumps(value))) + """, + (key, json.dumps(value)), + ) await db.commit() # 更新内存缓存 @@ -1158,7 +1432,9 @@ async def delete_config(self, key: str) -> bool: log.error(f"Error deleting config {key}: {e}") return False - async def get_credential_errors(self, filename: str, mode: str = "geminicli") -> Dict[str, Any]: + async def get_credential_errors( + self, filename: str, mode: str = "geminicli" + ) -> Dict[str, Any]: """ 专门获取凭证的错误信息(包含 error_codes 和 error_messages) @@ -1178,14 +1454,17 @@ async def get_credential_errors(self, filename: str, mode: str = "geminicli") -> table_name = self._get_table_name(mode) async with aiosqlite.connect(self._db_path) as db: # 精确匹配 - async with db.execute(f""" + async with db.execute( + f""" SELECT error_codes, error_messages FROM {table_name} WHERE filename = ? - """, (filename,)) as cursor: + """, + (filename,), + ) as cursor: row = await cursor.fetchone() if row: - error_codes_json = row[0] or '[]' - error_messages_json = row[1] or '[]' + error_codes_json = row[0] or "[]" + error_messages_json = row[1] or "[]" return { "filename": filename, "error_codes": json.loads(error_codes_json), @@ -1205,7 +1484,7 @@ async def get_credential_errors(self, filename: str, mode: str = "geminicli") -> "filename": filename, "error_codes": [], "error_messages": [], - "error": str(e) + "error": str(e), } # ============ 模型级冷却管理 ============ @@ -1215,7 +1494,7 @@ async def set_model_cooldown( filename: str, model_name: str, cooldown_until: Optional[float], - mode: str = "geminicli" + mode: str = "geminicli", ) -> bool: """ 设置特定模型的冷却时间 @@ -1238,16 +1517,19 @@ async def set_model_cooldown( table_name = self._get_table_name(mode) async with aiosqlite.connect(self._db_path) as db: # 获取当前的 model_cooldowns - async with db.execute(f""" + async with db.execute( + f""" SELECT model_cooldowns FROM {table_name} WHERE filename = ? - """, (filename,)) as cursor: + """, + (filename,), + ) as cursor: row = await cursor.fetchone() if not row: log.warning(f"Credential {filename} not found") return False - model_cooldowns = json.loads(row[0] or '{}') + model_cooldowns = json.loads(row[0] or "{}") # 更新或删除指定模型的冷却时间 if cooldown_until is None: @@ -1256,15 +1538,20 @@ async def set_model_cooldown( model_cooldowns[model_name] = cooldown_until # 写回数据库 - await db.execute(f""" + await db.execute( + f""" UPDATE {table_name} SET model_cooldowns = ?, updated_at = unixepoch() WHERE filename = ? - """, (json.dumps(model_cooldowns), filename)) + """, + (json.dumps(model_cooldowns), filename), + ) await db.commit() - log.debug(f"Set model cooldown: {filename}, model_name={model_name}, cooldown_until={cooldown_until}") + log.debug( + f"Set model cooldown: {filename}, model_name={model_name}, cooldown_until={cooldown_until}" + ) return True except Exception as e: @@ -1272,10 +1559,7 @@ async def set_model_cooldown( return False async def record_success( - self, - filename: str, - model_name: Optional[str] = None, - mode: str = "geminicli" + self, filename: str, model_name: Optional[str] = None, mode: str = "geminicli" ) -> None: """ 成功调用后的条件写入: @@ -1290,7 +1574,8 @@ async def record_success( table_name = self._get_table_name(mode) async with aiosqlite.connect(self._db_path) as db: # 条件写入:只有 error_codes 非空时才触发 - await db.execute(f""" + await db.execute( + f""" UPDATE {table_name} SET last_success = unixepoch(), error_codes = '[]', @@ -1298,25 +1583,33 @@ async def record_success( updated_at = unixepoch() WHERE filename = ? AND (error_codes IS NOT NULL AND error_codes != '[]' AND error_codes != '') - """, (filename,)) + """, + (filename,), + ) # 条件删除模型冷却:只有模型键存在时才写入 if model_name: - async with db.execute(f""" + async with db.execute( + f""" SELECT model_cooldowns FROM {table_name} WHERE filename = ? - """, (filename,)) as cursor: + """, + (filename,), + ) as cursor: row = await cursor.fetchone() if row: - cooldowns = json.loads(row[0] or '{}') + cooldowns = json.loads(row[0] or "{}") if model_name in cooldowns: cooldowns.pop(model_name) - await db.execute(f""" + await db.execute( + f""" UPDATE {table_name} SET model_cooldowns = ?, updated_at = unixepoch() WHERE filename = ? - """, (json.dumps(cooldowns), filename)) + """, + (json.dumps(cooldowns), filename), + ) await db.commit() except Exception as e: - log.error(f"Error recording success for {filename}: {e}") \ No newline at end of file + log.error(f"Error recording success for {filename}: {e}") diff --git a/test_anthropic2gemini_tools_schema.py b/test_anthropic2gemini_tools_schema.py index 09bf1ccb7..eadfc5e1c 100644 --- a/test_anthropic2gemini_tools_schema.py +++ b/test_anthropic2gemini_tools_schema.py @@ -1,6 +1,9 @@ import asyncio -from src.converter.anthropic2gemini import anthropic_to_gemini_request +from src.converter.anthropic2gemini import ( + anthropic_to_gemini_request, + _can_use_text_only_fast_path, +) def test_anthropic_tools_schema_shorthand_object_is_normalized(): @@ -30,3 +33,142 @@ def test_anthropic_tools_schema_shorthand_object_is_normalized(): assert params["type"] == "object" assert params["properties"]["key"]["type"] == "string" assert params["properties"]["value"]["type"] == "object" + + +def test_fast_path_guard_accepts_simple_text_messages(): + payload = { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": [{"type": "text", "text": "hi"}]}, + ], + "tools": [], + } + + assert _can_use_text_only_fast_path(payload) is True + + +def test_fast_path_guard_rejects_tool_image_and_thinking_paths(): + tool_payload = { + "messages": [{"role": "user", "content": "hello"}], + "tools": [{"name": "search", "input_schema": {"type": "object"}}], + } + image_payload = { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "AAAA", + }, + } + ], + } + ] + } + thinking_payload = { + "messages": [{"role": "user", "content": "hello"}], + "thinking": {"type": "enabled", "budget_tokens": 256}, + } + + assert _can_use_text_only_fast_path(tool_payload) is False + assert _can_use_text_only_fast_path(image_payload) is False + assert _can_use_text_only_fast_path(thinking_payload) is False + + +def test_fast_path_text_conversion_keeps_expected_request_shape(): + payload = { + "model": "gemini-2.5-pro", + "max_tokens": 128, + "messages": [ + {"role": "system", "content": "be concise"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": [{"type": "text", "text": "hi"}]}, + ], + "tools": [], + } + + gemini_request = asyncio.run(anthropic_to_gemini_request(payload)) + + assert gemini_request["contents"] == [ + {"role": "user", "parts": [{"text": "hello"}]}, + {"role": "model", "parts": [{"text": "hi"}]}, + ] + assert "tools" not in gemini_request + assert "toolConfig" not in gemini_request + assert gemini_request["systemInstruction"]["parts"][0]["text"] == "be concise" + + +def test_non_fast_path_still_preserves_tool_and_thinking_structure(): + payload = { + "model": "gemini-2.5-pro", + "max_tokens": 256, + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": "plan", + "thoughtSignature": "abcdefghijk", + } + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_1", + "name": "fetch_weather", + "input": {"city": "shenzhen"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_1", + "name": "fetch_weather", + "content": [{"type": "text", "text": "sunny"}], + } + ], + }, + ], + "tools": [ + { + "name": "fetch_weather", + "description": "Fetch city weather", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + } + ], + "tool_choice": {"type": "tool", "name": "fetch_weather"}, + "thinking": {"type": "enabled", "budget_tokens": 128}, + } + + gemini_request = asyncio.run(anthropic_to_gemini_request(payload)) + + assert "tools" in gemini_request + assert gemini_request["toolConfig"]["functionCallingConfig"][ + "allowedFunctionNames" + ] == ["fetch_weather"] + assert gemini_request["generationConfig"]["thinkingConfig"]["thinkingBudget"] == 128 + + all_parts = [ + part + for content in gemini_request["contents"] + for part in content.get("parts", []) + if isinstance(part, dict) + ] + assert any(part.get("thought") is True for part in all_parts) + assert any("functionCall" in part for part in all_parts) + assert any("functionResponse" in part for part in all_parts) diff --git a/test_geminicli_stream_request.py b/test_geminicli_stream_request.py index 864b66729..73551475e 100644 --- a/test_geminicli_stream_request.py +++ b/test_geminicli_stream_request.py @@ -5,6 +5,7 @@ from fastapi import Response from src.api import geminicli +from src.api import utils as api_utils from src.models import GeminiRequest from src.router.geminicli import gemini as gemini_router @@ -30,6 +31,10 @@ async def _collect(gen): return items +async def _retry_policy_v2_enabled(): + return True + + def test_stream_request_handles_disable_code_without_unbound_state(monkeypatch): called = {"record_error": 0, "handle_retry": 0} @@ -60,6 +65,7 @@ async def fake_stream_post_async(*args, **kwargs): ) monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) monkeypatch.setattr( geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint ) @@ -110,6 +116,7 @@ async def fake_stream_post_async(*args, **kwargs): yield "data: test" monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) monkeypatch.setattr( geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint ) @@ -137,6 +144,412 @@ async def fake_stream_post_async(*args, **kwargs): assert captured["timeout"].read is None +def test_stream_request_does_not_retry_after_first_chunk(monkeypatch): + called = {"stream_post": 0, "handle_retry": 0} + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 1, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_get_proxy_config(): + return None + + async def fake_record_api_call_success(*args, **kwargs): + return None + + async def fake_record_api_call_error(*args, **kwargs): + return None + + async def fake_handle_error_with_retry(*args, **kwargs): + called["handle_retry"] += 1 + return True + + async def fake_stream_post_async(*args, **kwargs): + called["stream_post"] += 1 + yield "data: token" + yield Response( + content=b'{"error":"too many requests"}', + status_code=429, + media_type="application/json", + ) + + monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr(geminicli, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + geminicli, "record_api_call_success", fake_record_api_call_success + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr(geminicli, "stream_post_async", fake_stream_post_async) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + chunks = asyncio.run(_collect(geminicli.stream_request(body=body, native=False))) + + assert called["stream_post"] == 1 + assert called["handle_retry"] == 0 + assert chunks[0] == "data: token" + assert isinstance(chunks[1], Response) + assert chunks[1].status_code == 429 + + +def test_non_stream_request_waits_once_per_retry_attempt(monkeypatch): + wait_calls = {"count": 0} + responses = [ + httpx.Response( + status_code=429, + json={"error": {"code": 429, "message": "too many requests"}}, + ), + httpx.Response(status_code=200, json={"ok": True}), + ] + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 1, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_record_api_call_success(*args, **kwargs): + return None + + async def fake_record_api_call_error(*args, **kwargs): + return None + + async def fake_handle_error_with_retry(*args, **kwargs): + wait_calls["count"] += 1 + metrics_ctx = kwargs.get("metrics_ctx") + if isinstance(metrics_ctx, dict): + metrics_ctx["retry_count"] = int(metrics_ctx.get("retry_count", 0) or 0) + 1 + return True + + async def fake_sleep_with_observability(*args, **kwargs): + wait_calls["count"] += 1 + + async def fake_post_async(*args, **kwargs): + return responses.pop(0) + + monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr( + geminicli, "record_api_call_success", fake_record_api_call_success + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr( + geminicli, "_sleep_with_observability", fake_sleep_with_observability + ) + monkeypatch.setattr(geminicli, "post_async", fake_post_async) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + response = asyncio.run(geminicli.non_stream_request(body=body)) + + assert response.status_code == 200 + assert wait_calls["count"] == 1 + + +class _SequentialCredentialManager: + def __init__(self): + self._credentials = [ + ("cred-a.json", {"token": "token-a", "project_id": "project-a"}), + ("cred-b.json", {"token": "token-b", "project_id": "project-b"}), + ] + self._index = 0 + + async def get_valid_credential(self, mode=None, model_name=None): + idx = min(self._index, len(self._credentials) - 1) + self._index += 1 + return self._credentials[idx] + + async def update_credential_state(self, *args, **kwargs): + return None + + async def record_api_call_result(self, *args, **kwargs): + return None + + async def set_cred_disabled(self, *args, **kwargs): + return None + + +def test_stream_request_retry_policy_v2_on_keeps_current_credential(monkeypatch): + captured = {"auth": [], "flags": []} + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 1, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_get_proxy_config(): + return None + + async def fake_record_api_call_success(*args, **kwargs): + return None + + async def fake_record_api_call_error(*args, **kwargs): + return None + + async def fake_handle_error_with_retry(*args, **kwargs): + captured["flags"].append(kwargs.get("retry_policy_v2_enabled")) + return True + + async def fake_stream_post_async(*args, **kwargs): + captured["auth"].append(kwargs.get("headers", {}).get("Authorization")) + if len(captured["auth"]) == 1: + yield Response( + content=b'{"error":{"message":"too many requests"}}', + status_code=429, + media_type="application/json", + ) + else: + yield "data: ok" + + monkeypatch.setattr(geminicli, "credential_manager", _SequentialCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr(geminicli, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + geminicli, "record_api_call_success", fake_record_api_call_success + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr(geminicli, "stream_post_async", fake_stream_post_async) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + chunks = asyncio.run(_collect(geminicli.stream_request(body=body, native=False))) + + assert chunks == ["data: ok"] + assert captured["flags"] == [True] + assert captured["auth"] == ["Bearer token-a", "Bearer token-a"] + + +def test_stream_request_retry_policy_v2_off_rotates_credential(monkeypatch): + captured = {"auth": [], "flags": []} + + async def fake_get_ff_retry_policy_v2(): + return False + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 1, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_get_proxy_config(): + return None + + async def fake_record_api_call_success(*args, **kwargs): + return None + + async def fake_record_api_call_error(*args, **kwargs): + return None + + async def fake_handle_error_with_retry(*args, **kwargs): + captured["flags"].append(kwargs.get("retry_policy_v2_enabled")) + return True + + async def fake_stream_post_async(*args, **kwargs): + captured["auth"].append(kwargs.get("headers", {}).get("Authorization")) + if len(captured["auth"]) == 1: + yield Response( + content=b'{"error":{"message":"too many requests"}}', + status_code=429, + media_type="application/json", + ) + else: + yield "data: ok" + + monkeypatch.setattr( + geminicli, "get_ff_retry_policy_v2", fake_get_ff_retry_policy_v2 + ) + monkeypatch.setattr(geminicli, "credential_manager", _SequentialCredentialManager()) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr(geminicli, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + geminicli, "record_api_call_success", fake_record_api_call_success + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr(geminicli, "stream_post_async", fake_stream_post_async) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + chunks = asyncio.run(_collect(geminicli.stream_request(body=body, native=False))) + + assert chunks == ["data: ok"] + assert captured["flags"] == [False] + assert captured["auth"] == ["Bearer token-a", "Bearer token-b"] + + +def test_non_stream_request_passes_retry_policy_v2_flag_to_helper(monkeypatch): + captured = {"flags": []} + responses = [ + httpx.Response( + status_code=429, + json={"error": {"code": 429, "message": "too many requests"}}, + ), + httpx.Response(status_code=200, json={"ok": True}), + ] + + async def fake_get_ff_retry_policy_v2(): + return False + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 1, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_record_api_call_success(*args, **kwargs): + return None + + async def fake_record_api_call_error(*args, **kwargs): + return None + + async def fake_handle_error_with_retry(*args, **kwargs): + captured["flags"].append(kwargs.get("retry_policy_v2_enabled")) + return True + + async def fake_post_async(*args, **kwargs): + return responses.pop(0) + + monkeypatch.setattr( + geminicli, "get_ff_retry_policy_v2", fake_get_ff_retry_policy_v2 + ) + monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr( + geminicli, "record_api_call_success", fake_record_api_call_success + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr(geminicli, "post_async", fake_post_async) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + response = asyncio.run(geminicli.non_stream_request(body=body)) + + assert response.status_code == 200 + assert captured["flags"] == [False] + + +def test_handle_error_with_retry_v2_controls_internal_sleep(monkeypatch): + sleep_calls = [] + + async def fake_check_should_auto_ban(*args, **kwargs): + return False + + async def fake_sleep(seconds): + sleep_calls.append(seconds) + + monkeypatch.setattr(api_utils, "check_should_auto_ban", fake_check_should_auto_ban) + monkeypatch.setattr(api_utils.asyncio, "sleep", fake_sleep) + + should_retry = asyncio.run( + api_utils.handle_error_with_retry( + credential_manager=None, + status_code=429, + credential_name="dummy.json", + retry_enabled=True, + attempt=0, + max_retries=1, + retry_interval=0.3, + retry_policy_v2_enabled=True, + ) + ) + + assert should_retry is True + assert sleep_calls == [0.3] + + sleep_calls.clear() + should_retry = asyncio.run( + api_utils.handle_error_with_retry( + credential_manager=None, + status_code=429, + credential_name="dummy.json", + retry_enabled=True, + attempt=0, + max_retries=1, + retry_interval=0.3, + retry_policy_v2_enabled=False, + ) + ) + + assert should_retry is True + assert sleep_calls == [] + + def test_stream_router_injects_thought_signature_for_function_call(monkeypatch): captured: dict[str, Any] = {"body": None, "native": None} diff --git a/test_httpx_client.py b/test_httpx_client.py index e5b3b38fc..e014544e6 100644 --- a/test_httpx_client.py +++ b/test_httpx_client.py @@ -1,4 +1,7 @@ -from src.httpx_client import _encode_json_body_safely +from typing import Any + +from src import httpx_client +from src.httpx_client import HttpxClientManager, _encode_json_body_safely, post_async def test_encode_json_body_safely_keeps_valid_unicode(): @@ -11,3 +14,133 @@ def test_encode_json_body_safely_falls_back_on_lone_surrogate(): payload = {"x": "\ud83d"} encoded = _encode_json_body_safely(payload) assert encoded.decode("utf-8") == '{"x":"\\ud83d"}' + + +class _FakeAsyncClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.closed = False + self.post_calls = [] + + async def aclose(self): + self.closed = True + + async def post(self, *args, **kwargs): + self.post_calls.append((args, kwargs)) + + +async def test_httpx_manager_reuses_pooled_clients(monkeypatch): + created_clients = [] + + async def fake_get_proxy_config(): + return None + + async def fake_get_ff_http2_pool_tuning(): + return True + + def fake_async_client(**kwargs): + client = _FakeAsyncClient(**kwargs) + created_clients.append(client) + return client + + monkeypatch.setattr(httpx_client, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + httpx_client, "get_ff_http2_pool_tuning", fake_get_ff_http2_pool_tuning + ) + monkeypatch.setattr(httpx_client.httpx, "AsyncClient", fake_async_client) + + manager = HttpxClientManager() + + first: Any + second: Any + async with manager.get_client() as first_client: + first = first_client + pass + async with manager.get_client() as second_client: + second = second_client + pass + + assert first is second + assert len(created_clients) == 1 + assert created_clients[0].kwargs["http2"] is True + + await manager.close_all_clients() + assert created_clients[0].closed is True + + +async def test_httpx_manager_refreshes_pool_when_proxy_changes(monkeypatch): + created_clients = [] + proxy_state = {"proxy": "http://proxy-1:8888"} + + async def fake_get_proxy_config(): + return proxy_state["proxy"] + + async def fake_get_ff_http2_pool_tuning(): + return True + + def fake_async_client(**kwargs): + client = _FakeAsyncClient(**kwargs) + created_clients.append(client) + return client + + monkeypatch.setattr(httpx_client, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + httpx_client, "get_ff_http2_pool_tuning", fake_get_ff_http2_pool_tuning + ) + monkeypatch.setattr(httpx_client.httpx, "AsyncClient", fake_async_client) + + manager = HttpxClientManager() + + first: Any + second: Any + async with manager.get_client() as first_client: + first = first_client + pass + + proxy_state["proxy"] = "http://proxy-2:8888" + + async with manager.get_client() as second_client: + second = second_client + pass + + assert first is not second + assert first.closed is True + assert second.kwargs["proxy"] == "http://proxy-2:8888" + + await manager.close_all_clients() + + +async def test_post_async_applies_request_timeout_with_pooled_client(monkeypatch): + created_clients = [] + + async def fake_get_proxy_config(): + return None + + async def fake_get_ff_http2_pool_tuning(): + return True + + def fake_async_client(**kwargs): + client = _FakeAsyncClient(**kwargs) + created_clients.append(client) + return client + + monkeypatch.setattr(httpx_client, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + httpx_client, "get_ff_http2_pool_tuning", fake_get_ff_http2_pool_tuning + ) + monkeypatch.setattr(httpx_client.httpx, "AsyncClient", fake_async_client) + + original_manager = httpx_client.http_client + manager = HttpxClientManager() + monkeypatch.setattr(httpx_client, "http_client", manager) + + try: + await post_async("https://example.com", data="hello", timeout=123.0) + finally: + await manager.close_all_clients() + monkeypatch.setattr(httpx_client, "http_client", original_manager) + + assert len(created_clients) == 1 + assert created_clients[0].post_calls + _, call_kwargs = created_clients[0].post_calls[0] + assert call_kwargs["timeout"] == 123.0 diff --git a/test_preview_scheduler.py b/test_preview_scheduler.py new file mode 100644 index 000000000..9916cce56 --- /dev/null +++ b/test_preview_scheduler.py @@ -0,0 +1,147 @@ +import asyncio +import time +from typing import Any, cast + +import src.credential_manager as credential_manager_module +from src.credential_manager import CredentialManager +from src.storage.sqlite_manager import SQLiteManager + + +def test_sqlite_preview_selection_keeps_preview_constraints(monkeypatch, tmp_path): + monkeypatch.setenv("CREDENTIALS_DIR", str(tmp_path)) + + async def _run(): + manager = SQLiteManager() + await manager.initialize() + + await manager.store_credential( + "non_preview.json", {"token": "t1", "project_id": "p1"}, mode="geminicli" + ) + await manager.store_credential( + "preview.json", {"token": "t2", "project_id": "p2"}, mode="geminicli" + ) + await manager.update_credential_state( + "non_preview.json", {"preview": False}, mode="geminicli" + ) + await manager.update_credential_state( + "preview.json", {"preview": True}, mode="geminicli" + ) + + preview_pick = await manager.get_next_available_credential( + mode="geminicli", + model_name="gemini-3-pro-preview", + scheduling_hints={"in_flight": {"preview.json": 5}}, + ) + assert preview_pick is not None + assert preview_pick[0] == "preview.json" + + non_preview_pick = await manager.get_next_available_credential( + mode="geminicli", + model_name="gemini-2.5-pro", + scheduling_hints={"in_flight": {"non_preview.json": 0}}, + ) + assert non_preview_pick is not None + assert non_preview_pick[0] == "non_preview.json" + + asyncio.run(_run()) + + +def test_sqlite_preview_health_scoring_uses_429_and_inflight(monkeypatch, tmp_path): + monkeypatch.setenv("CREDENTIALS_DIR", str(tmp_path)) + + async def _run(): + manager = SQLiteManager() + await manager.initialize() + + for name in ("p1.json", "p2.json", "p3.json"): + await manager.store_credential( + name, {"token": name, "project_id": name}, mode="geminicli" + ) + await manager.update_credential_state( + name, {"preview": True}, mode="geminicli" + ) + + await manager.update_credential_state( + "p1.json", {"error_codes": [429]}, mode="geminicli" + ) + + await manager.set_model_cooldown( + "p2.json", "gemini-3-pro-preview", time.time() + 60, mode="geminicli" + ) + + pick = await manager.get_next_available_credential( + mode="geminicli", + model_name="gemini-3-pro-preview", + scheduling_hints={"in_flight": {"p3.json": 0, "p1.json": 0}}, + ) + assert pick is not None + assert pick[0] == "p3.json" + + asyncio.run(_run()) + + +def test_credential_manager_passes_inflight_hints_to_scheduler(monkeypatch): + class _Backend: + def __init__(self): + self.hints = [] + + async def get_next_available_credential( + self, mode="geminicli", model_name=None, scheduling_hints=None + ): + self.hints.append(scheduling_hints or {}) + inflight = (scheduling_hints or {}).get("in_flight", {}) + if inflight.get("a.json", 0) > 0: + return "b.json", { + "token": "tb", + "project_id": "pb", + "expiry": "2999-01-01T00:00:00+00:00", + } + return "a.json", { + "token": "ta", + "project_id": "pa", + "expiry": "2999-01-01T00:00:00+00:00", + } + + class _Adapter: + def __init__(self): + self._backend = _Backend() + + manager = CredentialManager() + manager_any = cast(Any, manager) + manager_any._initialized = True + manager_any._storage_adapter = _Adapter() + + async def _flag_on(): + return True + + monkeypatch.setattr( + credential_manager_module, + "get_ff_preview_credential_scheduler_v2", + _flag_on, + ) + + async def _always_valid_token(_): + return False + + monkeypatch.setattr(manager, "_should_refresh_token", _always_valid_token) + + async def _run(): + first = await manager.get_valid_credential( + mode="geminicli", model_name="gemini-3-pro-preview" + ) + second = await manager.get_valid_credential( + mode="geminicli", model_name="gemini-3-pro-preview" + ) + + assert first is not None and second is not None + assert first[0] == "a.json" + assert second[0] == "b.json" + + hints = manager_any._storage_adapter._backend.hints + assert hints[0].get("in_flight", {}) == {} + assert hints[1].get("in_flight", {}).get("a.json") == 1 + + await manager._release_inflight("geminicli", "a.json") + await manager._release_inflight("geminicli", "b.json") + + asyncio.run(_run()) diff --git a/test_rollout_guard.py b/test_rollout_guard.py new file mode 100644 index 000000000..c42ef8d25 --- /dev/null +++ b/test_rollout_guard.py @@ -0,0 +1,347 @@ +import asyncio +import json +from pathlib import Path + +from scripts.perf import rollout_guard + + +def _write_json(path: Path, payload: dict) -> str: + path.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") + return str(path) + + +def _perf_report(*, full_latency_p95: float, reqps: float, tokensps: float) -> dict: + return { + "meta": {"duration": 30}, + "summary": { + "request_count": 12, + "full_latency_ms": {"p95": full_latency_p95}, + "reqps": reqps, + "tokensps": tokensps, + }, + "samples": [{"full_latency_ms": full_latency_p95, "total_tokens": 100}], + } + + +def _quality_report(*, quality_score: float) -> dict: + return { + "summary": { + "quality_score": quality_score, + } + } + + +def test_rollout_guard_all_pass_promotes_next_stage(tmp_path: Path): + before_perf = _write_json( + tmp_path / "before_perf.json", + _perf_report(full_latency_p95=1800, reqps=10.0, tokensps=100.0), + ) + after_perf = _write_json( + tmp_path / "after_perf.json", + _perf_report(full_latency_p95=1400, reqps=9.0, tokensps=92.0), + ) + baseline_quality = _write_json( + tmp_path / "baseline_quality.json", _quality_report(quality_score=90.0) + ) + candidate_quality = _write_json( + tmp_path / "candidate_quality.json", _quality_report(quality_score=85.0) + ) + + result = asyncio.run( + rollout_guard.evaluate_rollout_decision( + before_perf_path=before_perf, + after_perf_path=after_perf, + baseline_quality_path=baseline_quality, + candidate_quality_path=candidate_quality, + rollout_stage_percent=20, + rollback_trigger_latency_p95_ms=2000, + rollback_trigger_throughput_drop_pct=20, + rollback_trigger_quality_drop_pct=10, + apply=False, + ) + ) + + assert result["decision"] == "PROMOTE" + assert result["current_stage_percent"] == 20 + assert result["next_stage_percent"] == 50 + assert result["target_stage_percent"] == 50 + assert result["failed_gates"] == [] + assert result["blocked_gates"] == [] + assert ( + result["thresholds"]["latency_policy_mode"] + == rollout_guard.LATENCY_POLICY_MODE_ABSOLUTE_P95_CAP + ) + assert result["thresholds"]["rollback_trigger_latency_p95_improve_pct"] == 0.0 + assert ( + result["gates"]["latency"]["latency_policy_mode"] + == rollout_guard.LATENCY_POLICY_MODE_ABSOLUTE_P95_CAP + ) + assert result["dry_run"] is True + assert result["applied"] is False + + +def test_rollout_guard_gate_fail_rolls_back(tmp_path: Path): + before_perf = _write_json( + tmp_path / "before_perf.json", + _perf_report(full_latency_p95=1700, reqps=10.0, tokensps=100.0), + ) + after_perf = _write_json( + tmp_path / "after_perf.json", + _perf_report(full_latency_p95=2800, reqps=9.5, tokensps=99.0), + ) + baseline_quality = _write_json( + tmp_path / "baseline_quality.json", _quality_report(quality_score=90.0) + ) + candidate_quality = _write_json( + tmp_path / "candidate_quality.json", _quality_report(quality_score=89.0) + ) + + result = asyncio.run( + rollout_guard.evaluate_rollout_decision( + before_perf_path=before_perf, + after_perf_path=after_perf, + baseline_quality_path=baseline_quality, + candidate_quality_path=candidate_quality, + rollout_stage_percent=50, + rollback_trigger_latency_p95_ms=2500, + rollback_trigger_throughput_drop_pct=20, + rollback_trigger_quality_drop_pct=10, + apply=False, + ) + ) + + assert result["decision"] == "ROLLBACK" + assert result["current_stage_percent"] == 50 + assert result["rollback_stage_percent"] == 20 + assert result["target_stage_percent"] == 20 + assert "latency" in result["failed_gates"] + assert result["blocked_gates"] == [] + assert ( + result["gates"]["latency"]["latency_policy_mode"] + == rollout_guard.LATENCY_POLICY_MODE_ABSOLUTE_P95_CAP + ) + + +def test_rollout_guard_blocked_signal_holds_even_with_apply_requested(tmp_path: Path): + before_payload = _perf_report(full_latency_p95=1700, reqps=10.0, tokensps=100.0) + after_payload = _perf_report(full_latency_p95=1400, reqps=9.5, tokensps=95.0) + after_payload["meta"]["blocker"] = "artifact_incomplete" + + before_perf = _write_json(tmp_path / "before_perf.json", before_payload) + after_perf = _write_json(tmp_path / "after_perf.json", after_payload) + baseline_quality = _write_json( + tmp_path / "baseline_quality.json", _quality_report(quality_score=90.0) + ) + candidate_quality = _write_json( + tmp_path / "candidate_quality.json", _quality_report(quality_score=89.0) + ) + + result = asyncio.run( + rollout_guard.evaluate_rollout_decision( + before_perf_path=before_perf, + after_perf_path=after_perf, + baseline_quality_path=baseline_quality, + candidate_quality_path=candidate_quality, + rollout_stage_percent=20, + rollback_trigger_latency_p95_ms=2500, + rollback_trigger_throughput_drop_pct=20, + rollback_trigger_quality_drop_pct=10, + apply=True, + ) + ) + + assert result["decision"] == "HOLD_BLOCKED" + assert result["target_stage_percent"] == 20 + assert "latency" in result["blocked_gates"] + assert "throughput" in result["blocked_gates"] + assert result["applied"] is False + assert result["apply_skipped_reason"] == "decision_hold_blocked" + + +def test_rollout_guard_relative_latency_mode_promotes_on_required_improvement( + tmp_path: Path, +): + before_perf = _write_json( + tmp_path / "before_perf.json", + _perf_report(full_latency_p95=5000, reqps=10.0, tokensps=100.0), + ) + after_perf = _write_json( + tmp_path / "after_perf.json", + _perf_report(full_latency_p95=3500, reqps=9.5, tokensps=95.0), + ) + baseline_quality = _write_json( + tmp_path / "baseline_quality.json", _quality_report(quality_score=90.0) + ) + candidate_quality = _write_json( + tmp_path / "candidate_quality.json", _quality_report(quality_score=89.0) + ) + + result = asyncio.run( + rollout_guard.evaluate_rollout_decision( + before_perf_path=before_perf, + after_perf_path=after_perf, + baseline_quality_path=baseline_quality, + candidate_quality_path=candidate_quality, + rollout_stage_percent=20, + rollback_trigger_latency_p95_ms=2500, + latency_policy_mode=rollout_guard.LATENCY_POLICY_MODE_RELATIVE_FULL_P95_IMPROVE, + rollback_trigger_latency_p95_improve_pct=20, + rollback_trigger_throughput_drop_pct=20, + rollback_trigger_quality_drop_pct=10, + apply=False, + ) + ) + + assert result["decision"] == "PROMOTE" + assert result["failed_gates"] == [] + assert result["blocked_gates"] == [] + assert ( + result["thresholds"]["latency_policy_mode"] + == rollout_guard.LATENCY_POLICY_MODE_RELATIVE_FULL_P95_IMPROVE + ) + assert result["thresholds"]["rollback_trigger_latency_p95_improve_pct"] == 20.0 + assert ( + result["gates"]["latency"]["reason"] + == "full_latency_p95_improve_within_threshold" + ) + assert result["gates"]["latency"]["full_latency_p95_improve_percent"] == 30.0 + + +def test_rollout_guard_relative_latency_mode_fails_when_improvement_is_too_small( + tmp_path: Path, +): + before_perf = _write_json( + tmp_path / "before_perf.json", + _perf_report(full_latency_p95=5000, reqps=10.0, tokensps=100.0), + ) + after_perf = _write_json( + tmp_path / "after_perf.json", + _perf_report(full_latency_p95=4800, reqps=9.9, tokensps=99.0), + ) + baseline_quality = _write_json( + tmp_path / "baseline_quality.json", _quality_report(quality_score=90.0) + ) + candidate_quality = _write_json( + tmp_path / "candidate_quality.json", _quality_report(quality_score=89.0) + ) + + result = asyncio.run( + rollout_guard.evaluate_rollout_decision( + before_perf_path=before_perf, + after_perf_path=after_perf, + baseline_quality_path=baseline_quality, + candidate_quality_path=candidate_quality, + rollout_stage_percent=50, + rollback_trigger_latency_p95_ms=2500, + latency_policy_mode=rollout_guard.LATENCY_POLICY_MODE_RELATIVE_FULL_P95_IMPROVE, + rollback_trigger_latency_p95_improve_pct=10, + rollback_trigger_throughput_drop_pct=20, + rollback_trigger_quality_drop_pct=10, + apply=False, + ) + ) + + assert result["decision"] == "ROLLBACK" + assert result["target_stage_percent"] == 20 + assert "latency" in result["failed_gates"] + assert result["gates"]["latency"]["status"] == "FAIL" + assert ( + result["gates"]["latency"]["reason"] + == "full_latency_p95_improve_below_threshold" + ) + + +def test_stage_ladder_bounds_floor_and_ceiling(): + assert rollout_guard.get_previous_stage_percent(5) == 5 + assert rollout_guard.get_next_stage_percent(100) == 100 + + +def test_stage_percent_to_feature_flags_mapping(): + assert rollout_guard.stage_percent_to_feature_flags(5) == { + "ff_retry_policy_v2": True, + "ff_http2_pool_tuning": False, + "ff_converter_fast_path": False, + "ff_preview_credential_scheduler_v2": False, + } + assert rollout_guard.stage_percent_to_feature_flags(20) == { + "ff_retry_policy_v2": True, + "ff_http2_pool_tuning": True, + "ff_converter_fast_path": False, + "ff_preview_credential_scheduler_v2": False, + } + assert rollout_guard.stage_percent_to_feature_flags(50) == { + "ff_retry_policy_v2": True, + "ff_http2_pool_tuning": True, + "ff_converter_fast_path": True, + "ff_preview_credential_scheduler_v2": False, + } + assert rollout_guard.stage_percent_to_feature_flags(100) == { + "ff_retry_policy_v2": True, + "ff_http2_pool_tuning": True, + "ff_converter_fast_path": True, + "ff_preview_credential_scheduler_v2": True, + } + + +def test_apply_mode_persists_stage_and_feature_flags(monkeypatch, tmp_path: Path): + before_perf = _write_json( + tmp_path / "before_perf.json", + _perf_report(full_latency_p95=1800, reqps=10.0, tokensps=100.0), + ) + after_perf = _write_json( + tmp_path / "after_perf.json", + _perf_report(full_latency_p95=1400, reqps=9.0, tokensps=95.0), + ) + baseline_quality = _write_json( + tmp_path / "baseline_quality.json", _quality_report(quality_score=90.0) + ) + candidate_quality = _write_json( + tmp_path / "candidate_quality.json", _quality_report(quality_score=89.0) + ) + + class _FakeStorageAdapter: + def __init__(self): + self.calls = [] + + async def set_config(self, key, value): + self.calls.append((key, value)) + return True + + fake_storage = _FakeStorageAdapter() + + async def _fake_get_storage_adapter(): + return fake_storage + + reload_calls = {"count": 0} + + async def _fake_reload_config(): + reload_calls["count"] += 1 + + monkeypatch.setattr(rollout_guard, "get_storage_adapter", _fake_get_storage_adapter) + monkeypatch.setattr(rollout_guard, "reload_config", _fake_reload_config) + + result = asyncio.run( + rollout_guard.evaluate_rollout_decision( + before_perf_path=before_perf, + after_perf_path=after_perf, + baseline_quality_path=baseline_quality, + candidate_quality_path=candidate_quality, + rollout_stage_percent=5, + rollback_trigger_latency_p95_ms=2000, + rollback_trigger_throughput_drop_pct=20, + rollback_trigger_quality_drop_pct=10, + apply=True, + ) + ) + + assert result["decision"] == "PROMOTE" + assert result["target_stage_percent"] == 20 + assert result["applied"] is True + assert reload_calls["count"] == 1 + assert dict(fake_storage.calls) == { + "rollout_stage_percent": 20, + "ff_retry_policy_v2": True, + "ff_http2_pool_tuning": True, + "ff_converter_fast_path": False, + "ff_preview_credential_scheduler_v2": False, + } From de98f1158fb2aa01ca0e54d9888c63d5c7d4e531 Mon Sep 17 00:00:00 2001 From: CI User Date: Wed, 4 Mar 2026 00:41:15 +0800 Subject: [PATCH 42/47] fix: retry thinking-only stream interruptions and normalize anthropic stream errors --- src/api/geminicli.py | 218 +++++++++++++++++++++++--- src/converter/anthropic2gemini.py | 88 ++++++++++- src/router/geminicli/anthropic.py | 23 +-- test_anthropic2gemini_tools_schema.py | 45 ++++++ test_geminicli_stream_request.py | 119 +++++++++++++- 5 files changed, 440 insertions(+), 53 deletions(-) diff --git a/src/api/geminicli.py b/src/api/geminicli.py index 5e3d622dd..8c32b1130 100644 --- a/src/api/geminicli.py +++ b/src/api/geminicli.py @@ -95,6 +95,75 @@ def _compute_capacity_retry_delay(base_interval: float, attempt: int) -> float: return min(12.0, exp_backoff + jitter) +def _stream_chunk_has_committed_output(chunk: Any) -> bool: + """判断流式 chunk 是否已经包含对客户端可见且不可安全重试的输出。 + + 仅把“正文文本 / 工具调用 / 工具结果”视为已提交输出。 + thinking-only chunk 不视为提交,允许在传输异常时重试。 + """ + raw_line: Optional[str] + if isinstance(chunk, str): + raw_line = chunk + elif isinstance(chunk, (bytes, bytearray, memoryview)): + try: + raw_line = bytes(chunk).decode("utf-8", errors="ignore") + except Exception: + return False + else: + return False + + line = raw_line.strip() + if not line.startswith("data: "): + return False + + payload = line[6:].strip() + if not payload or payload == "[DONE]": + return False + + try: + data = json.loads(payload) + except Exception: + return False + + if not isinstance(data, dict): + return False + + response_obj: Dict[str, Any] + if isinstance(data.get("response"), dict): + response_obj = data["response"] + else: + response_obj = data + + candidates = response_obj.get("candidates") + if not isinstance(candidates, list): + return False + + for candidate in candidates: + if not isinstance(candidate, dict): + continue + content_obj = candidate.get("content") + if not isinstance(content_obj, dict): + continue + parts = content_obj.get("parts") + if not isinstance(parts, list): + continue + + for part in parts: + if not isinstance(part, dict): + continue + if "functionCall" in part or "functionResponse" in part: + return True + text = part.get("text") + if ( + isinstance(text, str) + and text.strip() + and part.get("thought") is not True + ): + return True + + return False + + _OBS_METRIC_KEYS = ( "t_req_in", "t_upstream_send", @@ -369,6 +438,15 @@ async def stream_request( except Exception as e: log.error(f"准备请求失败: {e}") + await record_api_call_error( + credential_manager, + current_file, + 500, + None, + mode="geminicli", + model_name=model_name, + error_message=str(e), + ) _mark_status(metrics_ctx, 500) _mark_timestamp(metrics_ctx, "t_done", overwrite=True) resp = Response( @@ -419,7 +497,8 @@ async def refresh_credential_fast() -> bool: return False for attempt in range(max_retries + 1): - success_recorded = False # 标记是否已记录成功 + success_recorded = False # 标记是否已收到首个上游chunk + committed_output = False # 标记是否已输出正文/工具块(不可安全重试) need_retry = False # 标记是否需要重试 keep_current_credential = False # 标记是否保留当前凭证 retry_wait_applied = False # 标记本次attempt是否已执行等待 @@ -441,16 +520,6 @@ async def refresh_credential_fast() -> bool: status_code = chunk.status_code last_error_response = chunk # 记录最后一次错误 - if success_recorded: - log.warning( - f"[GEMINICLI STREAM] 首token后收到错误响应,禁止重试 (status={status_code})" - ) - _mark_status(metrics_ctx, chunk.status_code) - _mark_timestamp(metrics_ctx, "t_done", overwrite=True) - _inject_observability_headers(chunk, metrics_ctx) - yield chunk - return - # 缓存错误解析结果,避免重复decode error_body = None try: @@ -462,6 +531,29 @@ async def refresh_credential_fast() -> bool: except Exception: error_body = "" + if success_recorded and committed_output: + log.warning( + f"[GEMINICLI STREAM] 首token后收到错误响应且已输出正文,禁止重试 (status={status_code})" + ) + await record_api_call_error( + credential_manager, + current_file, + status_code, + None, + mode="geminicli", + model_name=model_name, + error_message=error_body, + ) + _mark_status(metrics_ctx, chunk.status_code) + _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + _inject_observability_headers(chunk, metrics_ctx) + yield chunk + return + if success_recorded and not committed_output: + log.warning( + "[GEMINICLI STREAM] 首token后仅收到thinking阶段内容后报错,允许继续重试" + ) + # 如果错误码是429、503或者在禁用码当中,做好记录后进行重试 if ( status_code == 429 @@ -627,24 +719,37 @@ async def refresh_credential_fast() -> bool: return else: # 不是Response,说明是真流,直接yield返回 - # 只在第一个chunk时记录成功 + # 只在第一个chunk时打点首token if not success_recorded: _mark_timestamp(metrics_ctx, "t_first_token") + success_recorded = True + log.debug( + f"[GEMINICLI STREAM] 开始接收流式响应,模型: {model_name}" + ) + + # 仅当出现正文/工具块时才记为成功(thinking-only 不算提交输出) + if not committed_output and _stream_chunk_has_committed_output( + chunk + ): await record_api_call_success( credential_manager, current_file, mode="geminicli", model_name=model_name, ) - success_recorded = True - log.debug( - f"[GEMINICLI STREAM] 开始接收流式响应,模型: {model_name}" - ) + committed_output = True yield chunk # 流式请求完成,检查结果 if success_recorded: + if not committed_output: + await record_api_call_success( + credential_manager, + current_file, + mode="geminicli", + model_name=model_name, + ) log.debug(f"[GEMINICLI STREAM] 流式响应完成,模型: {model_name}") _mark_status(metrics_ctx, 200) _mark_timestamp(metrics_ctx, "t_done", overwrite=True) @@ -713,19 +818,49 @@ async def refresh_credential_fast() -> bool: except Exception as e: elapsed = time.time() - attempt_started_at + error_detail = repr(e) log.error( f"[GEMINICLI STREAM] 流式请求异常: type={type(e).__name__}, " - f"detail={repr(e)}, 模型: {model_name}, 尝试: {attempt + 1}/{max_retries + 1}, " + f"detail={error_detail}, 模型: {model_name}, 尝试: {attempt + 1}/{max_retries + 1}, " f"耗时: {elapsed:.2f}s, url: {target_url}, proxy={'on' if proxy_enabled else 'off'}, " f"http2=True, timeout(connect/write/pool=30s, read=None), 凭证: {current_file}" ) - if success_recorded: - log.warning("[GEMINICLI STREAM] 首token后发生异常,停止流且不重试") - _mark_status(metrics_ctx, 500) + if success_recorded and committed_output: + log.warning( + "[GEMINICLI STREAM] 首token后发生异常且已输出正文,停止流且不重试" + ) + await record_api_call_error( + credential_manager, + current_file, + 503, + None, + mode="geminicli", + model_name=model_name, + error_message=error_detail, + ) + _mark_status(metrics_ctx, 502) _mark_timestamp(metrics_ctx, "t_done", overwrite=True) + resp = Response( + content=json.dumps({"error": "流式响应中断", "detail": str(e)}), + status_code=502, + media_type="application/json", + ) + _inject_observability_headers(resp, metrics_ctx) + yield resp return + if success_recorded and not committed_output: + log.warning("[GEMINICLI STREAM] 首token后仅thinking阶段异常,允许重试") if attempt < max_retries: delay = _compute_capacity_retry_delay(retry_interval, attempt) + await record_api_call_error( + credential_manager, + current_file, + 503, + None, + mode="geminicli", + model_name=model_name, + error_message=error_detail, + ) log.info( f"[GEMINICLI STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1}), " f"等待 {delay:.1f}s..." @@ -736,7 +871,16 @@ async def refresh_credential_fast() -> bool: # 所有重试都失败,返回最后一次的错误(如果有) log.error( f"[GEMINICLI STREAM] 所有重试均失败,最后异常: type={type(e).__name__}, " - f"detail={repr(e)}" + f"detail={error_detail}" + ) + await record_api_call_error( + credential_manager, + current_file, + 503, + None, + mode="geminicli", + model_name=model_name, + error_message=error_detail, ) if last_error_response: _mark_status(metrics_ctx, last_error_response.status_code) @@ -819,6 +963,15 @@ async def non_stream_request( except Exception as e: log.error(f"准备请求失败: {e}") + await record_api_call_error( + credential_manager, + current_file, + 500, + None, + mode="geminicli", + model_name=model_name, + error_message=str(e), + ) _mark_status(metrics_ctx, 500) _mark_timestamp(metrics_ctx, "t_done", overwrite=True) resp = Response( @@ -1173,12 +1326,22 @@ async def refresh_credential_fast() -> bool: return last_error_response except Exception as e: + error_detail = repr(e) log.error( - f"非流式请求异常: type={type(e).__name__}, detail={repr(e)}, " + f"非流式请求异常: type={type(e).__name__}, detail={error_detail}, " f"凭证: {current_file}" ) if attempt < max_retries: delay = _compute_capacity_retry_delay(retry_interval, attempt) + await record_api_call_error( + credential_manager, + current_file, + 503, + None, + mode="geminicli", + model_name=model_name, + error_message=error_detail, + ) log.info( f"[NON-STREAM] 异常后重试 (attempt {attempt + 2}/{max_retries + 1}), " f"等待 {delay:.1f}s..." @@ -1189,7 +1352,16 @@ async def refresh_credential_fast() -> bool: # 所有重试都失败,返回最后一次的错误(如果有)或500错误 log.error( f"[NON-STREAM] 所有重试均失败,最后异常: type={type(e).__name__}, " - f"detail={repr(e)}" + f"detail={error_detail}" + ) + await record_api_call_error( + credential_manager, + current_file, + 503, + None, + mode="geminicli", + model_name=model_name, + error_message=error_detail, ) if last_error_response: _mark_status(metrics_ctx, last_error_response.status_code) diff --git a/src/converter/anthropic2gemini.py b/src/converter/anthropic2gemini.py index cf8ff1d9e..0db0fd17d 100644 --- a/src/converter/anthropic2gemini.py +++ b/src/converter/anthropic2gemini.py @@ -1101,7 +1101,7 @@ def gemini_to_anthropic_response( async def gemini_stream_to_anthropic_stream( - gemini_stream: AsyncIterator[bytes], model: str, status_code: int = 200 + gemini_stream: AsyncIterator[Any], model: str, status_code: int = 200 ) -> AsyncIterator[bytes]: """ 将 Gemini 格式流式响应转换为 Anthropic SSE 格式流式响应 @@ -1109,7 +1109,7 @@ async def gemini_stream_to_anthropic_stream( 注意: 如果收到的不是 200 开头的响应体,不做任何处理,直接转发 Args: - gemini_stream: Gemini 格式的流式响应 (bytes 迭代器) + gemini_stream: Gemini 格式的流式响应(通常为 bytes,也可能包含 Response 错误对象) model: 模型名称 status_code: HTTP 状态码 (默认 200) @@ -1132,6 +1132,7 @@ async def gemini_stream_to_anthropic_stream( input_tokens = 0 output_tokens = 0 finish_reason: Optional[str] = None + received_done_marker = False debug_logging_enabled = log.get_current_level() == "debug" def _sse_event(event: str, data: Dict[str, Any]) -> bytes: @@ -1151,21 +1152,77 @@ def _close_block() -> Optional[bytes]: current_block_type = None return event + def _build_error_event_payload(raw_error: Any) -> Dict[str, Any]: + """将上游错误载荷规范化为 Anthropic error 事件。""" + if isinstance(raw_error, dict): + if raw_error.get("type") == "error" and isinstance( + raw_error.get("error"), dict + ): + message_value = raw_error["error"].get("message") + if isinstance(message_value, str) and message_value: + return raw_error + + error_obj = raw_error.get("error") + if isinstance(error_obj, dict): + message_value = error_obj.get("message") + if isinstance(message_value, str) and message_value: + return { + "type": "error", + "error": {"type": "api_error", "message": message_value}, + } + + return { + "type": "error", + "error": {"type": "api_error", "message": str(raw_error)}, + } + # 处理流式数据 try: async for chunk in gemini_stream: # 检查是否是 Response 对象(错误情况) if isinstance(chunk, Response): log.warning( - f"[GEMINI_TO_ANTHROPIC] 收到 Response 对象,状态码: {chunk.status_code},直接转发错误" + f"[GEMINI_TO_ANTHROPIC] 收到 Response 对象,状态码: {chunk.status_code},按错误事件返回" ) - # 直接转发错误响应内容,不做格式转换 + close_evt = _close_block() + if close_evt: + yield close_evt + body = chunk.body if isinstance(body, (bytes, bytearray, memoryview)): - error_content = bytes(body) + raw_text = bytes(body).decode("utf-8", errors="ignore") else: - error_content = str(body or "").encode("utf-8") - yield error_content + raw_text = str(body or "") + + parsed_error: Any + try: + parsed_error = json.loads(raw_text) + except Exception: + parsed_error = raw_text + + if not message_start_sent: + message_start_sent = True + yield _sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + }, + }, + ) + + yield _sse_event("error", _build_error_event_payload(parsed_error)) return # 记录接收到的原始chunk @@ -1184,6 +1241,7 @@ def _close_block() -> Optional[bytes]: raw = chunk[6:].rstrip() if raw == b"[DONE]": + received_done_marker = True if debug_logging_enabled: log.debug("[GEMINI_TO_ANTHROPIC] Received [DONE] marker") break @@ -1429,6 +1487,22 @@ def _close_block() -> Optional[bytes]: if close_evt: yield close_evt + if message_start_sent and finish_reason is None and not received_done_marker: + log.warning( + "[ANTHROPIC][stream_end] 上游流在无finish_reason情况下结束,按错误事件返回" + ) + yield _sse_event( + "error", + { + "type": "error", + "error": { + "type": "api_error", + "message": "upstream stream ended before finishReason", + }, + }, + ) + return + # 确定停止原因 # 只有在正常停止(STOP)且有工具调用时才设为 tool_use # 避免在 SAFETY、MAX_TOKENS 等情况下仍然返回 tool_use 导致循环 diff --git a/src/router/geminicli/anthropic.py b/src/router/geminicli/anthropic.py index c54f7c321..8bcac4220 100644 --- a/src/router/geminicli/anthropic.py +++ b/src/router/geminicli/anthropic.py @@ -368,27 +368,8 @@ async def gemini_chunk_wrapper(): async for chunk in stream_gen: # 检查是否是Response对象(错误情况) if isinstance(chunk, Response): - # 错误响应,不进行转换,直接传递 - try: - raw_body = chunk.body - if isinstance(raw_body, (bytes, bytearray, memoryview)): - error_content = bytes(raw_body) - else: - error_content = str(raw_body or "").encode("utf-8") - gemini_error = json.loads(error_content.decode("utf-8")) - from src.converter.anthropic2gemini import ( - gemini_to_anthropic_response, - ) - - anthropic_error = gemini_to_anthropic_response( - gemini_error, real_model, chunk.status_code - ) - yield f"data: {json.dumps(anthropic_error)}\n\n".encode("utf-8") - except Exception: - yield f"data: {json.dumps({'type': 'error', 'error': {'type': 'api_error', 'message': 'Stream error'}})}\n\n".encode( - "utf-8" - ) - yield b"data: [DONE]\n\n" + # 让转换器统一处理错误,不在此处伪造 [DONE] + yield chunk return else: # 确保是bytes类型 diff --git a/test_anthropic2gemini_tools_schema.py b/test_anthropic2gemini_tools_schema.py index eadfc5e1c..37a9bd6dc 100644 --- a/test_anthropic2gemini_tools_schema.py +++ b/test_anthropic2gemini_tools_schema.py @@ -1,8 +1,10 @@ import asyncio +from fastapi import Response from src.converter.anthropic2gemini import ( anthropic_to_gemini_request, _can_use_text_only_fast_path, + gemini_stream_to_anthropic_stream, ) @@ -172,3 +174,46 @@ def test_non_fast_path_still_preserves_tool_and_thinking_structure(): assert any(part.get("thought") is True for part in all_parts) assert any("functionCall" in part for part in all_parts) assert any("functionResponse" in part for part in all_parts) + + +def test_gemini_stream_interrupted_without_finish_reason_emits_error_event(): + async def fake_stream(): + yield b'data: {"response": {"candidates": [{"content": {"parts": [{"text": "partial answer"}]}}]}}\n\n' + + async def collect_stream_output(): + chunks = [] + async for chunk in gemini_stream_to_anthropic_stream( + fake_stream(), "gemini-3-flash-preview", 200 + ): + chunks.append(chunk.decode("utf-8")) + return "".join(chunks) + + output = asyncio.run(collect_stream_output()) + + assert "event: error" in output + assert "upstream stream ended before finishReason" in output + assert "event: message_delta" not in output + + +def test_gemini_stream_response_error_chunk_emits_error_event(): + async def fake_stream(): + yield Response( + content=b'{"error":{"message":"stream broken"}}', + status_code=502, + media_type="application/json", + ) + + async def collect_stream_output(): + chunks = [] + async for chunk in gemini_stream_to_anthropic_stream( + fake_stream(), "gemini-3-flash-preview", 200 + ): + chunks.append(chunk.decode("utf-8")) + return "".join(chunks) + + output = asyncio.run(collect_stream_output()) + + assert "event: message_start" in output + assert "event: error" in output + assert "stream broken" in output + assert "event: message_delta" not in output diff --git a/test_geminicli_stream_request.py b/test_geminicli_stream_request.py index 73551475e..60097f7f6 100644 --- a/test_geminicli_stream_request.py +++ b/test_geminicli_stream_request.py @@ -1,4 +1,5 @@ import asyncio +import json from typing import Any import httpx @@ -171,7 +172,23 @@ async def fake_handle_error_with_retry(*args, **kwargs): async def fake_stream_post_async(*args, **kwargs): called["stream_post"] += 1 - yield "data: token" + yield ( + "data: " + + json.dumps( + { + "response": { + "candidates": [ + { + "content": { + "parts": [{"text": "final output"}], + } + } + ] + } + }, + ensure_ascii=False, + ) + ) yield Response( content=b'{"error":"too many requests"}', status_code=429, @@ -205,11 +222,109 @@ async def fake_stream_post_async(*args, **kwargs): assert called["stream_post"] == 1 assert called["handle_retry"] == 0 - assert chunks[0] == "data: token" + assert "final output" in chunks[0] assert isinstance(chunks[1], Response) assert chunks[1].status_code == 429 +def test_stream_request_retries_when_exception_after_thinking_only(monkeypatch): + called = {"stream_post": 0, "record_error": 0, "record_success": 0} + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 1, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_get_proxy_config(): + return None + + async def fake_record_api_call_success(*args, **kwargs): + called["record_success"] += 1 + + async def fake_record_api_call_error(*args, **kwargs): + called["record_error"] += 1 + + async def fake_sleep_with_observability(*args, **kwargs): + return None + + async def fake_stream_post_async(*args, **kwargs): + called["stream_post"] += 1 + if called["stream_post"] == 1: + yield ( + "data: " + + json.dumps( + { + "response": { + "candidates": [ + { + "content": { + "parts": [ + {"text": "thinking-only", "thought": True} + ] + } + } + ] + } + }, + ensure_ascii=False, + ) + ) + raise httpx.ReadError("stream interrupted") + + yield ( + "data: " + + json.dumps( + { + "response": { + "candidates": [ + { + "content": { + "parts": [{"text": "final answer"}], + }, + "finishReason": "STOP", + } + ] + } + }, + ensure_ascii=False, + ) + ) + + monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr(geminicli, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + geminicli, "record_api_call_success", fake_record_api_call_success + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr(geminicli, "stream_post_async", fake_stream_post_async) + monkeypatch.setattr( + geminicli, "_sleep_with_observability", fake_sleep_with_observability + ) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + chunks = asyncio.run(_collect(geminicli.stream_request(body=body, native=False))) + + assert called["stream_post"] == 2 + assert called["record_error"] == 1 + assert called["record_success"] == 1 + assert any("final answer" in str(item) for item in chunks) + + def test_non_stream_request_waits_once_per_retry_attempt(monkeypatch): wait_calls = {"count": 0} responses = [ From 366ac5823f7101429c37de64e7e4d81e1b793b00 Mon Sep 17 00:00:00 2001 From: CI User Date: Wed, 11 Mar 2026 21:18:41 +0800 Subject: [PATCH 43/47] fix(converter): group parallel tool calls into single turn pair Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- src/converter/anthropic2gemini.py | 92 +++++++++---- test_anthropic2gemini_tools_schema.py | 184 ++++++++++++++++++++++++++ 2 files changed, 248 insertions(+), 28 deletions(-) diff --git a/src/converter/anthropic2gemini.py b/src/converter/anthropic2gemini.py index 0db0fd17d..c6c450cf4 100644 --- a/src/converter/anthropic2gemini.py +++ b/src/converter/anthropic2gemini.py @@ -695,45 +695,81 @@ def reorganize_tool_messages(contents: List[Dict[str, Any]]) -> List[Dict[str, A """ 重新组织消息,满足 tool_use/tool_result 约束。 """ + # Pass 1: 收集所有 functionResponse,建立 id -> response part 映射 tool_results: Dict[str, Dict[str, Any]] = {} - for msg in contents: for part in msg.get("parts", []) or []: - if isinstance(part, dict) and "functionResponse" in part: - tool_id = (part.get("functionResponse") or {}).get("id") - if tool_id: - tool_results[str(tool_id)] = part + if not isinstance(part, dict) or "functionResponse" not in part: + continue + tool_id = (part.get("functionResponse") or {}).get("id") + if tool_id is not None: + tool_results[str(tool_id)] = part + + def _emit_call_group( + out: List[Dict[str, Any]], call_parts: List[Dict[str, Any]] + ) -> None: + """输出一组 functionCall + 对应的单个 user functionResponse 回合。""" + if not call_parts: + return - flattened: List[Dict[str, Any]] = [] - for msg in contents: - role = msg.get("role") - for part in msg.get("parts", []) or []: - flattened.append({"role": role, "parts": [part]}) + out.append({"role": "model", "parts": list(call_parts)}) + + response_parts: List[Dict[str, Any]] = [] + for call_part in call_parts: + fc = (call_part or {}).get("functionCall") or {} + tool_id = fc.get("id") + tool_name = fc.get("name") + + matched = None + if tool_id is not None: + matched = tool_results.get(str(tool_id)) + + if matched is not None: + response_parts.append(matched) + else: + response_parts.append( + { + "functionResponse": { + "id": tool_id, + "name": tool_name or "unknown_function", + "response": {"result": "no response"}, + } + } + ) + + out.append({"role": "user", "parts": response_parts}) + # Pass 2: 逐条消息重组,保持非 tool part 的既有行为(逐 part 输出) new_contents: List[Dict[str, Any]] = [] - i = 0 - while i < len(flattened): - msg = flattened[i] - part = msg["parts"][0] + for msg in contents: + role = msg.get("role") + pending_calls: List[Dict[str, Any]] = [] - if isinstance(part, dict) and "functionResponse" in part: - i += 1 - continue + for part in msg.get("parts", []) or []: + if not isinstance(part, dict): + # 与旧逻辑一致:非 dict part 当作普通文本 part 保留 + if pending_calls: + _emit_call_group(new_contents, pending_calls) + pending_calls = [] + new_contents.append({"role": role, "parts": [part]}) + continue - if isinstance(part, dict) and "functionCall" in part: - tool_id = (part.get("functionCall") or {}).get("id") - new_contents.append({"role": "model", "parts": [part]}) + if "functionResponse" in part: + # functionResponse 统一在 call group 后输出 + continue - if tool_id is not None and str(tool_id) in tool_results: - new_contents.append( - {"role": "user", "parts": [tool_results[str(tool_id)]]} - ) + if "functionCall" in part: + pending_calls.append(part) + continue - i += 1 - continue + # 普通 part(text/thinking/image/...)维持逐 part 输出 + if pending_calls: + _emit_call_group(new_contents, pending_calls) + pending_calls = [] + new_contents.append({"role": role, "parts": [part]}) - new_contents.append(msg) - i += 1 + if pending_calls: + _emit_call_group(new_contents, pending_calls) return new_contents diff --git a/test_anthropic2gemini_tools_schema.py b/test_anthropic2gemini_tools_schema.py index 37a9bd6dc..7e891cf21 100644 --- a/test_anthropic2gemini_tools_schema.py +++ b/test_anthropic2gemini_tools_schema.py @@ -5,6 +5,7 @@ anthropic_to_gemini_request, _can_use_text_only_fast_path, gemini_stream_to_anthropic_stream, + reorganize_tool_messages, ) @@ -176,6 +177,189 @@ def test_non_fast_path_still_preserves_tool_and_thinking_structure(): assert any("functionResponse" in part for part in all_parts) +def test_reorganize_tool_messages_groups_parallel_calls_into_single_turn_pair(): + contents = [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "call_1", "name": "a", "args": {}}}, + {"functionCall": {"id": "call_2", "name": "b", "args": {}}}, + {"functionCall": {"id": "call_3", "name": "c", "args": {}}}, + ], + }, + { + "role": "user", + "parts": [ + { + "functionResponse": { + "id": "call_1", + "name": "a", + "response": {"output": "r1"}, + } + }, + { + "functionResponse": { + "id": "call_2", + "name": "b", + "response": {"output": "r2"}, + } + }, + { + "functionResponse": { + "id": "call_3", + "name": "c", + "response": {"output": "r3"}, + } + }, + ], + }, + ] + + out = reorganize_tool_messages(contents) + + assert len(out) == 2 + assert out[0]["role"] == "model" + assert len(out[0]["parts"]) == 3 + assert all("functionCall" in p for p in out[0]["parts"]) + assert out[1]["role"] == "user" + assert len(out[1]["parts"]) == 3 + assert [p["functionResponse"]["id"] for p in out[1]["parts"]] == [ + "call_1", + "call_2", + "call_3", + ] + + +def test_reorganize_tool_messages_keeps_single_call_single_response_pairing(): + contents = [ + { + "role": "model", + "parts": [ + { + "functionCall": { + "id": "single_1", + "name": "fetch_weather", + "args": {"city": "shenzhen"}, + } + } + ], + }, + { + "role": "user", + "parts": [ + { + "functionResponse": { + "id": "single_1", + "name": "fetch_weather", + "response": {"output": "sunny"}, + } + } + ], + }, + ] + + out = reorganize_tool_messages(contents) + + assert len(out) == 2 + assert out[0] == { + "role": "model", + "parts": [ + { + "functionCall": { + "id": "single_1", + "name": "fetch_weather", + "args": {"city": "shenzhen"}, + } + } + ], + } + assert out[1] == { + "role": "user", + "parts": [ + { + "functionResponse": { + "id": "single_1", + "name": "fetch_weather", + "response": {"output": "sunny"}, + } + } + ], + } + + +def test_reorganize_tool_messages_synthesizes_missing_response_with_fallback(): + contents = [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "missing_1", "name": "tool_a", "args": {}}}, + {"functionCall": {"id": "missing_2", "name": "tool_b", "args": {}}}, + ], + }, + { + "role": "user", + "parts": [ + { + "functionResponse": { + "id": "missing_1", + "name": "tool_a", + "response": {"output": "ok"}, + } + } + ], + }, + ] + + out = reorganize_tool_messages(contents) + + assert len(out) == 2 + assert len(out[0]["parts"]) == 2 + assert len(out[1]["parts"]) == 2 + assert out[1]["parts"][0]["functionResponse"]["id"] == "missing_1" + assert out[1]["parts"][1] == { + "functionResponse": { + "id": "missing_2", + "name": "tool_b", + "response": {"result": "no response"}, + } + } + + +def test_reorganize_tool_messages_separates_text_turn_from_grouped_calls(): + contents = [ + { + "role": "model", + "parts": [ + {"text": "planning"}, + {"functionCall": {"id": "mix_1", "name": "lookup", "args": {"q": "x"}}}, + ], + }, + { + "role": "user", + "parts": [ + { + "functionResponse": { + "id": "mix_1", + "name": "lookup", + "response": {"output": "done"}, + } + } + ], + }, + ] + + out = reorganize_tool_messages(contents) + + assert len(out) == 3 + assert out[0] == {"role": "model", "parts": [{"text": "planning"}]} + assert out[1]["role"] == "model" + assert len(out[1]["parts"]) == 1 + assert "functionCall" in out[1]["parts"][0] + assert out[2]["role"] == "user" + assert len(out[2]["parts"]) == 1 + assert "functionResponse" in out[2]["parts"][0] + + def test_gemini_stream_interrupted_without_finish_reason_emits_error_event(): async def fake_stream(): yield b'data: {"response": {"candidates": [{"content": {"parts": [{"text": "partial answer"}]}}]}}\n\n' From 9d54bc95cba32aee21b6d841d08dffd04e4458dd Mon Sep 17 00:00:00 2001 From: CI User Date: Wed, 11 Mar 2026 21:18:53 +0800 Subject: [PATCH 44/47] fix(converter): enforce function call/response parity and add regression coverage Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- src/converter/gemini_fix.py | 144 ++++++++++++++++++-- test_function_call_mismatch.py | 198 +++++++++++++++++++++++++++ test_gemini_fix.py | 241 +++++++++++++++++++++++++++++++++ 3 files changed, 570 insertions(+), 13 deletions(-) create mode 100644 test_function_call_mismatch.py diff --git a/src/converter/gemini_fix.py b/src/converter/gemini_fix.py index 4deecefb8..bfa63ed5c 100644 --- a/src/converter/gemini_fix.py +++ b/src/converter/gemini_fix.py @@ -4,8 +4,7 @@ ──────────────────────────────────────────────────────────────── """ -from math import e -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from log import log from src.utils import DEFAULT_SAFETY_SETTINGS @@ -187,6 +186,114 @@ def is_thinking_model(model_name: str) -> bool: return "think" in model_name or "pro" in model_name.lower() +def validate_function_call_pairs(contents: List[Any]) -> List[Any]: + """确保 functionCall turn 后紧跟数量匹配的 functionResponse turn。""" + validated_contents = list(contents) + index = 0 + + while index < len(validated_contents): + content = validated_contents[index] + if not isinstance(content, dict) or content.get("role") != "model": + index += 1 + continue + + parts = content.get("parts") or [] + if not isinstance(parts, list): + index += 1 + continue + + function_calls = [ + part + for part in parts + if isinstance(part, dict) and isinstance(part.get("functionCall"), dict) + ] + call_count = len(function_calls) + if call_count == 0: + index += 1 + continue + + def _synthesize_response(call_part: Dict[str, Any]) -> Dict[str, Any]: + call = call_part.get("functionCall", {}) + synthesized = { + "name": call.get("name") or "unknown_function", + "response": {"result": "no response"}, + } + if call.get("id"): + synthesized["id"] = call["id"] + return {"functionResponse": synthesized} + + next_index = index + 1 + next_turn = ( + validated_contents[next_index] + if next_index < len(validated_contents) + else None + ) + + if not isinstance(next_turn, dict) or next_turn.get("role") != "user": + synthesized_parts = [ + _synthesize_response(call_part) for call_part in function_calls + ] + validated_contents.insert( + next_index, + { + "role": "user", + "parts": synthesized_parts, + }, + ) + log.warning( + "[GEMINI_FIX] functionCall turn 后缺少 user/functionResponse," + f"已插入 user turn 并补齐 {call_count} 个 response" + ) + index += 2 + continue + + user_parts = next_turn.get("parts") or [] + if not isinstance(user_parts, list): + user_parts = [] + + function_responses = [ + part + for part in user_parts + if isinstance(part, dict) and isinstance(part.get("functionResponse"), dict) + ] + response_count = len(function_responses) + + if response_count == call_count: + index += 1 + continue + + fixed_responses = function_responses[:call_count] + if response_count < call_count: + missing_count = call_count - response_count + for call_part in function_calls[response_count:]: + fixed_responses.append(_synthesize_response(call_part)) + log.warning( + "[GEMINI_FIX] functionCall/functionResponse 数量不匹配," + f"call={call_count}, response={response_count},已补齐 {missing_count} 个 response" + ) + else: + removed_count = response_count - call_count + log.warning( + "[GEMINI_FIX] functionCall/functionResponse 数量不匹配," + f"call={call_count}, response={response_count},已移除 {removed_count} 个多余 response" + ) + + non_function_response_parts = [ + part + for part in user_parts + if not ( + isinstance(part, dict) + and isinstance(part.get("functionResponse"), dict) + ) + ] + updated_user_turn = next_turn.copy() + updated_user_turn["parts"] = non_function_response_parts + fixed_responses + validated_contents[next_index] = updated_user_turn + index += 1 + + return validated_contents + + async def normalize_gemini_request( request: Dict[str, Any], mode: str = "geminicli" ) -> Dict[str, Any]: @@ -213,10 +320,16 @@ async def normalize_gemini_request( generation_config = ( result.get("generationConfig") or {} ).copy() # 创建副本避免修改原对象 - tools = result.get("tools") - system_instruction = result.get("systemInstruction") or result.get( - "system_instructions" - ) + system_instruction = result.get("systemInstruction") + for alias_key in ("system_instruction", "system_instructions"): + alias_value = result.pop(alias_key, None) + if (not system_instruction) and alias_value: + system_instruction = alias_value + + if system_instruction: + result["systemInstruction"] = system_instruction + else: + result.pop("systemInstruction", None) # 记录原始请求 log.debug( @@ -352,7 +465,7 @@ async def normalize_gemini_request( if has_tool_calls: # MCP 场景:检测到工具调用,移除 thinkingConfig log.warning( - f"[ANTIGRAVITY] 检测到工具调用(MCP场景),移除 thinkingConfig 避免失效" + "[ANTIGRAVITY] 检测到工具调用(MCP场景),移除 thinkingConfig 避免失效" ) generation_config.pop("thinkingConfig", None) else: @@ -383,7 +496,7 @@ async def normalize_gemini_request( ): content["parts"] = [thinking_part] + parts log.debug( - f"[ANTIGRAVITY] 已在最后一个 assistant 消息开头插入思考块(含跳过验证签名)" + "[ANTIGRAVITY] 已在最后一个 assistant 消息开头插入思考块(含跳过验证签名)" ) break @@ -478,11 +591,15 @@ async def normalize_gemini_request( # 检查 part 是否有有效的非空值 # 过滤掉空字典或所有值都为空的 part - has_valid_value = any( - value not in (None, "", {}, []) - for key, value in part.items() - if key != "thought" # thought 字段可以为空 - ) + # functionCall/functionResponse 豁免空值过滤 + if "functionCall" in part or "functionResponse" in part: + has_valid_value = True + else: + has_valid_value = any( + value not in (None, "", {}, []) + for key, value in part.items() + if key != "thought" # thought 字段可以为空 + ) if has_valid_value: # 修复 text 字段:确保是字符串而不是列表 @@ -521,6 +638,7 @@ async def normalize_gemini_request( cleaned_contents.append(content) result["contents"] = cleaned_contents + result["contents"] = validate_function_call_pairs(result["contents"]) if generation_config: result["generationConfig"] = generation_config diff --git a/test_function_call_mismatch.py b/test_function_call_mismatch.py new file mode 100644 index 000000000..e0bb27674 --- /dev/null +++ b/test_function_call_mismatch.py @@ -0,0 +1,198 @@ +import asyncio + +from src.converter.anthropic2gemini import anthropic_to_gemini_request +from src.converter.gemini_fix import normalize_gemini_request + + +def _convert_then_normalize(payload: dict) -> dict: + gemini_request = asyncio.run(anthropic_to_gemini_request(payload)) + return asyncio.run(normalize_gemini_request(gemini_request, mode="geminicli")) + + +def _assert_model_user_function_parity(contents: list[dict]) -> None: + """Every model turn with functionCall must be followed by user with same number of functionResponse.""" + assert isinstance(contents, list) + assert contents, "contents must not be empty" + + for i, turn in enumerate(contents): + if not isinstance(turn, dict) or turn.get("role") != "model": + continue + + parts = turn.get("parts") or [] + call_parts = [ + p + for p in parts + if isinstance(p, dict) and isinstance(p.get("functionCall"), dict) + ] + if not call_parts: + continue + + assert i + 1 < len(contents), ( + f"model tool-call turn at index {i} has no next turn" + ) + next_turn = contents[i + 1] + assert next_turn.get("role") == "user", ( + f"model tool-call turn at index {i} must be followed by user turn" + ) + + response_parts = [ + p + for p in (next_turn.get("parts") or []) + if isinstance(p, dict) and isinstance(p.get("functionResponse"), dict) + ] + assert len(response_parts) == len(call_parts), ( + f"call/response count mismatch at index {i}: " + f"{len(call_parts)} calls vs {len(response_parts)} responses" + ) + + for call_part, response_part in zip(call_parts, response_parts): + call = call_part["functionCall"] + response = response_part["functionResponse"] + + assert response.get("id") == call.get("id") + assert response.get("name") == call.get("name") + assert "response" in response + + +def test_e2e_parallel_tool_use_tool_result_keeps_parity(): + payload = { + "model": "gemini-2.5-flash", + "max_tokens": 128, + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_parallel_1", + "name": "get_weather", + "input": {"city": "shenzhen"}, + }, + { + "type": "tool_use", + "id": "toolu_parallel_2", + "name": "get_news", + "input": {"topic": "ai"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_parallel_1", + "name": "get_weather", + "content": [{"type": "text", "text": "sunny"}], + }, + { + "type": "tool_result", + "tool_use_id": "toolu_parallel_2", + "name": "get_news", + "content": [{"type": "text", "text": "headline"}], + }, + ], + }, + ], + } + + normalized = _convert_then_normalize(payload) + contents = normalized["contents"] + + _assert_model_user_function_parity(contents) + assert len(contents) == 2 + + +def test_e2e_missing_response_synthesizes_no_response_and_restores_parity(): + payload = { + "model": "gemini-2.5-flash", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_missing_1", + "name": "lookup_a", + "input": {}, + }, + { + "type": "tool_use", + "id": "toolu_missing_2", + "name": "lookup_b", + "input": {}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_missing_1", + "name": "lookup_a", + "content": [{"type": "text", "text": "ok"}], + } + ], + }, + ], + } + + normalized = _convert_then_normalize(payload) + contents = normalized["contents"] + + _assert_model_user_function_parity(contents) + + response_parts = [ + p + for p in contents[1]["parts"] + if isinstance(p, dict) and "functionResponse" in p + ] + assert len(response_parts) == 2 + assert response_parts[1]["functionResponse"] == { + "id": "toolu_missing_2", + "name": "lookup_b", + "response": {"result": "no response"}, + } + + +def test_e2e_empty_response_object_is_preserved_and_parity_holds(): + payload = { + "model": "gemini-2.5-flash", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_empty_1", + "name": "store_data", + "input": {}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_empty_1", + "name": "store_data", + "content": [{"type": "text", "text": ""}], + } + ], + }, + ], + } + + gemini_request = asyncio.run(anthropic_to_gemini_request(payload)) + # Simulate empty-response edge case entering normalize stage. + gemini_request["contents"][1]["parts"][0]["functionResponse"]["response"] = {} + + normalized = asyncio.run(normalize_gemini_request(gemini_request, mode="geminicli")) + contents = normalized["contents"] + + _assert_model_user_function_parity(contents) + + response = contents[1]["parts"][0]["functionResponse"]["response"] + assert response == {} diff --git a/test_gemini_fix.py b/test_gemini_fix.py index 9f294620e..16cd8d3a9 100644 --- a/test_gemini_fix.py +++ b/test_gemini_fix.py @@ -1,4 +1,5 @@ import asyncio +from unittest.mock import patch from src.converter.gemini_fix import normalize_gemini_request @@ -54,3 +55,243 @@ def test_normalize_gemini_request_preserves_existing_signature_formats(): assert part["thoughtSignature"] == "sig_from_client" assert "thought_signature" not in part + + +def test_normalize_gemini_request_unifies_system_instruction_aliases(): + request = { + "model": "gemini-2.5-flash", + "contents": [{"role": "user", "parts": [{"text": "hello"}]}], + "system_instruction": {"parts": [{"text": "base"}]}, + "system_instructions": {"parts": [{"text": "fallback"}]}, + } + + normalized = asyncio.run(normalize_gemini_request(request, mode="geminicli")) + + assert "system_instruction" not in normalized + assert "system_instructions" not in normalized + assert normalized["systemInstruction"]["parts"][0]["text"] == "base" + + +def test_normalize_gemini_request_exemption_for_empty_function_parts(): + request = { + "model": "gemini-2.5-flash", + "contents": [ + { + "role": "model", + "parts": [ + {"functionResponse": {"name": "foo", "response": {}}}, + {"functionCall": {"name": "bar", "args": {}}}, + {"text": ""}, # Should be filtered + {"text": "valid"}, + ], + } + ], + } + + normalized = asyncio.run(normalize_gemini_request(request, mode="geminicli")) + parts = normalized["contents"][0]["parts"] + + # Only 3 parts should remain: functionResponse, functionCall, and valid text + assert len(parts) == 3 + + assert "functionResponse" in parts[0] + assert parts[0]["functionResponse"]["response"] == {} + + assert "functionCall" in parts[1] + assert parts[1]["functionCall"]["args"] == {} + # Check if thoughtSignature was added (from existing logic) + assert parts[1]["thoughtSignature"] == "skip_thought_signature_validator" + + assert "text" in parts[2] + assert parts[2]["text"] == "valid" + + +def test_validate_function_call_pairs_repairs_missing_response(): + request = { + "model": "gemini-2.5-flash", + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "tool_a", "args": {}}}, + {"functionCall": {"id": "call-2", "name": "tool_b", "args": {}}}, + {"functionCall": {"id": "call-3", "name": "tool_c", "args": {}}}, + ], + }, + { + "role": "user", + "parts": [ + { + "functionResponse": { + "id": "call-1", + "name": "tool_a", + "response": {"ok": True}, + } + }, + { + "functionResponse": { + "id": "call-2", + "name": "tool_b", + "response": {"ok": True}, + } + }, + ], + }, + ], + } + + with patch("src.converter.gemini_fix.log.warning") as warning_mock: + normalized = asyncio.run(normalize_gemini_request(request, mode="geminicli")) + + response_parts = [ + part + for part in normalized["contents"][1]["parts"] + if "functionResponse" in part + ] + assert len(response_parts) == 3 + assert response_parts[2]["functionResponse"]["id"] == "call-3" + assert response_parts[2]["functionResponse"]["name"] == "tool_c" + assert response_parts[2]["functionResponse"]["response"] == { + "result": "no response" + } + + warning_texts = [str(call.args[0]) for call in warning_mock.call_args_list] + assert any("已补齐 1 个 response" in text for text in warning_texts) + + +def test_validate_function_call_pairs_removes_extra_response(): + request = { + "model": "gemini-2.5-flash", + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "tool_a", "args": {}}}, + {"functionCall": {"id": "call-2", "name": "tool_b", "args": {}}}, + ], + }, + { + "role": "user", + "parts": [ + { + "functionResponse": { + "id": "call-1", + "name": "tool_a", + "response": {"ok": True}, + } + }, + { + "functionResponse": { + "id": "call-2", + "name": "tool_b", + "response": {"ok": True}, + } + }, + { + "functionResponse": { + "id": "call-3", + "name": "tool_c", + "response": {"ok": True}, + } + }, + ], + }, + ], + } + + with patch("src.converter.gemini_fix.log.warning") as warning_mock: + normalized = asyncio.run(normalize_gemini_request(request, mode="geminicli")) + + response_parts = [ + part + for part in normalized["contents"][1]["parts"] + if "functionResponse" in part + ] + assert len(response_parts) == 2 + assert [part["functionResponse"]["id"] for part in response_parts] == [ + "call-1", + "call-2", + ] + + warning_texts = [str(call.args[0]) for call in warning_mock.call_args_list] + assert any("已移除 1 个多余 response" in text for text in warning_texts) + + +def test_validate_function_call_pairs_unchanged_when_counts_match(): + request = { + "model": "gemini-2.5-flash", + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "tool_a", "args": {}}}, + {"functionCall": {"id": "call-2", "name": "tool_b", "args": {}}}, + ], + }, + { + "role": "user", + "parts": [ + { + "functionResponse": { + "id": "call-1", + "name": "tool_a", + "response": {"ok": True}, + } + }, + { + "functionResponse": { + "id": "call-2", + "name": "tool_b", + "response": {"ok": True}, + } + }, + ], + }, + ], + } + + with patch("src.converter.gemini_fix.log.warning") as warning_mock: + normalized = asyncio.run(normalize_gemini_request(request, mode="geminicli")) + + response_parts = [ + part + for part in normalized["contents"][1]["parts"] + if "functionResponse" in part + ] + assert len(response_parts) == 2 + + warning_texts = [str(call.args[0]) for call in warning_mock.call_args_list] + assert not any("数量不匹配" in text for text in warning_texts) + assert not any("已插入 user turn" in text for text in warning_texts) + + +def test_validate_function_call_pairs_inserts_user_turn_when_missing(): + request = { + "model": "gemini-2.5-flash", + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "tool_a", "args": {}}} + ], + }, + {"role": "model", "parts": [{"text": "next model turn"}]}, + ], + } + + with patch("src.converter.gemini_fix.log.warning") as warning_mock: + normalized = asyncio.run(normalize_gemini_request(request, mode="geminicli")) + + assert normalized["contents"][1]["role"] == "user" + inserted_parts = normalized["contents"][1]["parts"] + assert len(inserted_parts) == 1 + assert inserted_parts[0]["functionResponse"]["id"] == "call-1" + assert inserted_parts[0]["functionResponse"]["name"] == "tool_a" + assert inserted_parts[0]["functionResponse"]["response"] == { + "result": "no response" + } + + warning_texts = [str(call.args[0]) for call in warning_mock.call_args_list] + assert any( + "已插入 user turn 并补齐 1 个 response" in text for text in warning_texts + ) From 750391b02da993f3236b453b15be5b0c9958af87 Mon Sep 17 00:00:00 2001 From: CI User Date: Wed, 11 Mar 2026 21:35:03 +0800 Subject: [PATCH 45/47] fix(converter): normalize system instruction field names Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- src/converter/anti_truncation.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/converter/anti_truncation.py b/src/converter/anti_truncation.py index b8543023a..6f49df4a2 100644 --- a/src/converter/anti_truncation.py +++ b/src/converter/anti_truncation.py @@ -135,8 +135,18 @@ def apply_anti_truncation(payload: Dict[str, Any]) -> Dict[str, Any]: modified_payload = apply_regex_replacements_to_payload(payload) request_data = modified_payload.get("request", {}) + # 统一 systemInstruction 字段命名,避免 oneof 字段重复写入 + system_instruction = request_data.get("systemInstruction") + for alias_key in ("system_instruction", "system_instructions"): + alias_value = request_data.pop(alias_key, None) + if (not system_instruction) and alias_value: + system_instruction = alias_value + # 获取或创建systemInstruction - system_instruction = request_data.get("systemInstruction", {}) + if system_instruction: + request_data["systemInstruction"] = system_instruction + else: + system_instruction = {} if not system_instruction: system_instruction = {"parts": []} elif "parts" not in system_instruction: @@ -711,4 +721,4 @@ def is_anti_truncation_enabled(request_data: Dict[str, Any]) -> bool: Returns: 是否启用反截断 """ - return request_data.get("enable_anti_truncation", False) \ No newline at end of file + return request_data.get("enable_anti_truncation", False) From e15a63fdfe538902a30272107f906605e167787b Mon Sep 17 00:00:00 2001 From: CI User Date: Wed, 11 Mar 2026 21:35:11 +0800 Subject: [PATCH 46/47] fix(gemini): implement daily quota detection and UTC midnight backoff Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- src/api/geminicli.py | 64 +++++++++++++++++++++++++----- src/api/utils.py | 94 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 138 insertions(+), 20 deletions(-) diff --git a/src/api/geminicli.py b/src/api/geminicli.py index 8c32b1130..b59f4b0af 100644 --- a/src/api/geminicli.py +++ b/src/api/geminicli.py @@ -17,6 +17,7 @@ import json import random import time +from datetime import datetime, timezone, timedelta from typing import Any, Dict, Optional import httpx @@ -45,14 +46,14 @@ def _extract_gemini_error_info( error_text: Optional[str], -) -> tuple[Optional[str], Optional[str]]: +) -> tuple[Any, Optional[str], bool]: """提取 Gemini 错误状态与 ErrorInfo.reason,兼容非标准错误结构。""" if not error_text: - return None, None + return None, None, False try: error_data = json.loads(error_text) except Exception: - return None, None + return None, None, False error_obj = error_data.get("error", {}) if isinstance(error_data, dict) else {} error_status = error_obj.get("status") if isinstance(error_obj, dict) else None @@ -68,13 +69,25 @@ def _extract_gemini_error_info( merged_message = " ".join([m for m in message_candidates if m]).lower() reason = None + is_daily_quota = False if isinstance(details, list): for detail in details: if not isinstance(detail, dict): continue if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": - reason = detail.get("reason") - if reason: + if not reason: + reason = detail.get("reason") + + metadata = detail.get("metadata") + if isinstance(metadata, dict): + quota_unit = metadata.get("quota_unit") + quota_limit = metadata.get("quota_limit") + if isinstance(quota_unit, str) and "/d/" in quota_unit: + is_daily_quota = True + if isinstance(quota_limit, str) and "PerDay" in quota_limit: + is_daily_quota = True + + if reason and is_daily_quota: break # 文本兜底:识别容量耗尽的非标准消息 @@ -85,7 +98,7 @@ def _extract_gemini_error_info( ): reason = "MODEL_CAPACITY_EXHAUSTED" - return error_status, reason + return error_status, reason, is_daily_quota def _compute_capacity_retry_delay(base_interval: float, attempt: int) -> float: @@ -568,6 +581,7 @@ async def refresh_credential_fast() -> bool: cooldown_until = None error_status = None error_reason = None + is_daily_quota = False if (status_code == 429 or status_code == 503) and error_body: try: cooldown_until = await parse_and_log_cooldown( @@ -575,9 +589,11 @@ async def refresh_credential_fast() -> bool: ) except Exception: pass - error_status, error_reason = _extract_gemini_error_info( - error_body - ) + ( + error_status, + error_reason, + is_daily_quota, + ) = _extract_gemini_error_info(error_body) # 429 MODEL_CAPACITY_EXHAUSTED 视为容量拥塞,不保留同一凭证立即重试 is_model_capacity_exhausted = ( @@ -591,6 +607,7 @@ async def refresh_credential_fast() -> bool: and (status_code == 429 or status_code == 503) and cooldown_until is None and not is_model_capacity_exhausted + and not is_daily_quota ): keep_current_credential = True elif next_cred_task is None and attempt < max_retries: @@ -607,6 +624,17 @@ async def refresh_credential_fast() -> bool: + _compute_capacity_retry_delay(retry_interval, attempt) ) + # 日配额错误缺少冷却时间时,回退到下一个 UTC 零点 + if is_daily_quota and cooldown_until is None: + next_midnight_utc = datetime.now(timezone.utc).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + timedelta(days=1) + cooldown_until = next_midnight_utc.timestamp() + log.warning( + "[GEMINICLI STREAM] 日配额429缺少冷却时间,回退到下一个UTC午夜: " + f"{next_midnight_utc.isoformat()}" + ) + # 无cd的429/503保留当前凭证重试,无需记录错误 if not keep_current_credential: await record_api_call_error( @@ -1091,6 +1119,7 @@ async def refresh_credential_fast() -> bool: cooldown_until = None error_status = None error_reason = None + is_daily_quota = False if (status_code == 429 or status_code == 503) and error_text: try: cooldown_until = await parse_and_log_cooldown( @@ -1098,7 +1127,11 @@ async def refresh_credential_fast() -> bool: ) except Exception: pass - error_status, error_reason = _extract_gemini_error_info(error_text) + ( + error_status, + error_reason, + is_daily_quota, + ) = _extract_gemini_error_info(error_text) is_model_capacity_exhausted = ( status_code == 429 and error_reason == "MODEL_CAPACITY_EXHAUSTED" @@ -1109,6 +1142,7 @@ async def refresh_credential_fast() -> bool: and (status_code == 429 or status_code == 503) and cooldown_until is None and not is_model_capacity_exhausted + and not is_daily_quota ) # 对于没有触发cd且非容量耗尽的429错误,不预热新凭证 @@ -1127,6 +1161,16 @@ async def refresh_credential_fast() -> bool: retry_interval, attempt ) + # 日配额错误缺少冷却时间时,回退到下一个 UTC 零点 + if is_daily_quota and cooldown_until is None: + next_midnight_utc = datetime.now(timezone.utc).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + timedelta(days=1) + cooldown_until = next_midnight_utc.timestamp() + log.warning( + f"[NON-STREAM] 日配额429缺少冷却时间,回退到下一个UTC午夜: {next_midnight_utc.isoformat()}" + ) + # 无cd的429/503保留当前凭证重试,无需记录错误 if not keep_current_credential: await record_api_call_error( diff --git a/src/api/utils.py b/src/api/utils.py index 7d683a0b4..07b5b1bc1 100644 --- a/src/api/utils.py +++ b/src/api/utils.py @@ -5,7 +5,9 @@ import asyncio import json -from datetime import datetime, timezone +import re +import time +from datetime import datetime, timedelta, timezone from typing import Any, Dict, Optional from fastapi import Response @@ -503,6 +505,33 @@ async def collect_streaming_response(stream_generator) -> Response: ) +def _parse_duration_string(duration_str: str) -> Optional[float]: + """Parse Go-style duration strings (e.g. 13h19m1.209s) to seconds.""" + if not isinstance(duration_str, str): + return None + + value = duration_str.strip() + if not value: + return None + + match = re.fullmatch(r"(?:(\d+)h)?(?:(\d+)m)?(?:([\d.]+)s)?", value) + if not match: + return None + + hours_str, minutes_str, seconds_str = match.groups() + if hours_str is None and minutes_str is None and seconds_str is None: + return None + + try: + hours = int(hours_str) if hours_str is not None else 0 + minutes = int(minutes_str) if minutes_str is not None else 0 + seconds = float(seconds_str) if seconds_str is not None else 0.0 + except (TypeError, ValueError): + return None + + return float(hours * 3600 + minutes * 60) + seconds + + def parse_quota_reset_timestamp(error_response: dict) -> Optional[float]: """ 从Google API错误响应中提取quota重置时间戳 @@ -534,22 +563,67 @@ def parse_quota_reset_timestamp(error_response: dict) -> Optional[float]: """ try: details = error_response.get("error", {}).get("details", []) + if not isinstance(details, list): + return None + + parsed_delay_seconds: Optional[float] = None + has_daily_quota = False for detail in details: - if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": - reset_timestamp_str = detail.get("metadata", {}).get( - "quotaResetTimeStamp" - ) + if not isinstance(detail, dict): + continue + if detail.get("@type") != "type.googleapis.com/google.rpc.ErrorInfo": + continue + + metadata = detail.get("metadata", {}) + if not isinstance(metadata, dict): + continue - if reset_timestamp_str: - if reset_timestamp_str.endswith("Z"): - reset_timestamp_str = reset_timestamp_str.replace("Z", "+00:00") + # Priority 1: explicit quota reset timestamp + reset_timestamp_str = metadata.get("quotaResetTimeStamp") + if isinstance(reset_timestamp_str, str) and reset_timestamp_str.strip(): + ts_value = reset_timestamp_str.strip() + if ts_value.endswith("Z"): + ts_value = ts_value.replace("Z", "+00:00") - reset_dt = datetime.fromisoformat(reset_timestamp_str) + try: + reset_dt = datetime.fromisoformat(ts_value) if reset_dt.tzinfo is None: reset_dt = reset_dt.replace(tzinfo=timezone.utc) - return reset_dt.astimezone(timezone.utc).timestamp() + except Exception: + # Ignore malformed timestamp and keep trying fallbacks + pass + + # Priority 2: relative reset delay + if parsed_delay_seconds is None: + reset_delay = metadata.get("quotaResetDelay") + if isinstance(reset_delay, str): + parsed_delay_seconds = _parse_duration_string(reset_delay) + + # Priority 3: daily quota detection + quota_unit = metadata.get("quota_unit") + quota_limit = metadata.get("quota_limit") + if ( + isinstance(quota_unit, str) + and "/d/" in quota_unit + or isinstance(quota_limit, str) + and "PerDay" in quota_limit + ): + has_daily_quota = True + + if parsed_delay_seconds is not None: + return time.time() + parsed_delay_seconds + + if has_daily_quota: + now_utc = datetime.now(timezone.utc) + next_midnight_utc = (now_utc + timedelta(days=1)).replace( + hour=0, + minute=0, + second=0, + microsecond=0, + ) + return next_midnight_utc.timestamp() return None From ddecdb3231278929f0a1c7e87c810befb62440e3 Mon Sep 17 00:00:00 2001 From: CI User Date: Wed, 11 Mar 2026 21:35:20 +0800 Subject: [PATCH 47/47] test: add comprehensive streaming and retry regression tests Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- pyproject.toml | 2 + test_geminicli_stream_request.py | 350 +++++++++++++++++++++++++++++++ 2 files changed, 352 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 377461494..3be119bf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ asyncio_mode = "auto" addopts = [ "-v", "--strict-markers", + "--ignore=gpt4free_source", + "--ignore=temp", ] [tool.black] diff --git a/test_geminicli_stream_request.py b/test_geminicli_stream_request.py index 60097f7f6..15a0fd46c 100644 --- a/test_geminicli_stream_request.py +++ b/test_geminicli_stream_request.py @@ -1,5 +1,6 @@ import asyncio import json +from datetime import datetime, timezone from typing import Any import httpx @@ -483,6 +484,189 @@ async def fake_stream_post_async(*args, **kwargs): assert captured["auth"] == ["Bearer token-a", "Bearer token-a"] +def test_stream_request_daily_quota_does_not_keep_current_credential(monkeypatch): + captured = {"auth": [], "flags": []} + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 1, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_get_proxy_config(): + return None + + async def fake_parse_and_log_cooldown(*args, **kwargs): + return None + + async def fake_record_api_call_success(*args, **kwargs): + return None + + async def fake_record_api_call_error(*args, **kwargs): + return None + + async def fake_handle_error_with_retry(*args, **kwargs): + captured["flags"].append(kwargs.get("retry_policy_v2_enabled")) + return True + + async def fake_stream_post_async(*args, **kwargs): + captured["auth"].append(kwargs.get("headers", {}).get("Authorization")) + if len(captured["auth"]) == 1: + yield Response( + content=json.dumps( + { + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "RESOURCE_EXHAUSTED", + "metadata": { + "quota_unit": "1/d/{project}", + }, + } + ], + } + }, + ensure_ascii=False, + ).encode("utf-8"), + status_code=429, + media_type="application/json", + ) + else: + yield "data: ok" + + monkeypatch.setattr(geminicli, "credential_manager", _SequentialCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr(geminicli, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + geminicli, "parse_and_log_cooldown", fake_parse_and_log_cooldown + ) + monkeypatch.setattr( + geminicli, "record_api_call_success", fake_record_api_call_success + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr(geminicli, "stream_post_async", fake_stream_post_async) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + chunks = asyncio.run(_collect(geminicli.stream_request(body=body, native=False))) + + assert chunks == ["data: ok"] + assert captured["flags"] == [True] + assert captured["auth"] == ["Bearer token-a", "Bearer token-b"] + + +def test_stream_request_daily_quota_missing_cooldown_falls_back_to_utc_midnight( + monkeypatch, +): + captured = {"cooldown_until": None, "warnings": []} + fixed_now = datetime(2026, 3, 11, 15, 30, 0, tzinfo=timezone.utc) + + class _FixedDateTime(datetime): + @classmethod + def now(cls, tz=None): + if tz is None: + return fixed_now.replace(tzinfo=None) + return fixed_now.astimezone(tz) + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 0, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_get_proxy_config(): + return None + + async def fake_parse_and_log_cooldown(*args, **kwargs): + return None + + async def fake_record_api_call_error(*args, **kwargs): + captured["cooldown_until"] = args[3] + + async def fake_handle_error_with_retry(*args, **kwargs): + return False + + async def fake_stream_post_async(*args, **kwargs): + yield Response( + content=json.dumps( + { + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "RESOURCE_EXHAUSTED", + "metadata": { + "quota_limit": "GenerateContentRequestsPerDayPerProjectPerModel-FreeTier", + }, + } + ], + } + }, + ensure_ascii=False, + ).encode("utf-8"), + status_code=429, + media_type="application/json", + ) + + def fake_warning(message): + captured["warnings"].append(message) + + monkeypatch.setattr(geminicli, "datetime", _FixedDateTime) + monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr(geminicli, "get_proxy_config", fake_get_proxy_config) + monkeypatch.setattr( + geminicli, "parse_and_log_cooldown", fake_parse_and_log_cooldown + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr(geminicli, "stream_post_async", fake_stream_post_async) + monkeypatch.setattr(geminicli.log, "warning", fake_warning) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + chunks = asyncio.run(_collect(geminicli.stream_request(body=body, native=False))) + + expected_cooldown = datetime(2026, 3, 12, 0, 0, 0, tzinfo=timezone.utc).timestamp() + assert len(chunks) == 1 + assert isinstance(chunks[0], Response) + assert chunks[0].status_code == 429 + assert captured["cooldown_until"] == expected_cooldown + assert any("回退到下一个UTC午夜" in message for message in captured["warnings"]) + + def test_stream_request_retry_policy_v2_off_rotates_credential(monkeypatch): captured = {"auth": [], "flags": []} @@ -619,6 +803,172 @@ async def fake_post_async(*args, **kwargs): assert captured["flags"] == [False] +def test_non_stream_request_daily_quota_does_not_keep_current_credential(monkeypatch): + captured = {"auth": [], "flags": []} + responses = [ + httpx.Response( + status_code=429, + json={ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "RESOURCE_EXHAUSTED", + "metadata": { + "quota_unit": "1/d/{project}", + }, + } + ], + } + }, + ), + httpx.Response(status_code=200, json={"ok": True}), + ] + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 1, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_parse_and_log_cooldown(*args, **kwargs): + return None + + async def fake_record_api_call_success(*args, **kwargs): + return None + + async def fake_record_api_call_error(*args, **kwargs): + return None + + async def fake_handle_error_with_retry(*args, **kwargs): + captured["flags"].append(kwargs.get("retry_policy_v2_enabled")) + return True + + async def fake_post_async(*args, **kwargs): + captured["auth"].append(kwargs.get("headers", {}).get("Authorization")) + return responses.pop(0) + + monkeypatch.setattr(geminicli, "credential_manager", _SequentialCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr( + geminicli, "parse_and_log_cooldown", fake_parse_and_log_cooldown + ) + monkeypatch.setattr( + geminicli, "record_api_call_success", fake_record_api_call_success + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr(geminicli, "post_async", fake_post_async) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + response = asyncio.run(geminicli.non_stream_request(body=body)) + + assert response.status_code == 200 + assert captured["flags"] == [True] + assert captured["auth"] == ["Bearer token-a", "Bearer token-b"] + + +def test_non_stream_request_daily_quota_missing_cooldown_falls_back_to_utc_midnight( + monkeypatch, +): + captured = {"cooldown_until": None, "warnings": []} + fixed_now = datetime(2026, 3, 11, 15, 30, 0, tzinfo=timezone.utc) + + class _FixedDateTime(datetime): + @classmethod + def now(cls, tz=None): + if tz is None: + return fixed_now.replace(tzinfo=None) + return fixed_now.astimezone(tz) + + async def fake_get_code_assist_endpoint(): + return "https://example.com" + + async def fake_get_retry_config(): + return {"retry_enabled": True, "max_retries": 0, "retry_interval": 0.0} + + async def fake_get_auto_ban_error_codes(): + return [] + + async def fake_parse_and_log_cooldown(*args, **kwargs): + return None + + async def fake_record_api_call_error(*args, **kwargs): + captured["cooldown_until"] = args[3] + + async def fake_handle_error_with_retry(*args, **kwargs): + return False + + async def fake_post_async(*args, **kwargs): + return httpx.Response( + status_code=429, + json={ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "RESOURCE_EXHAUSTED", + "metadata": { + "quota_limit": "GenerateContentRequestsPerDayPerProjectPerModel-FreeTier", + }, + } + ], + } + }, + ) + + def fake_warning(message): + captured["warnings"].append(message) + + monkeypatch.setattr(geminicli, "datetime", _FixedDateTime) + monkeypatch.setattr(geminicli, "credential_manager", _DummyCredentialManager()) + monkeypatch.setattr(geminicli, "get_ff_retry_policy_v2", _retry_policy_v2_enabled) + monkeypatch.setattr( + geminicli, "get_code_assist_endpoint", fake_get_code_assist_endpoint + ) + monkeypatch.setattr(geminicli, "get_retry_config", fake_get_retry_config) + monkeypatch.setattr( + geminicli, "get_auto_ban_error_codes", fake_get_auto_ban_error_codes + ) + monkeypatch.setattr( + geminicli, "parse_and_log_cooldown", fake_parse_and_log_cooldown + ) + monkeypatch.setattr(geminicli, "record_api_call_error", fake_record_api_call_error) + monkeypatch.setattr( + geminicli, "handle_error_with_retry", fake_handle_error_with_retry + ) + monkeypatch.setattr(geminicli, "post_async", fake_post_async) + monkeypatch.setattr(geminicli.log, "warning", fake_warning) + + body = { + "model": "gemini-2.5-pro", + "request": {"contents": [{"role": "user", "parts": [{"text": "hi"}]}]}, + } + response = asyncio.run(geminicli.non_stream_request(body=body)) + + expected_cooldown = datetime(2026, 3, 12, 0, 0, 0, tzinfo=timezone.utc).timestamp() + assert response.status_code == 429 + assert captured["cooldown_until"] == expected_cooldown + assert any("回退到下一个UTC午夜" in message for message in captured["warnings"]) + + def test_handle_error_with_retry_v2_controls_internal_sleep(monkeypatch): sleep_calls = []