Skip to content
Merged
144 changes: 143 additions & 1 deletion src/api/antigravity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import asyncio
import hashlib
import json
import uuid
from datetime import datetime, timezone
Expand Down Expand Up @@ -71,6 +72,145 @@ def build_antigravity_headers(access_token: str, model_name: str = "") -> Dict[s
return headers


def _generate_stable_session_id(request_payload: Dict[str, Any]) -> str:
contents = request_payload.get("contents")
if isinstance(contents, list):
for content in contents:
if not isinstance(content, dict) or content.get("role") != "user":
continue
parts = content.get("parts")
if not isinstance(parts, list) or not parts:
continue
first_part = parts[0]
if not isinstance(first_part, dict):
continue
text = first_part.get("text")
if isinstance(text, str) and text:
digest = hashlib.sha256(text.encode("utf-8")).digest()
value = int.from_bytes(digest[:8], "big") & 0x7FFFFFFFFFFFFFFF
return f"-{value}"

value = uuid.uuid4().int % 9_000_000_000_000_000_000
return f"-{value}"


def _ensure_antigravity_session_id(payload: Dict[str, Any], model_name: str) -> None:
if "image" in (model_name or "").lower():
return

request_payload = payload.get("request")
if not isinstance(request_payload, dict):
return

if request_payload.get("sessionId"):
return

request_payload["sessionId"] = _generate_stable_session_id(request_payload)


def _empty_object_schema() -> Dict[str, Any]:
return {"type": "object", "properties": {}}


def _prepare_antigravity_tool(tool: Any, is_claude: bool) -> Any:
if not isinstance(tool, dict):
return tool

normalized_tool = tool.copy()

custom_tool = normalized_tool.get("custom")
if isinstance(custom_tool, dict):
normalized_custom = custom_tool.copy()
if "input_schema" not in normalized_custom:
schema = (
normalized_custom.pop("parametersJsonSchema", None)
or normalized_custom.pop("parameters_json_schema", None)
or normalized_custom.get("parameters")
)
normalized_custom["input_schema"] = schema or _empty_object_schema()
normalized_tool["custom"] = normalized_custom

declarations_key = None
declarations = None
if isinstance(normalized_tool.get("functionDeclarations"), list):
declarations_key = "functionDeclarations"
declarations = normalized_tool.get("functionDeclarations")
elif isinstance(normalized_tool.get("function_declarations"), list):
declarations_key = "function_declarations"
declarations = normalized_tool.get("function_declarations")

if isinstance(declarations, list) and declarations_key:
normalized_declarations = []
for declaration in declarations:
if not isinstance(declaration, dict):
normalized_declarations.append(declaration)
continue

normalized_declaration = declaration.copy()
schema = None
if "parametersJsonSchema" in normalized_declaration:
schema = normalized_declaration.pop("parametersJsonSchema")
elif "parameters_json_schema" in normalized_declaration:
schema = normalized_declaration.pop("parameters_json_schema")
elif "parameters" in normalized_declaration:
schema = normalized_declaration.get("parameters")

if schema not in (None, {}, []):
normalized_declaration["parameters"] = schema
elif is_claude or "parameters" not in normalized_declaration:
normalized_declaration["parameters"] = _empty_object_schema()

normalized_declarations.append(normalized_declaration)

normalized_tool[declarations_key] = normalized_declarations

return normalized_tool


def _prepare_antigravity_payload(payload: Dict[str, Any], model_name: str) -> Dict[str, Any]:
"""Match Antigravity's upstream payload quirks before the HTTP request."""
payload["userAgent"] = "antigravity"
if "image" in (model_name or "").lower():
payload["requestType"] = "image_gen"
payload.setdefault(
"requestId",
f"image_gen/{int(datetime.now(timezone.utc).timestamp() * 1000)}/{uuid.uuid4()}/12",
)
else:
payload["requestType"] = "agent"
payload.setdefault("requestId", f"agent-{uuid.uuid4()}")

request_payload = payload.get("request")
if not isinstance(request_payload, dict):
return payload

_ensure_antigravity_session_id(payload, model_name)
request_payload.pop("safetySettings", None)

is_claude = "claude" in (model_name or "").lower()
tools = request_payload.get("tools")
if isinstance(tools, list):
request_payload["tools"] = [
_prepare_antigravity_tool(tool, is_claude)
for tool in tools
]

if is_claude:
tool_config = request_payload.get("toolConfig")
if not isinstance(tool_config, dict):
tool_config = {}
request_payload["toolConfig"] = tool_config

function_config = tool_config.get("functionCallingConfig")
if not isinstance(function_config, dict):
function_config = {}
tool_config["functionCallingConfig"] = function_config

function_config["mode"] = "VALIDATED"

return payload


def _is_retryable_status(status_code: int, disable_error_codes: List[int]) -> bool:
"""统一判断是否属于可重试状态码。"""
return status_code in (429, 503) or status_code in disable_error_codes
Expand Down Expand Up @@ -167,6 +307,7 @@ async def stream_request(
"project": project_id,
"request": body.get("request", {}),
}
_prepare_antigravity_payload(final_payload, model_name)

# 仅当凭证明确开启积分消耗时注入 enabledCreditTypes
def apply_enabled_credit_types(cred_data: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -460,6 +601,7 @@ async def non_stream_request(
"project": project_id,
"request": body.get("request", {}),
}
_prepare_antigravity_payload(final_payload, model_name)

# 仅当凭证明确开启积分消耗时注入 enabledCreditTypes
def apply_enabled_credit_types(cred_data: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -842,4 +984,4 @@ async def fetch_quota_info(access_token: str) -> Dict[str, Any]:
return {
"success": False,
"error": str(e)
}
}
Loading