diff --git a/models.py b/models.py index fbc2694dfd..714281de29 100644 --- a/models.py +++ b/models.py @@ -307,6 +307,17 @@ def __init__( model_config: Optional[ModelConfig] = None, **kwargs: Any, ): + # Override model name if specific provider env vars are set + if provider == "azure": + deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME") + if deployment_name: + model = deployment_name + + if provider == "bedrock": + aws_model = os.getenv("AWS_MODEL_NAME") + if aws_model: + model = aws_model + model_value = f"{provider}/{model}" super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore # Set A0 model config as instance attribute after parent init @@ -435,12 +446,16 @@ async def _astream( result = ChatGenerationResult() + # Prepare call kwargs + call_kwargs = {**self.kwargs, **kwargs} + _adjust_call_args(self.provider, self.model_name, call_kwargs) + response = await acompletion( model=self.model_name, messages=msgs, stream=True, stop=stop, - **{**self.kwargs, **kwargs}, + **call_kwargs, ) async for chunk in response: # type: ignore # parse chunk @@ -487,6 +502,10 @@ async def unified_call( # Prepare call kwargs and retry config (strip A0-only params before calling LiteLLM) call_kwargs: dict[str, Any] = {**self.kwargs, **kwargs} + + # Adjust call args (inject env vars, headers etc.) + _adjust_call_args(self.provider, self.model_name, call_kwargs) + max_retries: int = int(call_kwargs.pop("a0_retry_attempts", 2)) retry_delay_s: float = float(call_kwargs.pop("a0_retry_delay_seconds", 1.5)) stream = reasoning_callback is not None or response_callback is not None or tokens_callback is not None @@ -626,6 +645,9 @@ async def _acall( if "response_format" in kwrgs and "json_schema" in kwrgs["response_format"] and model.startswith("gemini/"): kwrgs["response_format"]["json_schema"] = ChatGoogle("")._fix_gemini_schema(kwrgs["response_format"]["json_schema"]) + # Adjust call args (inject env vars, headers etc.) + _adjust_call_args(self._wrapper.provider, self._wrapper.model_name, kwrgs) + resp = await acompletion( model=self._wrapper.model_name, messages=messages, @@ -840,6 +862,31 @@ def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict): if provider_name == "other": provider_name = "openai" + # Azure OpenAI / AI Foundry Support + # Automatically inject Base URL and Version from env if not provided + if provider_name == "azure": + # Check standard variables + if "api_base" not in kwargs: + api_base = os.getenv("AZURE_API_BASE") or os.getenv("AZURE_OPENAI_ENDPOINT") + if api_base: + kwargs["api_base"] = api_base + if "api_version" not in kwargs: + api_version = os.getenv("AZURE_API_VERSION") or os.getenv("AZURE_OPENAI_API_VERSION") + if api_version: + kwargs["api_version"] = api_version + if "api_key" not in kwargs: + api_key = os.getenv("AZURE_OPENAI_API_KEY") + if api_key: + kwargs["api_key"] = api_key + + # AWS Bedrock Support + # Automatically inject Region from env if not provided + if provider_name == "bedrock": + if "aws_region_name" not in kwargs: + region = os.getenv("AWS_REGION_NAME") or os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") + if region: + kwargs["aws_region_name"] = region + return provider_name, model_name, kwargs diff --git a/python/helpers/settings.py b/python/helpers/settings.py index 9e71b7956f..fff441e407 100644 --- a/python/helpers/settings.py +++ b/python/helpers/settings.py @@ -178,15 +178,17 @@ def convert_out(settings: Settings) -> SettingsOutput: "type": "select", "value": settings["chat_model_provider"], "options": cast(list[FieldOption], get_providers("chat")), + "disabled_by": {"field": "cloud_provider", "not_value": "none"} } ) chat_model_fields.append( { "id": "chat_model_name", "title": "Chat model name", - "description": "Exact name of model from selected provider", + "description": "Exact name of model from selected provider. For Azure, this is overridden by 'Azure OpenAI Deployment Name' in Provider Configuration. For AWS Bedrock, use 'AWS Bedrock Model ID'.", "type": "text", "value": settings["chat_model_name"], + "disabled_by": {"field": "cloud_provider", "not_value": "none"} } ) @@ -194,9 +196,10 @@ def convert_out(settings: Settings) -> SettingsOutput: { "id": "chat_model_api_base", "title": "Chat model API base URL", - "description": "API base URL for main chat model. Leave empty for default. Only relevant for Azure, local and custom (other) providers.", + "description": "API base URL for main chat model. For Azure, this is overridden by 'Azure OpenAI Endpoint' in Provider Configuration. Leave empty for default.", "type": "text", "value": settings["chat_model_api_base"], + "disabled_by": {"field": "cloud_provider", "not_value": "none"} } ) @@ -291,15 +294,17 @@ def convert_out(settings: Settings) -> SettingsOutput: "type": "select", "value": settings["util_model_provider"], "options": cast(list[FieldOption], get_providers("chat")), + "disabled_by": {"field": "cloud_provider", "not_value": "none"} } ) util_model_fields.append( { "id": "util_model_name", "title": "Utility model name", - "description": "Exact name of model from selected provider", + "description": "Exact name of model from selected provider. For Azure, this is overridden by 'Azure OpenAI Deployment Name' in Provider Configuration. For AWS Bedrock, use 'AWS Bedrock Model ID'.", "type": "text", "value": settings["util_model_name"], + "disabled_by": {"field": "cloud_provider", "not_value": "none"} } ) @@ -307,9 +312,10 @@ def convert_out(settings: Settings) -> SettingsOutput: { "id": "util_model_api_base", "title": "Utility model API base URL", - "description": "API base URL for utility model. Leave empty for default. Only relevant for Azure, local and custom (other) providers.", + "description": "API base URL for utility model. For Azure, this is overridden by 'Azure OpenAI Endpoint' in Provider Configuration. Leave empty for default.", "type": "text", "value": settings["util_model_api_base"], + "disabled_by": {"field": "cloud_provider", "not_value": "none"} } ) @@ -600,6 +606,114 @@ def convert_out(settings: Settings) -> SettingsOutput: "tab": "external", } + # Provider specific config section + provider_config_fields: list[SettingsField] = [] + + # Cloud Provider Selector + provider_config_fields.append({ + "id": "cloud_provider", + "title": "Cloud Provider", + "description": "Select a cloud provider to use for all models. When selected, Chat/Utility model providers will be auto-configured.", + "type": "select", + "value": settings["api_keys"].get("CLOUD_PROVIDER", "none"), + "options": [ + {"value": "none", "label": "None (Use individual providers)"}, + {"value": "azure", "label": "Azure OpenAI"}, + {"value": "bedrock", "label": "AWS Bedrock"} + ] + }) + + # Azure OpenAI + provider_config_fields.append({ + "id": "azure_openai_api_key", + "title": "Azure OpenAI API Key", + "description": "Set AZURE_OPENAI_API_KEY environment variable.", + "type": "password", + "value": (API_KEY_PLACEHOLDER if settings["api_keys"].get("AZURE_OPENAI_API_KEY") else ""), + "depends_on": {"field": "cloud_provider", "value": "azure"} + }) + + provider_config_fields.append({ + "id": "azure_openai_endpoint", + "title": "Azure OpenAI Endpoint", + "description": "Set AZURE_OPENAI_ENDPOINT environment variable.", + "type": "text", + "value": settings["api_keys"].get("AZURE_OPENAI_ENDPOINT", ""), + "depends_on": {"field": "cloud_provider", "value": "azure"} + }) + + provider_config_fields.append({ + "id": "azure_openai_api_version", + "title": "Azure OpenAI API Version", + "description": "Set AZURE_OPENAI_API_VERSION environment variable.", + "type": "text", + "value": settings["api_keys"].get("AZURE_OPENAI_API_VERSION", ""), + "depends_on": {"field": "cloud_provider", "value": "azure"} + }) + + provider_config_fields.append({ + "id": "azure_openai_chat_deployment_name", + "title": "Azure OpenAI Deployment Name", + "description": "Set AZURE_OPENAI_CHAT_DEPLOYMENT_NAME (e.g. gpt-4o, gpt-4o-mini).", + "type": "text", + "value": settings["api_keys"].get("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", ""), + "depends_on": {"field": "cloud_provider", "value": "azure"} + }) + + # AWS Bedrock + provider_config_fields.append({ + "id": "aws_access_key_id", + "title": "AWS Access Key ID", + "description": "Set AWS_ACCESS_KEY_ID.", + "type": "text", + "value": settings["api_keys"].get("AWS_ACCESS_KEY_ID", ""), + "depends_on": {"field": "cloud_provider", "value": "bedrock"} + }) + + provider_config_fields.append({ + "id": "aws_secret_access_key", + "title": "AWS Secret Access Key", + "description": "Set AWS_SECRET_ACCESS_KEY.", + "type": "password", + "value": (PASSWORD_PLACEHOLDER if settings["api_keys"].get("AWS_SECRET_ACCESS_KEY") else ""), + "depends_on": {"field": "cloud_provider", "value": "bedrock"} + }) + + provider_config_fields.append({ + "id": "aws_session_token", + "title": "AWS Session Token", + "description": "Optional: Set AWS_SESSION_TOKEN.", + "type": "password", + "value": (PASSWORD_PLACEHOLDER if settings["api_keys"].get("AWS_SESSION_TOKEN") else ""), + "depends_on": {"field": "cloud_provider", "value": "bedrock"} + }) + + provider_config_fields.append({ + "id": "aws_default_region", + "title": "AWS Default Region", + "description": "Set AWS_DEFAULT_REGION (e.g. us-east-1).", + "type": "text", + "value": settings["api_keys"].get("AWS_DEFAULT_REGION", ""), + "depends_on": {"field": "cloud_provider", "value": "bedrock"} + }) + + provider_config_fields.append({ + "id": "aws_model_name", + "title": "AWS Bedrock Model ID", + "description": "Set AWS_MODEL_NAME (e.g. anthropic.claude-3-sonnet-20240229-v1:0). This overrides the main Model Name when using AWS.", + "type": "text", + "value": settings["api_keys"].get("AWS_MODEL_NAME", ""), + "depends_on": {"field": "cloud_provider", "value": "bedrock"} + }) + + provider_config_section: SettingsSection = { + "id": "provider_config", + "title": "Provider Configuration", + "description": "Advanced configuration for specific model providers (Azure, AWS, etc.).", + "fields": provider_config_fields, + "tab": "external", + } + # LiteLLM global config section litellm_fields: list[SettingsField] = [] @@ -1285,6 +1399,7 @@ def convert_out(settings: Settings) -> SettingsOutput: memory_section, speech_section, api_keys_section, + provider_config_section, litellm_section, secrets_section, auth_section, @@ -1329,6 +1444,41 @@ def convert_in(settings: dict) -> Settings: current[field["id"]] = _env_to_dict(field["value"]) elif field["id"].startswith("api_key_"): current["api_keys"][field["id"]] = field["value"] + + # Manual mapping for Azure extra fields + elif field["id"] == "azure_openai_api_key": + current["api_keys"]["AZURE_OPENAI_API_KEY"] = field["value"] + elif field["id"] == "azure_openai_endpoint": + current["api_keys"]["AZURE_OPENAI_ENDPOINT"] = field["value"] + elif field["id"] == "azure_openai_api_version": + current["api_keys"]["AZURE_OPENAI_API_VERSION"] = field["value"] + elif field["id"] == "azure_openai_chat_deployment_name": + current["api_keys"]["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] = field["value"] + + # Manual mapping for AWS Bedrock extra fields + elif field["id"] == "aws_access_key_id": + current["api_keys"]["AWS_ACCESS_KEY_ID"] = field["value"] + elif field["id"] == "aws_secret_access_key": + current["api_keys"]["AWS_SECRET_ACCESS_KEY"] = field["value"] + elif field["id"] == "aws_session_token": + current["api_keys"]["AWS_SESSION_TOKEN"] = field["value"] + elif field["id"] == "aws_default_region": + current["api_keys"]["AWS_DEFAULT_REGION"] = field["value"] + current["api_keys"]["AWS_REGION_NAME"] = field["value"] # Keep legacy/litellm compatible var too + elif field["id"] == "aws_model_name": + current["api_keys"]["AWS_MODEL_NAME"] = field["value"] + + # Cloud Provider handling + elif field["id"] == "cloud_provider": + current["api_keys"]["CLOUD_PROVIDER"] = field["value"] + # Auto-sync Chat and Utility model providers + if field["value"] == "azure": + current["chat_model_provider"] = "azure" + current["util_model_provider"] = "azure" + elif field["value"] == "bedrock": + current["chat_model_provider"] = "bedrock" + current["util_model_provider"] = "bedrock" + else: current[field["id"]] = field["value"] return current diff --git a/requirements.txt b/requirements.txt index 07be99756a..96fc879e6a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,4 +46,5 @@ imapclient>=3.0.1 html2text>=2024.2.26 beautifulsoup4>=4.12.3 exchangelib>=5.4.3 -pywinpty==3.0.2; sys_platform == "win32" \ No newline at end of file +pywinpty==3.0.2; sys_platform == "win32" +boto3>=1.34.0 \ No newline at end of file diff --git a/webui/css/settings.css b/webui/css/settings.css index 871b364016..f78f58aaf5 100644 --- a/webui/css/settings.css +++ b/webui/css/settings.css @@ -13,6 +13,11 @@ grid-template-columns: 1fr; } +.field.field-disabled { + opacity: 0.5; + pointer-events: none; +} + /* Field Labels */ .field-label { display: flex; diff --git a/webui/index.html b/webui/index.html index 82abf4b591..a7257e2a0b 100644 --- a/webui/index.html +++ b/webui/index.html @@ -238,9 +238,9 @@
- -