diff --git a/.env.example b/.env.example index 7235142..caf2bd0 100644 --- a/.env.example +++ b/.env.example @@ -292,6 +292,41 @@ # CUSTOM_CAP_ANTIGRAVITY_T2_3_G25_FLASH=80% # CUSTOM_CAP_COOLDOWN_ANTIGRAVITY_T2_3_G25_FLASH=offset:1800 +# ------------------------------------------------------------------------------ +# | [ADVANCED] Cross-Provider Model Fallback Groups | +# ------------------------------------------------------------------------------ +# +# Pool credentials from multiple providers for equivalent models. When one +# provider's credentials are exhausted, automatically fall back to the next. +# +# Key features: +# - Sequential provider rotation: Each provider is tried completely (with its +# internal tier rotation) before moving to the next provider +# - Target promotion: The requested provider is always moved to the front +# - Different model names: Each provider can use a different model name +# +# Format: JSON array of arrays, each inner array is a fallback group +# +# MODEL_FALLBACK_GROUPS='[ +# ["gemini/gemini-2.5-pro", "gemini_cli/gemini-2.5-pro", "openrouter/google/gemini-2.5-pro"], +# ["antigravity/claude-sonnet-4.5", "openrouter/anthropic/claude-3.5-sonnet"] +# ]' +# +# Behavior example: +# Request: "gemini_cli/gemini-2.5-pro" +# 1. Find group containing "gemini_cli/gemini-2.5-pro" +# 2. Reorder with target first: ["gemini_cli/gemini-2.5-pro", "gemini/gemini-2.5-pro", "openrouter/google/gemini-2.5-pro"] +# 3. Try gemini_cli completely (tier-2 → tier-1 credentials) +# 4. If exhausted, try gemini completely (tier-2 → tier-1 credentials) +# 5. If exhausted, try openrouter completely (tier-2 → tier-1 credentials) +# +# Notes: +# - Same provider can appear multiple times with different models +# - Request for a model NOT in any group uses single-provider behavior +# - Group name is not needed; entries are matched by exact "provider/model" +# +# MODEL_FALLBACK_GROUPS='[["gemini/gemini-2.5-pro", "gemini_cli/gemini-2.5-pro", "antigravity/gemini-2.5-pro"]]' + # ------------------------------------------------------------------------------ # | [ADVANCED] Proxy Configuration | # ------------------------------------------------------------------------------ diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index f7ebbde..74607f8 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -60,6 +60,7 @@ client = RotatingClient( - `enable_request_logging` (`bool`, default: `False`): If `True`, enables detailed per-request file logging. - `max_concurrent_requests_per_key` (`Optional[Dict[str, int]]`, default: `None`): Max concurrent requests allowed for a single API key per provider. - `rotation_tolerance` (`float`, default: `3.0`): Controls the credential rotation strategy. See Section 2.2 for details. +- `model_fallback_groups` (`Optional[List[List[str]]]`, default: `None`): Cross-provider fallback groups. See Section 2.22 for details. #### Core Responsibilities @@ -919,6 +920,83 @@ The proxy accepts both Anthropic and OpenAI authentication styles: - `x-api-key` header (Anthropic style) - `Authorization: Bearer` header (OpenAI style) +### 2.22. Cross-Provider Fallback Groups (`fallback_groups.py`) + +The `FallbackGroupManager` enables pooling credentials from multiple providers for equivalent models, with automatic fallback when one provider's credentials are exhausted. + +#### Configuration + +**Environment Variable:** +```bash +MODEL_FALLBACK_GROUPS='[ + ["gemini/gemini-2.5-pro", "gemini_cli/gemini-2.5-pro", "openrouter/google/gemini-2.5-pro"], + ["antigravity/claude-sonnet-4.5", "openrouter/anthropic/claude-3.5-sonnet"] +]' +``` + +**Constructor Parameter:** +```python +client = RotatingClient( + model_fallback_groups=[ + ["gemini/gemini-2.5-pro", "gemini_cli/gemini-2.5-pro", "openrouter/google/gemini-2.5-pro"], + ] +) +``` + +#### Key Concepts + +- **Fallback Group**: A list of `provider/model` combinations that are considered equivalent +- **Target Promotion**: When a request matches an entry in a group, that entry is moved to the front +- **Sequential Provider Rotation**: Each provider is tried completely (with its internal tier rotation) before moving to the next +- **Provider Priority**: Providers are tried in the order specified in the configuration + +#### Algorithm + +Given a request for `gemini_cli/gemini-2.5-pro` with fallback group: +`["gemini/gemini-2.5-pro", "gemini_cli/gemini-2.5-pro", "openrouter/google/gemini-2.5-pro"]` + +1. **Find matching group**: Scan all groups for exact match `gemini_cli/gemini-2.5-pro` +2. **Reorder with target first**: `["gemini_cli/gemini-2.5-pro", "gemini/gemini-2.5-pro", "openrouter/google/gemini-2.5-pro"]` +3. **Try each provider sequentially**: + - Try `gemini_cli/gemini-2.5-pro` with all its credentials (tier-2 first, then tier-1 - internal rotation) + - If exhausted, try `gemini/gemini-2.5-pro` with all its credentials + - If exhausted, try `openrouter/google/gemini-2.5-pro` with all its credentials +4. **Success stops iteration**: First successful response is returned immediately + +#### Implementation Details + +The fallback system reuses existing retry logic completely: + +``` +For each entry in fallback group (in order): + Call existing _execute_with_retry() with entry's provider/model + If success: return response + If exhausted: continue to next entry +``` + +**Key Benefits:** +- Simple, predictable provider ordering +- Each provider uses its own internal tier rotation +- Existing cooldown, fair cycle, and rotation logic remains intact +- Provider-specific settings (like concurrency limits) are respected +- No changes needed to `UsageManager` - orchestration is at client level + +#### Edge Cases + +| Scenario | Behavior | +|----------|----------| +| Request not in any group | Single-provider mode (existing behavior) | +| Same provider, different models | Both entries tried in order | +| Provider has no credentials | Entry skipped silently | +| All entries exhausted | Returns error response with details | + +#### Logging + +- `DEBUG`: When fallback activated, which entry being tried +- `INFO`: When entry exhausted and moving to next +- `INFO`: When fallback succeeds with a specific entry +- `WARNING`: When all entries exhausted + ### 3.5. Antigravity (`antigravity_provider.py`) The most sophisticated provider implementation, supporting Google's internal Antigravity API for Gemini 3 and Claude models (including **Claude Opus 4.5**, Anthropic's most powerful model). diff --git a/README.md b/README.md index a7c3c43..5d17192 100644 --- a/README.md +++ b/README.md @@ -294,6 +294,7 @@ The proxy is powered by a standalone Python library that you can use directly in - **Intelligent key selection** with tiered, model-aware locking - **Deadline-driven requests** with configurable global timeout - **Automatic failover** between keys on errors +- **Cross-provider fallback** — pool credentials from multiple providers for the same model - **OAuth support** for Gemini CLI, Antigravity, Qwen, iFlow - **Stateless deployment ready** — load credentials from environment variables diff --git a/src/rotator_library/README.md b/src/rotator_library/README.md index 22d2bf6..3ce7473 100644 --- a/src/rotator_library/README.md +++ b/src/rotator_library/README.md @@ -31,6 +31,7 @@ A robust, asynchronous, and thread-safe Python library for managing a pool of AP - **Shared OAuth Base**: Refactored OAuth implementation with reusable [`GoogleOAuthBase`](providers/google_oauth_base.py) for multiple providers. - **Fair Cycle Rotation**: Ensures each credential exhausts at least once before any can be reused within a tier. Prevents a single credential from being repeatedly used while others sit idle. Configurable per provider with tracking modes and cross-tier support. - **Custom Usage Caps**: Set custom limits per tier, per model/group that are more restrictive than actual API limits. Supports percentages (e.g., "80%") and multiple cooldown modes (`quota_reset`, `offset`, `fixed`). Credentials go on cooldown before hitting actual API limits. +- **Cross-Provider Fallback Groups**: Pool credentials from multiple providers for equivalent models. When one provider's credentials are exhausted, automatically fall back to the next provider in the configured order. Each provider uses its own internal tier rotation. - **Centralized Defaults**: All tunable defaults are defined in [`config/defaults.py`](config/defaults.py) for easy customization and visibility. ## Installation @@ -82,7 +83,10 @@ client = RotatingClient( whitelist_models={}, enable_request_logging=False, max_concurrent_requests_per_key={}, - rotation_tolerance=2.0 # 0.0=deterministic, 2.0=recommended random + rotation_tolerance=2.0, # 0.0=deterministic, 2.0=recommended random + model_fallback_groups=[ # Cross-provider fallback groups + ["gemini/gemini-2.5-pro", "gemini_cli/gemini-2.5-pro", "openrouter/google/gemini-2.5-pro"], + ], ) ``` diff --git a/src/rotator_library/client.py b/src/rotator_library/client.py index 098fff0..4ad148b 100644 --- a/src/rotator_library/client.py +++ b/src/rotator_library/client.py @@ -41,6 +41,7 @@ from .background_refresher import BackgroundRefresher from .model_definitions import ModelDefinitions from .transaction_logger import TransactionLogger +from .fallback_groups import FallbackGroupManager from .utils.paths import get_default_root, get_logs_dir, get_oauth_dir, get_data_file from .utils.suppress_litellm_warnings import suppress_litellm_serialization_warnings from .config import ( @@ -83,6 +84,7 @@ def __init__( max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE, data_dir: Optional[Union[str, Path]] = None, + model_fallback_groups: Optional[List[List[str]]] = None, ): """ Initialize the RotatingClient with intelligent credential rotation. @@ -106,6 +108,11 @@ def __init__( - 5.0+: High randomness, more unpredictable selection patterns data_dir: Root directory for all data files (logs, cache, oauth_creds, key_usage.json). If None, auto-detects: EXE directory if frozen, else current working directory. + model_fallback_groups: List of fallback groups for cross-provider model pooling. + Each group is a list of "provider/model" strings. When a request matches + any entry in a group, all entries become fallback candidates. + Order matters: earlier entries have higher priority within each tier. + Can also be configured via MODEL_FALLBACK_GROUPS environment variable. """ # Resolve data_dir early - this becomes the root for all file operations if data_dir is not None: @@ -526,6 +533,9 @@ def __init__( ) self.max_concurrent_requests_per_key[provider] = 1 + # Initialize cross-provider fallback groups + self.fallback_manager = FallbackGroupManager(model_fallback_groups) + def _parse_custom_cap_env_key( self, remainder: str ) -> Tuple[Optional[Union[int, Tuple[int, ...], str]], Optional[str]]: @@ -1175,7 +1185,14 @@ async def _execute_with_retry( pre_request_callback: Optional[callable] = None, **kwargs, ) -> Any: - """A generic retry mechanism for non-streaming API calls.""" + """A generic retry mechanism for non-streaming API calls. + + Args: + api_call: The API call function to execute + request: The request object for disconnect detection + pre_request_callback: Optional callback before each request + **kwargs: Additional arguments for the API call + """ model = kwargs.get("model") if not model: raise ValueError("'model' is a required parameter.") @@ -1941,13 +1958,211 @@ async def _execute_with_retry( ) return None + async def _execute_with_fallback( + self, + api_call: callable, + request: Optional[Any], + pre_request_callback: Optional[callable], + fallback_entries: List[str], + **kwargs, + ) -> Any: + """ + Execute request with cross-provider fallback. + + Tries each provider in the fallback group sequentially. Each provider + uses its own tier rotation internally (via _execute_with_retry). + When one provider is exhausted, moves to the next. + + Args: + api_call: The API call function (litellm.acompletion) + request: The request object for disconnect detection + pre_request_callback: Optional callback before each request + fallback_entries: Ordered list of "provider/model" to try + **kwargs: Additional arguments for the API call + + Returns: + API response on success, error response on failure + """ + original_model = kwargs.get("model", fallback_entries[0]) + lib_logger.debug( + f"Cross-provider fallback activated for {original_model}. " + f"Group: {', '.join(fallback_entries[:3])}{'...' if len(fallback_entries) > 3 else ''}" + ) + + last_result = None + + for entry in fallback_entries: + parts = entry.split("/", 1) + if len(parts) != 2: + lib_logger.warning(f"Invalid fallback entry format: {entry}") + continue + + provider = parts[0] + + # Check if provider has credentials + if ( + provider not in self.all_credentials + or not self.all_credentials[provider] + ): + lib_logger.debug(f"No credentials for provider '{provider}', skipping") + continue + + lib_logger.debug(f"Fallback: trying {entry}") + + # Update kwargs with this entry's model + entry_kwargs = kwargs.copy() + entry_kwargs["model"] = entry + + # Call existing retry logic - it handles tier rotation internally + result = await self._execute_with_retry( + api_call, + request, + pre_request_callback, + **entry_kwargs, + ) + + # Check if result is successful (not an error response) + if result is not None: + if isinstance(result, dict) and "error" in result: + last_result = result + lib_logger.info(f"Fallback entry {entry} exhausted, trying next") + continue + else: + lib_logger.info(f"Fallback succeeded with {entry}") + return result + + # All entries exhausted + lib_logger.warning(f"All fallback entries exhausted for {original_model}") + + if last_result is not None: + return last_result + + return { + "error": { + "message": f"All fallback entries exhausted for {original_model}", + "type": "no_available_keys", + "code": 503, + } + } + + async def _streaming_with_fallback( + self, + request: Optional[Any], + pre_request_callback: Optional[callable], + fallback_entries: List[str], + **kwargs, + ) -> AsyncGenerator[str, None]: + """ + Execute streaming request with cross-provider fallback. + + Tries each provider in the fallback group sequentially. Each provider + uses its own tier rotation internally (via _streaming_acompletion_with_retry). + When one provider is exhausted, moves to the next. + + Args: + request: The request object for disconnect detection + pre_request_callback: Optional callback before each request + fallback_entries: Ordered list of "provider/model" to try + **kwargs: Additional arguments for the API call + + Yields: + SSE formatted strings + """ + original_model = kwargs.get("model", fallback_entries[0]) + lib_logger.debug( + f"Cross-provider fallback (streaming) activated for {original_model}. " + f"Group: {', '.join(fallback_entries[:3])}{'...' if len(fallback_entries) > 3 else ''}" + ) + + last_error = None + + for entry in fallback_entries: + parts = entry.split("/", 1) + if len(parts) != 2: + lib_logger.warning(f"Invalid fallback entry format: {entry}") + continue + + provider = parts[0] + + # Check if provider has credentials + if ( + provider not in self.all_credentials + or not self.all_credentials[provider] + ): + lib_logger.debug(f"No credentials for provider '{provider}', skipping") + continue + + lib_logger.debug(f"Fallback (streaming): trying {entry}") + + # Update kwargs with this entry's model + entry_kwargs = kwargs.copy() + entry_kwargs["model"] = entry + + # Call existing streaming logic - it handles tier rotation internally + stream = self._streaming_acompletion_with_retry( + request, + pre_request_callback, + **entry_kwargs, + ) + + # Consume the stream + chunks_yielded = 0 + async for chunk in stream: + # Check if this is an error response + if chunk.startswith("data: ") and chunks_yielded == 0: + try: + content = chunk[6:].strip() + if content and content != "[DONE]": + data = json.loads(content) + if isinstance(data, dict) and "error" in data: + # This is an error - don't yield, try next entry + last_error = data + lib_logger.info( + f"Fallback entry {entry} (streaming) returned error, trying next" + ) + break + except json.JSONDecodeError: + pass + + # Not an error, yield the chunk + yield chunk + chunks_yielded += 1 + + # Check if stream completed successfully + if chunks_yielded > 0: + lib_logger.info(f"Fallback (streaming) succeeded with {entry}") + return + + # All entries exhausted + lib_logger.warning( + f"All fallback entries exhausted for {original_model} (streaming)" + ) + + if last_error: + yield f"data: {json.dumps(last_error)}\n\n" + else: + error_data = { + "error": { + "message": f"All fallback entries exhausted for {original_model}", + "type": "no_available_keys", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + async def _streaming_acompletion_with_retry( self, request: Optional[Any], pre_request_callback: Optional[callable] = None, **kwargs, ) -> AsyncGenerator[str, None]: - """A dedicated generator for retrying streaming completions with full request preparation and per-key retries.""" + """A dedicated generator for retrying streaming completions with full request preparation and per-key retries. + + Args: + request: The request object for disconnect detection + pre_request_callback: Optional callback before each request + **kwargs: Additional arguments for the API call + """ model = kwargs.get("model") provider = model.split("/")[0] @@ -2805,6 +3020,9 @@ def acompletion( ) kwargs.pop("stream_options", None) + # Check for cross-provider fallback group + fallback_entries = self.fallback_manager.get_fallback_group(model) + if kwargs.get("stream"): # Only add stream_options for providers that support it (excluding iflow) if provider != "iflow": @@ -2813,16 +3031,33 @@ def acompletion( if "include_usage" not in kwargs["stream_options"]: kwargs["stream_options"]["include_usage"] = True - return self._streaming_acompletion_with_retry( - request=request, pre_request_callback=pre_request_callback, **kwargs - ) + if fallback_entries: + return self._streaming_with_fallback( + request=request, + pre_request_callback=pre_request_callback, + fallback_entries=fallback_entries, + **kwargs, + ) + else: + return self._streaming_acompletion_with_retry( + request=request, pre_request_callback=pre_request_callback, **kwargs + ) else: - return self._execute_with_retry( - litellm.acompletion, - request=request, - pre_request_callback=pre_request_callback, - **kwargs, - ) + if fallback_entries: + return self._execute_with_fallback( + litellm.acompletion, + request=request, + pre_request_callback=pre_request_callback, + fallback_entries=fallback_entries, + **kwargs, + ) + else: + return self._execute_with_retry( + litellm.acompletion, + request=request, + pre_request_callback=pre_request_callback, + **kwargs, + ) def aembedding( self, diff --git a/src/rotator_library/fallback_groups.py b/src/rotator_library/fallback_groups.py new file mode 100644 index 0000000..f1c46d8 --- /dev/null +++ b/src/rotator_library/fallback_groups.py @@ -0,0 +1,160 @@ +""" +Cross-provider fallback group management. + +Allows pooling credentials from multiple providers for equivalent models, +with tier-aware rotation that respects entry priority ordering. +""" + +import json +import logging +import os +from typing import Dict, List, Optional + +lib_logger = logging.getLogger("rotator_library") + + +class FallbackGroupManager: + """ + Manages cross-provider fallback groups for models. + + Configuration format (JSON array of arrays): + [ + ["2/gemini-3", "3/gemini-3", "4/gemini-2.5-pro"], + ["antigravity/claude-sonnet-4.5", "openrouter/claude-3.5-sonnet"] + ] + + Each inner array is a fallback group. When a request matches any entry + in a group, all entries in that group become fallback candidates. + + Order matters: + - Entries earlier in the list have higher priority + - The requested entry is always promoted to first position + - Same provider can appear multiple times with different models + """ + + def __init__(self, config: Optional[List[List[str]]] = None): + """ + Initialize the FallbackGroupManager. + + Args: + config: List of fallback groups. If None, loads from + MODEL_FALLBACK_GROUPS environment variable. + """ + self._groups: List[List[str]] = [] + self._entry_to_group_index: Dict[str, int] = {} + + if config is not None: + self._load_from_config(config) + else: + self._load_from_env() + + self._build_index() + + def _load_from_config(self, config: List[List[str]]) -> None: + """Load groups from provided config.""" + if not isinstance(config, list): + lib_logger.warning( + f"model_fallback_groups must be a list, got {type(config).__name__}" + ) + return + + for i, group in enumerate(config): + if not isinstance(group, list): + lib_logger.warning( + f"Fallback group {i} must be a list, got {type(group).__name__}" + ) + continue + + # Validate entries + valid_entries = [] + for entry in group: + if not isinstance(entry, str): + lib_logger.warning( + f"Fallback entry must be a string, got {type(entry).__name__}: {entry}" + ) + continue + if "/" not in entry: + lib_logger.warning( + f"Fallback entry must be 'provider/model' format: {entry}" + ) + continue + valid_entries.append(entry) + + if len(valid_entries) >= 2: + self._groups.append(valid_entries) + lib_logger.info( + f"Loaded fallback group with {len(valid_entries)} entries: " + f"{', '.join(valid_entries[:3])}{'...' if len(valid_entries) > 3 else ''}" + ) + elif valid_entries: + lib_logger.warning( + f"Fallback group needs at least 2 entries, got {len(valid_entries)}: {valid_entries}" + ) + + def _load_from_env(self) -> None: + """Load groups from MODEL_FALLBACK_GROUPS environment variable.""" + env_value = os.environ.get("MODEL_FALLBACK_GROUPS", "").strip() + if not env_value: + return + + try: + config = json.loads(env_value) + self._load_from_config(config) + except json.JSONDecodeError as e: + lib_logger.warning(f"Invalid JSON in MODEL_FALLBACK_GROUPS: {e}") + + def _build_index(self) -> None: + """Build lookup index from entry to group index.""" + self._entry_to_group_index.clear() + for group_idx, group in enumerate(self._groups): + for entry in group: + if entry in self._entry_to_group_index: + lib_logger.warning( + f"Entry '{entry}' appears in multiple fallback groups. " + f"Using first occurrence (group {self._entry_to_group_index[entry]})." + ) + continue + self._entry_to_group_index[entry] = group_idx + + def get_fallback_group(self, provider_model: str) -> Optional[List[str]]: + """ + Get fallback group for a provider/model, with target promoted to first. + + Args: + provider_model: Full "provider/model" string (e.g., "3/gemini-3") + + Returns: + Reordered list with target first, or None if not in any group. + Example: ["3/gemini-3", "2/gemini-3", "4/gemini-2.5-pro"] + """ + group_idx = self._entry_to_group_index.get(provider_model) + if group_idx is None: + return None + + group = self._groups[group_idx] + + # Move target to front, preserve relative order of others, and deduplicate + reordered = [provider_model] + seen = {provider_model} + for entry in group: + if entry not in seen: + reordered.append(entry) + seen.add(entry) + + return reordered + + def is_in_fallback_group(self, provider_model: str) -> bool: + """Check if provider/model is part of any fallback group.""" + return provider_model in self._entry_to_group_index + + def get_all_groups(self) -> List[List[str]]: + """Get all configured fallback groups (for debugging/logging).""" + return list(self._groups) + + def __bool__(self) -> bool: + """Return True if any fallback groups are configured.""" + return len(self._groups) > 0 + + def __len__(self) -> int: + """Return number of configured fallback groups.""" + return len(self._groups)