Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,17 @@ def __init__(
model_config: Optional[ModelConfig] = None,
**kwargs: Any,
):
# Override model name if specific provider env vars are set
if provider == "azure":
deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME")
if deployment_name:
model = deployment_name

if provider == "bedrock":
aws_model = os.getenv("AWS_MODEL_NAME")
if aws_model:
model = aws_model

model_value = f"{provider}/{model}"
super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore
# Set A0 model config as instance attribute after parent init
Expand Down Expand Up @@ -435,12 +446,16 @@ async def _astream(

result = ChatGenerationResult()

# Prepare call kwargs
call_kwargs = {**self.kwargs, **kwargs}
_adjust_call_args(self.provider, self.model_name, call_kwargs)

response = await acompletion(
model=self.model_name,
messages=msgs,
stream=True,
stop=stop,
**{**self.kwargs, **kwargs},
**call_kwargs,
)
async for chunk in response: # type: ignore
# parse chunk
Expand Down Expand Up @@ -487,6 +502,10 @@ async def unified_call(

# Prepare call kwargs and retry config (strip A0-only params before calling LiteLLM)
call_kwargs: dict[str, Any] = {**self.kwargs, **kwargs}

# Adjust call args (inject env vars, headers etc.)
_adjust_call_args(self.provider, self.model_name, call_kwargs)

max_retries: int = int(call_kwargs.pop("a0_retry_attempts", 2))
retry_delay_s: float = float(call_kwargs.pop("a0_retry_delay_seconds", 1.5))
stream = reasoning_callback is not None or response_callback is not None or tokens_callback is not None
Expand Down Expand Up @@ -626,6 +645,9 @@ async def _acall(
if "response_format" in kwrgs and "json_schema" in kwrgs["response_format"] and model.startswith("gemini/"):
kwrgs["response_format"]["json_schema"] = ChatGoogle("")._fix_gemini_schema(kwrgs["response_format"]["json_schema"])

# Adjust call args (inject env vars, headers etc.)
_adjust_call_args(self._wrapper.provider, self._wrapper.model_name, kwrgs)

resp = await acompletion(
model=self._wrapper.model_name,
messages=messages,
Expand Down Expand Up @@ -840,6 +862,31 @@ def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict):
if provider_name == "other":
provider_name = "openai"

# Azure OpenAI / AI Foundry Support
# Automatically inject Base URL and Version from env if not provided
if provider_name == "azure":
# Check standard variables
if "api_base" not in kwargs:
api_base = os.getenv("AZURE_API_BASE") or os.getenv("AZURE_OPENAI_ENDPOINT")
if api_base:
kwargs["api_base"] = api_base
if "api_version" not in kwargs:
api_version = os.getenv("AZURE_API_VERSION") or os.getenv("AZURE_OPENAI_API_VERSION")
if api_version:
kwargs["api_version"] = api_version
if "api_key" not in kwargs:
api_key = os.getenv("AZURE_OPENAI_API_KEY")
if api_key:
kwargs["api_key"] = api_key

# AWS Bedrock Support
# Automatically inject Region from env if not provided
if provider_name == "bedrock":
if "aws_region_name" not in kwargs:
region = os.getenv("AWS_REGION_NAME") or os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
if region:
kwargs["aws_region_name"] = region

return provider_name, model_name, kwargs


Expand Down
158 changes: 154 additions & 4 deletions python/helpers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,25 +178,28 @@ def convert_out(settings: Settings) -> SettingsOutput:
"type": "select",
"value": settings["chat_model_provider"],
"options": cast(list[FieldOption], get_providers("chat")),
"disabled_by": {"field": "cloud_provider", "not_value": "none"}
}
)
chat_model_fields.append(
{
"id": "chat_model_name",
"title": "Chat model name",
"description": "Exact name of model from selected provider",
"description": "Exact name of model from selected provider. For Azure, this is overridden by 'Azure OpenAI Deployment Name' in Provider Configuration. For AWS Bedrock, use 'AWS Bedrock Model ID'.",
"type": "text",
"value": settings["chat_model_name"],
"disabled_by": {"field": "cloud_provider", "not_value": "none"}
}
)

chat_model_fields.append(
{
"id": "chat_model_api_base",
"title": "Chat model API base URL",
"description": "API base URL for main chat model. Leave empty for default. Only relevant for Azure, local and custom (other) providers.",
"description": "API base URL for main chat model. For Azure, this is overridden by 'Azure OpenAI Endpoint' in Provider Configuration. Leave empty for default.",
"type": "text",
"value": settings["chat_model_api_base"],
"disabled_by": {"field": "cloud_provider", "not_value": "none"}
}
)

Expand Down Expand Up @@ -291,25 +294,28 @@ def convert_out(settings: Settings) -> SettingsOutput:
"type": "select",
"value": settings["util_model_provider"],
"options": cast(list[FieldOption], get_providers("chat")),
"disabled_by": {"field": "cloud_provider", "not_value": "none"}
}
)
util_model_fields.append(
{
"id": "util_model_name",
"title": "Utility model name",
"description": "Exact name of model from selected provider",
"description": "Exact name of model from selected provider. For Azure, this is overridden by 'Azure OpenAI Deployment Name' in Provider Configuration. For AWS Bedrock, use 'AWS Bedrock Model ID'.",
"type": "text",
"value": settings["util_model_name"],
"disabled_by": {"field": "cloud_provider", "not_value": "none"}
}
)

util_model_fields.append(
{
"id": "util_model_api_base",
"title": "Utility model API base URL",
"description": "API base URL for utility model. Leave empty for default. Only relevant for Azure, local and custom (other) providers.",
"description": "API base URL for utility model. For Azure, this is overridden by 'Azure OpenAI Endpoint' in Provider Configuration. Leave empty for default.",
"type": "text",
"value": settings["util_model_api_base"],
"disabled_by": {"field": "cloud_provider", "not_value": "none"}
}
)

Expand Down Expand Up @@ -600,6 +606,114 @@ def convert_out(settings: Settings) -> SettingsOutput:
"tab": "external",
}

# Provider specific config section
provider_config_fields: list[SettingsField] = []

# Cloud Provider Selector
provider_config_fields.append({
"id": "cloud_provider",
"title": "Cloud Provider",
"description": "Select a cloud provider to use for all models. When selected, Chat/Utility model providers will be auto-configured.",
"type": "select",
"value": settings["api_keys"].get("CLOUD_PROVIDER", "none"),
"options": [
{"value": "none", "label": "None (Use individual providers)"},
{"value": "azure", "label": "Azure OpenAI"},
{"value": "bedrock", "label": "AWS Bedrock"}
]
})

# Azure OpenAI
provider_config_fields.append({
"id": "azure_openai_api_key",
"title": "Azure OpenAI API Key",
"description": "Set AZURE_OPENAI_API_KEY environment variable.",
"type": "password",
"value": (API_KEY_PLACEHOLDER if settings["api_keys"].get("AZURE_OPENAI_API_KEY") else ""),
"depends_on": {"field": "cloud_provider", "value": "azure"}
})

provider_config_fields.append({
"id": "azure_openai_endpoint",
"title": "Azure OpenAI Endpoint",
"description": "Set AZURE_OPENAI_ENDPOINT environment variable.",
"type": "text",
"value": settings["api_keys"].get("AZURE_OPENAI_ENDPOINT", ""),
"depends_on": {"field": "cloud_provider", "value": "azure"}
})

provider_config_fields.append({
"id": "azure_openai_api_version",
"title": "Azure OpenAI API Version",
"description": "Set AZURE_OPENAI_API_VERSION environment variable.",
"type": "text",
"value": settings["api_keys"].get("AZURE_OPENAI_API_VERSION", ""),
"depends_on": {"field": "cloud_provider", "value": "azure"}
})

provider_config_fields.append({
"id": "azure_openai_chat_deployment_name",
"title": "Azure OpenAI Deployment Name",
"description": "Set AZURE_OPENAI_CHAT_DEPLOYMENT_NAME (e.g. gpt-4o, gpt-4o-mini).",
"type": "text",
"value": settings["api_keys"].get("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", ""),
"depends_on": {"field": "cloud_provider", "value": "azure"}
})

# AWS Bedrock
provider_config_fields.append({
"id": "aws_access_key_id",
"title": "AWS Access Key ID",
"description": "Set AWS_ACCESS_KEY_ID.",
"type": "text",
"value": settings["api_keys"].get("AWS_ACCESS_KEY_ID", ""),
"depends_on": {"field": "cloud_provider", "value": "bedrock"}
})

provider_config_fields.append({
"id": "aws_secret_access_key",
"title": "AWS Secret Access Key",
"description": "Set AWS_SECRET_ACCESS_KEY.",
"type": "password",
"value": (PASSWORD_PLACEHOLDER if settings["api_keys"].get("AWS_SECRET_ACCESS_KEY") else ""),
"depends_on": {"field": "cloud_provider", "value": "bedrock"}
})

provider_config_fields.append({
"id": "aws_session_token",
"title": "AWS Session Token",
"description": "Optional: Set AWS_SESSION_TOKEN.",
"type": "password",
"value": (PASSWORD_PLACEHOLDER if settings["api_keys"].get("AWS_SESSION_TOKEN") else ""),
"depends_on": {"field": "cloud_provider", "value": "bedrock"}
})

provider_config_fields.append({
"id": "aws_default_region",
"title": "AWS Default Region",
"description": "Set AWS_DEFAULT_REGION (e.g. us-east-1).",
"type": "text",
"value": settings["api_keys"].get("AWS_DEFAULT_REGION", ""),
"depends_on": {"field": "cloud_provider", "value": "bedrock"}
})

provider_config_fields.append({
"id": "aws_model_name",
"title": "AWS Bedrock Model ID",
"description": "Set AWS_MODEL_NAME (e.g. anthropic.claude-3-sonnet-20240229-v1:0). This overrides the main Model Name when using AWS.",
"type": "text",
"value": settings["api_keys"].get("AWS_MODEL_NAME", ""),
"depends_on": {"field": "cloud_provider", "value": "bedrock"}
})

provider_config_section: SettingsSection = {
"id": "provider_config",
"title": "Provider Configuration",
"description": "Advanced configuration for specific model providers (Azure, AWS, etc.).",
"fields": provider_config_fields,
"tab": "external",
}

# LiteLLM global config section
litellm_fields: list[SettingsField] = []

Expand Down Expand Up @@ -1285,6 +1399,7 @@ def convert_out(settings: Settings) -> SettingsOutput:
memory_section,
speech_section,
api_keys_section,
provider_config_section,
litellm_section,
secrets_section,
auth_section,
Expand Down Expand Up @@ -1329,6 +1444,41 @@ def convert_in(settings: dict) -> Settings:
current[field["id"]] = _env_to_dict(field["value"])
elif field["id"].startswith("api_key_"):
current["api_keys"][field["id"]] = field["value"]

# Manual mapping for Azure extra fields
elif field["id"] == "azure_openai_api_key":
current["api_keys"]["AZURE_OPENAI_API_KEY"] = field["value"]
elif field["id"] == "azure_openai_endpoint":
current["api_keys"]["AZURE_OPENAI_ENDPOINT"] = field["value"]
elif field["id"] == "azure_openai_api_version":
current["api_keys"]["AZURE_OPENAI_API_VERSION"] = field["value"]
elif field["id"] == "azure_openai_chat_deployment_name":
current["api_keys"]["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] = field["value"]

# Manual mapping for AWS Bedrock extra fields
elif field["id"] == "aws_access_key_id":
current["api_keys"]["AWS_ACCESS_KEY_ID"] = field["value"]
elif field["id"] == "aws_secret_access_key":
current["api_keys"]["AWS_SECRET_ACCESS_KEY"] = field["value"]
elif field["id"] == "aws_session_token":
current["api_keys"]["AWS_SESSION_TOKEN"] = field["value"]
elif field["id"] == "aws_default_region":
current["api_keys"]["AWS_DEFAULT_REGION"] = field["value"]
current["api_keys"]["AWS_REGION_NAME"] = field["value"] # Keep legacy/litellm compatible var too
elif field["id"] == "aws_model_name":
current["api_keys"]["AWS_MODEL_NAME"] = field["value"]

# Cloud Provider handling
elif field["id"] == "cloud_provider":
current["api_keys"]["CLOUD_PROVIDER"] = field["value"]
# Auto-sync Chat and Utility model providers
if field["value"] == "azure":
current["chat_model_provider"] = "azure"
current["util_model_provider"] = "azure"
elif field["value"] == "bedrock":
current["chat_model_provider"] = "bedrock"
current["util_model_provider"] = "bedrock"

else:
current[field["id"]] = field["value"]
return current
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ imapclient>=3.0.1
html2text>=2024.2.26
beautifulsoup4>=4.12.3
exchangelib>=5.4.3
pywinpty==3.0.2; sys_platform == "win32"
pywinpty==3.0.2; sys_platform == "win32"
boto3>=1.34.0
5 changes: 5 additions & 0 deletions webui/css/settings.css
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
grid-template-columns: 1fr;
}

.field.field-disabled {
opacity: 0.5;
pointer-events: none;
}

/* Field Labels */
.field-label {
display: flex;
Expand Down
4 changes: 2 additions & 2 deletions webui/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ <h2 x-text="settings.title"></h2>
<div class="section-title" x-text="section.title"></div>
<div class="section-description" x-html="section.description"></div>

<template x-for="(field, fieldIndex) in section.fields.filter(f => !f.hidden)"
<template x-for="(field, fieldIndex) in section.fields.filter(f => !f.hidden && (!f.depends_on || section.fields.find(df => df.id === f.depends_on.field)?.value === f.depends_on.value))"
:key="fieldIndex">
<div :class="{'field': true, 'field-full': field.type === 'textarea'}">
<div :class="{'field': true, 'field-full': field.type === 'textarea', 'field-disabled': field.disabled_by && settings.sections.flatMap(s => s.fields).find(df => df.id === field.disabled_by.field)?.value !== field.disabled_by.not_value}">
<div class="field-label" x-show="field.title || field.description">
<div class="field-title" x-text="field.title" x-show="field.title">
</div>
Expand Down