Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ openai==2.36.0
google-auth==2.0.0
requests==2.32.3
python-multipart==0.0.9
PyYAML==6.0.1
PyYAML==6.0.1
19 changes: 19 additions & 0 deletions backend/src/dna/llm_providers/gemini_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
Gemini implementation of the LLM provider interface.
"""

import logging
import os
from typing import Any

from openai import AsyncOpenAI

from dna.llm_providers.llm_provider_base import LLMProviderBase

logger = logging.getLogger(__name__)


class GeminiProvider(LLMProviderBase):
"""Gemini implementation of the LLM provider."""
Expand All @@ -25,3 +29,18 @@ def _get_provider_client(self):
base_url=os.getenv(f"{self.LLM_PROVIDER_NAME }_URL", self.DEFAULT_URL),
timeout=self.timeout,
)

async def get_available_models(self) -> dict[str, Any]:
"""Fetch available models from Gemini API."""
try:
response = await self.client.models.list()
model_ids = sorted(m.id for m in response.data)
except Exception:
logger.warning("Failed to fetch models from Gemini API, using default")
model_ids = [self.model]

return {
"provider": "gemini",
"models": model_ids,
"default": self.model,
}
18 changes: 17 additions & 1 deletion backend/src/dna/llm_providers/llm_provider_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,26 @@ async def close(self) -> None:
await self._client.close()
self._client = None

async def get_available_models(self) -> dict[str, Any]:
"""Return available models for this provider.

Returns a dict with keys: provider, models, default.
Subclasses should override to provide dynamic discovery with caching.
"""
return {
"provider": (self.LLM_PROVIDER_NAME or "").lower(),
"models": [self.model],
"default": self.model,
}

async def generate_note(
self,
prompt: str,
transcript: str,
context: str,
existing_notes: str,
additional_instructions: Optional[str] = None,
model: Optional[str] = None,
) -> str:
"""Generate a note suggestion from the given inputs.

Expand All @@ -177,10 +190,13 @@ async def generate_note(
context: Version context (entity name, task, status, etc.).
existing_notes: Any notes the user has already written.
additional_instructions: Optional additional instructions to append.
model: Optional model override; falls back to self.model.

Returns:
The generated note suggestion.
"""
use_model = model or self.model

user_message = self._substitute_template(
prompt, transcript, context, existing_notes
)
Expand All @@ -189,7 +205,7 @@ async def generate_note(
user_message += f"\n\nAdditional Instructions: {additional_instructions}"

response = await self.client.chat.completions.create(
model=self.model,
model=use_model,
messages=[
{"role": "system", "content": GENERATE_NOTE_PROMPT},
{"role": "user", "content": user_message},
Expand Down
26 changes: 26 additions & 0 deletions backend/src/dna/llm_providers/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
OpenAI implementation of the LLM provider interface.
"""

import logging
from typing import Any

from openai import AsyncOpenAI

from dna.llm_providers.llm_provider_base import LLMProviderBase

logger = logging.getLogger(__name__)

OPENAI_CHAT_PREFIXES = ("gpt-", "o1", "o3", "o4", "chatgpt-")


class OpenAIProvider(LLMProviderBase):
"""OpenAI implementation of the LLM provider."""
Expand All @@ -18,3 +25,22 @@ class OpenAIProvider(LLMProviderBase):
def _get_provider_client(self):
"""Construct an instance of the LLM provider's client."""
return AsyncOpenAI(api_key=self.api_key, timeout=self.timeout)

async def get_available_models(self) -> dict[str, Any]:
"""Fetch available chat-completion models from OpenAI API."""
try:
response = await self.client.models.list()
model_ids = sorted(
m.id
for m in response.data
if m.id.startswith(OPENAI_CHAT_PREFIXES)
)
except Exception:
logger.warning("Failed to fetch models from OpenAI API, using default")
model_ids = [self.model]

return {
"provider": "openai",
"models": model_ids,
"default": self.model,
}
4 changes: 4 additions & 0 deletions backend/src/dna/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class GenerateNoteRequest(BaseModel):
default=None,
description="Optional additional instructions to append to the prompt",
)
model: Optional[str] = Field(
default=None,
description="Optional LLM model override; omit to use server default",
)


class GenerateNoteResponse(BaseModel):
Expand Down
5 changes: 5 additions & 0 deletions backend/src/dna/models/user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class UserSettingsUpdate(BaseModel):
note_prompt: Optional[str] = Field(
default=None, description="Custom prompt for generating notes"
)
preferred_model: Optional[str] = Field(
default=None,
description="Preferred LLM model for note generation; empty means use server default",
)
regenerate_on_version_change: Optional[bool] = Field(
default=None,
description="Regenerate AI note when switching review versions",
Expand Down Expand Up @@ -42,6 +46,7 @@ class UserSettings(BaseModel):
id: str = Field(alias="_id")
user_email: str
note_prompt: str = ""
preferred_model: str = ""
regenerate_on_version_change: bool = False
regenerate_on_transcript_update: bool = False
sync_prodtrack_tab_on_version_change: bool = True
Expand Down
1 change: 1 addition & 0 deletions backend/src/dna/models/user_settings_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class UserSettingsResponse(BaseModel):
id: str = Field(alias="_id")
user_email: str
note_prompt: str = ""
preferred_model: str = ""
default_note_prompt: str = ""
regenerate_on_version_change: bool = False
regenerate_on_transcript_update: bool = False
Expand Down
21 changes: 21 additions & 0 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ def _user_settings_to_response(settings: UserSettings) -> UserSettingsResponse:
_id=settings.id,
user_email=settings.user_email,
note_prompt=settings.note_prompt,
preferred_model=settings.preferred_model,
default_note_prompt=get_default_note_prompt(),
regenerate_on_version_change=settings.regenerate_on_version_change,
regenerate_on_transcript_update=settings.regenerate_on_transcript_update,
Expand All @@ -1439,6 +1440,7 @@ def _empty_user_settings_response(user_email: str) -> UserSettingsResponse:
_id="",
user_email=user_email,
note_prompt="",
preferred_model="",
default_note_prompt=default,
regenerate_on_version_change=False,
regenerate_on_transcript_update=False,
Expand Down Expand Up @@ -1788,6 +1790,20 @@ async def get_segments_for_version(
# -----------------------------------------------------------------------------


@app.get(
"/models",
tags=["LLM"],
summary="Get available LLM models",
description="Returns the list of models available from the active LLM provider.",
)
async def get_available_models(
llm_provider: LLMProviderDep,
_: CurrentUserDep,
) -> dict:
"""Get available models from the active LLM provider."""
return await llm_provider.get_available_models()


def _build_full_prompt(
prompt: str,
transcript: str,
Expand Down Expand Up @@ -1853,12 +1869,17 @@ async def generate_note(
prompt, transcript, context, existing_notes, request.additional_instructions
)

model_override = request.model
if not model_override and user_settings:
model_override = user_settings.preferred_model or None

suggestion = await llm_provider.generate_note(
prompt=prompt,
transcript=transcript,
context=context,
existing_notes=existing_notes,
additional_instructions=request.additional_instructions,
model=model_override,
)

return GenerateNoteResponse(
Expand Down
149 changes: 149 additions & 0 deletions backend/tests/llm_providers/test_model_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Tests for model selection feature: get_available_models and model param in generate_note."""

from unittest.mock import AsyncMock, MagicMock

import pytest

from dna.llm_providers.gemini_provider import GeminiProvider
from dna.llm_providers.openai_provider import OpenAIProvider


class TestOpenAIGetAvailableModels:
"""Tests for OpenAIProvider.get_available_models."""

@pytest.mark.asyncio
async def test_returns_models_from_api(self):
"""Should return filtered model list from OpenAI API."""
provider = OpenAIProvider(api_key="test-key")

mock_model_gpt = MagicMock()
mock_model_gpt.id = "gpt-4o"
mock_model_other = MagicMock()
mock_model_other.id = "dall-e-3"
mock_model_o1 = MagicMock()
mock_model_o1.id = "o3-mini"

mock_response = MagicMock()
mock_response.data = [mock_model_gpt, mock_model_other, mock_model_o1]

mock_client = AsyncMock()
mock_client.models.list = AsyncMock(return_value=mock_response)
provider._client = mock_client

result = await provider.get_available_models()

assert result["provider"] == "openai"
assert "gpt-4o" in result["models"]
assert "o3-mini" in result["models"]
assert "dall-e-3" not in result["models"]
assert result["default"] == "gpt-4o-mini"

@pytest.mark.asyncio
async def test_falls_back_on_api_error(self):
"""Should return default model when API call fails."""
provider = OpenAIProvider(api_key="test-key")

mock_client = AsyncMock()
mock_client.models.list = AsyncMock(side_effect=Exception("API error"))
provider._client = mock_client

result = await provider.get_available_models()

assert result["provider"] == "openai"
assert result["models"] == ["gpt-4o-mini"]
assert result["default"] == "gpt-4o-mini"


class TestGeminiGetAvailableModels:
"""Tests for GeminiProvider.get_available_models."""

@pytest.mark.asyncio
async def test_returns_models_from_api(self):
"""Should return model list from Gemini API."""
provider = GeminiProvider(api_key="test-key")

mock_model_1 = MagicMock()
mock_model_1.id = "gemini-2.5-flash"
mock_model_2 = MagicMock()
mock_model_2.id = "gemini-2.5-pro"

mock_response = MagicMock()
mock_response.data = [mock_model_2, mock_model_1]

mock_client = AsyncMock()
mock_client.models.list = AsyncMock(return_value=mock_response)
provider._client = mock_client

result = await provider.get_available_models()

assert result["provider"] == "gemini"
assert "gemini-2.5-flash" in result["models"]
assert "gemini-2.5-pro" in result["models"]
assert result["default"] == "gemini-2.5-flash"

@pytest.mark.asyncio
async def test_falls_back_on_api_error(self):
"""Should return default model when API call fails."""
provider = GeminiProvider(api_key="test-key")

mock_client = AsyncMock()
mock_client.models.list = AsyncMock(side_effect=Exception("API error"))
provider._client = mock_client

result = await provider.get_available_models()

assert result["provider"] == "gemini"
assert result["models"] == ["gemini-2.5-flash"]
assert result["default"] == "gemini-2.5-flash"


class TestGenerateNoteModelParam:
"""Tests for the model parameter in generate_note."""

@pytest.mark.asyncio
async def test_uses_override_model_when_provided(self):
"""generate_note should use the provided model override."""
provider = OpenAIProvider(api_key="test-key", model="gpt-4o-mini")

mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Note"

mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
provider._client = mock_client

await provider.generate_note(
prompt="{{ transcript }}",
transcript="Test",
context="",
existing_notes="",
model="gpt-4o",
)

call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "gpt-4o"

@pytest.mark.asyncio
async def test_uses_default_model_when_none(self):
"""generate_note should use self.model when model param is None."""
provider = OpenAIProvider(api_key="test-key", model="gpt-4o-mini")

mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Note"

mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
provider._client = mock_client

await provider.generate_note(
prompt="{{ transcript }}",
transcript="Test",
context="",
existing_notes="",
model=None,
)

call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "gpt-4o-mini"
Loading