-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Proposal: AdaptiveModel #3335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Proposal: AdaptiveModel #3335
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,297 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import inspect | ||
| import time | ||
| from collections.abc import AsyncIterator, Awaitable, Callable, Sequence | ||
| from contextlib import AsyncExitStack, asynccontextmanager, suppress | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Generic, TypeVar | ||
|
|
||
| from opentelemetry.trace import get_current_span | ||
|
|
||
| from pydantic_ai._run_context import RunContext | ||
| from pydantic_ai.models.instrumented import InstrumentedModel | ||
|
|
||
| from ..exceptions import FallbackExceptionGroup | ||
| from ..settings import merge_model_settings | ||
| from . import Model, ModelRequestParameters, StreamedResponse | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ..messages import ModelMessage, ModelResponse | ||
| from ..settings import ModelSettings | ||
|
|
||
| AgentDepsT = TypeVar('AgentDepsT') | ||
|
|
||
|
|
||
| @dataclass | ||
| class AttemptResult: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think just |
||
| """Record of a single attempt to use a model.""" | ||
|
|
||
| model: Model | ||
| """The model that was attempted.""" | ||
|
|
||
| exception: Exception | None | ||
| """The exception raised by the model, if any.""" | ||
|
|
||
| timestamp: float | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's store a full |
||
| """Unix timestamp when the attempt was made.""" | ||
|
|
||
| duration: float | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a store a |
||
| """Duration of the attempt in seconds.""" | ||
|
|
||
|
|
||
| @dataclass | ||
| class AdaptiveContext(Generic[AgentDepsT]): | ||
| """Context provided to the selector function.""" | ||
|
|
||
| run_context: RunContext[AgentDepsT] | None | ||
| """Access to agent dependencies. May be None for non-streaming requests.""" | ||
|
|
||
| models: Sequence[Model] | ||
| """Available models to choose from.""" | ||
|
|
||
| attempts: list[AttemptResult] | ||
| """History of attempts in this request.""" | ||
|
|
||
| attempt_number: int | ||
| """Current attempt number (1-indexed).""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't this always be |
||
|
|
||
| messages: list[ModelMessage] | ||
| """The original request messages.""" | ||
|
|
||
| model_settings: ModelSettings | None | ||
| """Model settings for this request.""" | ||
|
|
||
| model_request_parameters: ModelRequestParameters | ||
| """Model request parameters.""" | ||
|
|
||
|
|
||
| @dataclass(init=False) | ||
| class AdaptiveModel(Model, Generic[AgentDepsT]): | ||
| """A model that uses custom logic to select which model to try next. | ||
| Unlike FallbackModel which tries models sequentially, AdaptiveModel gives | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you try refactoring |
||
| full control over model selection based on rich context including attempts, | ||
| exceptions, and agent dependencies. | ||
| The selector function is called before each attempt and can: | ||
| - Return a Model to try next (can be the same model for retry) | ||
| - Return None to stop trying | ||
| - Use async/await for delays (exponential backoff, etc.) | ||
| - Access agent dependencies via ctx.run_context.deps | ||
| - Inspect previous attempts via ctx.attempts | ||
| """ | ||
|
|
||
| models: Sequence[Model] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does the AdaptiveModel class have a models member? It seems to imply that the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ya, I've been debating this internally. It's was mainly here because the design evolved organically from The biggest issue is having to know the models at initialization. Maybe there are instances where you actually want to init a model on the fly? (although I'm struggling to actually come up with one in practice)
The main things we lose by removing it:
If it's not critical, I'd lean more for just removing it for overall consistency. Basically this becomes a I do think there is a fine line between being "general purpose" and having "no purpose" 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, when I first started exploring our use case, I was initially hoping that agent might accept either a model factory or a model instance. So to me the idea of this being a model factory sounds just fine. A model factory in model's clothing. |
||
| _selector: ( | ||
| Callable[[AdaptiveContext[AgentDepsT]], Model | None] | ||
| | Callable[[AdaptiveContext[AgentDepsT]], Awaitable[Model | None]] | ||
| ) | ||
| _max_attempts: int | None | ||
|
|
||
| def __init__( | ||
| self, | ||
| models: Sequence[Model], | ||
| selector: Callable[[AdaptiveContext[AgentDepsT]], Model | None] | ||
| | Callable[[AdaptiveContext[AgentDepsT]], Awaitable[Model | None]], | ||
| *, | ||
| max_attempts: int | None = None, | ||
| ): | ||
| """Initialize an adaptive model instance. | ||
| Args: | ||
| models: Pool of models to choose from. | ||
| selector: Sync or async function that selects the next model to try. | ||
| Called before each attempt with context including previous attempts. | ||
| Return a Model to try, or None to stop. | ||
| max_attempts: Maximum total attempts across all models (None = unlimited). | ||
| """ | ||
| super().__init__() | ||
| if not models: | ||
| raise ValueError('At least one model must be provided') | ||
|
|
||
| self.models = list(models) | ||
| self._selector = selector | ||
| self._max_attempts = max_attempts | ||
|
|
||
| @property | ||
| def model_name(self) -> str: | ||
| """The model name.""" | ||
| return f'adaptive:{",".join(model.model_name for model in self.models)}' | ||
|
|
||
| @property | ||
| def system(self) -> str: | ||
| return f'adaptive:{",".join(model.system for model in self.models)}' | ||
|
|
||
| @property | ||
| def base_url(self) -> str | None: | ||
| return self.models[0].base_url if self.models else None | ||
|
|
||
| async def request( | ||
| self, | ||
| messages: list[ModelMessage], | ||
| model_settings: ModelSettings | None, | ||
| model_request_parameters: ModelRequestParameters, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @DouweM is there a reason that
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I added it to Of course if we only support
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Does python not support optional params or signature overloads in a way thats backwards compatible? That's painful 😬
I'm trying to wrap my brain around the different use-cases where context matters and what we give up by not having the Honestly, there are probably more use cases for global Looking at the |
||
| ) -> ModelResponse: | ||
| """Try models based on selector logic until one succeeds or selector returns None.""" | ||
| attempts: list[AttemptResult] = [] | ||
| attempt_number = 0 | ||
|
|
||
| while True: | ||
| attempt_number += 1 | ||
|
|
||
| # Check max attempts | ||
| if self._max_attempts is not None and attempt_number > self._max_attempts: | ||
| exceptions = [a.exception for a in attempts if a.exception is not None] | ||
| if exceptions: | ||
| raise FallbackExceptionGroup( | ||
| f'AdaptiveModel exceeded max_attempts of {self._max_attempts}', exceptions | ||
| ) | ||
| else: | ||
| raise FallbackExceptionGroup( | ||
| f'AdaptiveModel exceeded max_attempts of {self._max_attempts}', | ||
| [RuntimeError('No models were attempted')], | ||
| ) | ||
|
|
||
| # Create context for selector | ||
| context = AdaptiveContext( | ||
| run_context=None, # run_context not available in non-streaming request | ||
| models=self.models, | ||
| attempts=attempts, | ||
| attempt_number=attempt_number, | ||
| messages=messages, | ||
| model_settings=model_settings, | ||
| model_request_parameters=model_request_parameters, | ||
| ) | ||
|
|
||
| # Call selector to get next model | ||
| model = await self._call_selector(context) | ||
|
|
||
| if model is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we allow the selector to return |
||
| # Selector says stop trying | ||
| exceptions = [a.exception for a in attempts if a.exception is not None] | ||
| if exceptions: | ||
| raise FallbackExceptionGroup('AdaptiveModel selector returned None', exceptions) | ||
| else: | ||
| raise FallbackExceptionGroup( | ||
| 'AdaptiveModel selector returned None', [RuntimeError('No models were attempted')] | ||
| ) | ||
|
|
||
| # Try the selected model | ||
| start_time = time.time() | ||
| customized_params = model.customize_request_parameters(model_request_parameters) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that we need to use |
||
| merged_settings = merge_model_settings(model.settings, model_settings) | ||
|
|
||
| try: | ||
| response = await model.request(messages, merged_settings, customized_params) | ||
| # Success! Set span attributes and return | ||
| self._set_span_attributes(model) | ||
| return response | ||
| except Exception as exc: | ||
| # Record the attempt | ||
| duration = time.time() - start_time | ||
| attempts.append( | ||
| AttemptResult( | ||
| model=model, | ||
| exception=exc, | ||
| timestamp=start_time, | ||
| duration=duration, | ||
| ) | ||
| ) | ||
| # Continue loop to try again | ||
|
|
||
| @asynccontextmanager | ||
| async def request_stream( | ||
| self, | ||
| messages: list[ModelMessage], | ||
| model_settings: ModelSettings | None, | ||
| model_request_parameters: ModelRequestParameters, | ||
| run_context: RunContext[AgentDepsT] | None = None, | ||
| ) -> AsyncIterator[StreamedResponse]: | ||
| """Try models based on selector logic until one succeeds or selector returns None.""" | ||
| attempts: list[AttemptResult] = [] | ||
| attempt_number = 0 | ||
|
|
||
| while True: | ||
| attempt_number += 1 | ||
|
|
||
| # Check max attempts | ||
| if self._max_attempts is not None and attempt_number > self._max_attempts: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should reduce the duplication between |
||
| exceptions = [a.exception for a in attempts if a.exception is not None] | ||
| if exceptions: | ||
| raise FallbackExceptionGroup( | ||
| f'AdaptiveModel exceeded max_attempts of {self._max_attempts}', exceptions | ||
| ) | ||
| else: | ||
| raise FallbackExceptionGroup( | ||
| f'AdaptiveModel exceeded max_attempts of {self._max_attempts}', | ||
| [RuntimeError('No models were attempted')], | ||
| ) | ||
|
|
||
| # Create context for selector | ||
| context = AdaptiveContext( | ||
| run_context=run_context, | ||
| models=self.models, | ||
| attempts=attempts, | ||
| attempt_number=attempt_number, | ||
| messages=messages, | ||
| model_settings=model_settings, | ||
| model_request_parameters=model_request_parameters, | ||
| ) | ||
|
|
||
| # Call selector to get next model | ||
| model = await self._call_selector(context) | ||
|
|
||
| if model is None: | ||
| # Selector says stop trying | ||
| exceptions = [a.exception for a in attempts if a.exception is not None] | ||
| if exceptions: | ||
| raise FallbackExceptionGroup('AdaptiveModel selector returned None', exceptions) | ||
| else: | ||
| raise FallbackExceptionGroup( | ||
| 'AdaptiveModel selector returned None', [RuntimeError('No models were attempted')] | ||
| ) | ||
|
|
||
| # Try the selected model | ||
| start_time = time.time() | ||
| customized_params = model.customize_request_parameters(model_request_parameters) | ||
| merged_settings = merge_model_settings(model.settings, model_settings) | ||
|
|
||
| async with AsyncExitStack() as stack: | ||
| try: | ||
| response = await stack.enter_async_context( | ||
| model.request_stream(messages, merged_settings, customized_params, run_context) | ||
| ) | ||
| except Exception as exc: | ||
| # Record the attempt and continue | ||
| duration = time.time() - start_time | ||
| attempts.append( | ||
| AttemptResult( | ||
| model=model, | ||
| exception=exc, | ||
| timestamp=start_time, | ||
| duration=duration, | ||
| ) | ||
| ) | ||
| continue | ||
|
|
||
| # Success! Set span attributes and yield | ||
| self._set_span_attributes(model) | ||
| yield response | ||
| return | ||
|
|
||
| async def _call_selector(self, context: AdaptiveContext[AgentDepsT]) -> Model | None: | ||
| """Call the selector function, handling both sync and async.""" | ||
| if inspect.iscoroutinefunction(self._selector): | ||
| return await self._selector(context) | ||
| else: | ||
| return self._selector(context) # type: ignore | ||
|
|
||
| def _set_span_attributes(self, model: Model): | ||
| """Set OpenTelemetry span attributes for the successful model.""" | ||
| with suppress(Exception): | ||
| span = get_current_span() | ||
| if span.is_recording(): | ||
| attributes = getattr(span, 'attributes', {}) | ||
| if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch | ||
| span.set_attributes(InstrumentedModel.model_attributes(model)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the issues linked in #3303 will also apply this model, so I will fix it in
FallbackModelbefore we merge this, and then we should make sure this also works properly with output modes.