Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 297 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/adaptive.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the issues linked in #3303 will also apply this model, so I will fix it in FallbackModel before we merge this, and then we should make sure this also works properly with output modes.

Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
from __future__ import annotations as _annotations

import inspect
import time
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar

from opentelemetry.trace import get_current_span

from pydantic_ai._run_context import RunContext
from pydantic_ai.models.instrumented import InstrumentedModel

from ..exceptions import FallbackExceptionGroup
from ..settings import merge_model_settings
from . import Model, ModelRequestParameters, StreamedResponse

if TYPE_CHECKING:
from ..messages import ModelMessage, ModelResponse
from ..settings import ModelSettings

AgentDepsT = TypeVar('AgentDepsT')


@dataclass
class AttemptResult:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just Attempt will be clear enough

"""Record of a single attempt to use a model."""

model: Model
"""The model that was attempted."""

exception: Exception | None
"""The exception raised by the model, if any."""

timestamp: float
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's store a full datetime

"""Unix timestamp when the attempt was made."""

duration: float
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a store a timedelta? Don't feel strongly about this one

"""Duration of the attempt in seconds."""


@dataclass
class AdaptiveContext(Generic[AgentDepsT]):
"""Context provided to the selector function."""

run_context: RunContext[AgentDepsT] | None
"""Access to agent dependencies. May be None for non-streaming requests."""

models: Sequence[Model]
"""Available models to choose from."""

attempts: list[AttemptResult]
"""History of attempts in this request."""

attempt_number: int
"""Current attempt number (1-indexed)."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this always be len(attempts) + 1? In that case it could be a property, or be omitted entirely


messages: list[ModelMessage]
"""The original request messages."""

model_settings: ModelSettings | None
"""Model settings for this request."""

model_request_parameters: ModelRequestParameters
"""Model request parameters."""


@dataclass(init=False)
class AdaptiveModel(Model, Generic[AgentDepsT]):
"""A model that uses custom logic to select which model to try next.
Unlike FallbackModel which tries models sequentially, AdaptiveModel gives
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try refactoring FallbackModel as subclass of AdaptiveModel, to prove this is flexible enough (and simplify and deduplicate FallbackModel)?

full control over model selection based on rich context including attempts,
exceptions, and agent dependencies.
The selector function is called before each attempt and can:
- Return a Model to try next (can be the same model for retry)
- Return None to stop trying
- Use async/await for delays (exponential backoff, etc.)
- Access agent dependencies via ctx.run_context.deps
- Inspect previous attempts via ctx.attempts
"""

models: Sequence[Model]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the AdaptiveModel class have a models member? It seems to imply that the _selector function is choosing one from this sequence. What advantage does that give us? Could we make it optional? Why not just let the _selector function provide a model however it wants?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, I've been debating this internally. It's was mainly here because the design evolved organically from FallbackModel.

The biggest issue is having to know the models at initialization. Maybe there are instances where you actually want to init a model on the fly? (although I'm struggling to actually come up with one in practice)

What advantage does that give us?

The main things we lose by removing it:

  • it's potentially less "discoverable" (being able to introspect/reflect the underlying models). Not sure where this would actually be needed in practice.
  • some of the model properties would be less specific (IE model_name). This probably only really impacts observability.

Could we make it optional?

If it's not critical, I'd lean more for just removing it for overall consistency.

Basically this becomes a ModelProvider / Factory with some internal state tracking (arguably still too opinionated) and life cycle hooks (I've since added on_failure/success hooks).

I do think there is a fine line between being "general purpose" and having "no purpose" 😅

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, when I first started exploring our use case, I was initially hoping that agent might accept either a model factory or a model instance. So to me the idea of this being a model factory sounds just fine. A model factory in model's clothing.

_selector: (
Callable[[AdaptiveContext[AgentDepsT]], Model | None]
| Callable[[AdaptiveContext[AgentDepsT]], Awaitable[Model | None]]
)
_max_attempts: int | None

def __init__(
self,
models: Sequence[Model],
selector: Callable[[AdaptiveContext[AgentDepsT]], Model | None]
| Callable[[AdaptiveContext[AgentDepsT]], Awaitable[Model | None]],
*,
max_attempts: int | None = None,
):
"""Initialize an adaptive model instance.
Args:
models: Pool of models to choose from.
selector: Sync or async function that selects the next model to try.
Called before each attempt with context including previous attempts.
Return a Model to try, or None to stop.
max_attempts: Maximum total attempts across all models (None = unlimited).
"""
super().__init__()
if not models:
raise ValueError('At least one model must be provided')

self.models = list(models)
self._selector = selector
self._max_attempts = max_attempts

@property
def model_name(self) -> str:
"""The model name."""
return f'adaptive:{",".join(model.model_name for model in self.models)}'

@property
def system(self) -> str:
return f'adaptive:{",".join(model.system for model in self.models)}'

@property
def base_url(self) -> str | None:
return self.models[0].base_url if self.models else None

async def request(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM is there a reason that run_context is available in request_stream but not request?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I added it to request_stream pre-v1 because we needed it for Temporal integration, and I really should've added it to request as well at the time, as we can't do so anymore now as it'd be a breaking change for users that have custom Model implementations 😬

Of course if we only support deps for request_stream, this AdaptiveModel will feel half-baked. Could we make it not depend on RunContext, and perhaps require the deps (or another context object) to be passed into the model explicitly?

Copy link
Contributor Author

@ChuckJonas ChuckJonas Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't do so anymore now as it'd be a breaking change for users that have custom Model implementations

Does python not support optional params or signature overloads in a way thats backwards compatible? That's painful 😬

and perhaps require the deps (or another context object) to be passed into the model explicitly?

I'm trying to wrap my brain around the different use-cases where context matters and what we give up by not having the RunContext. I guess if the Model and Agent need to share dependencies, then those objects can just be setup on each request and passed into both.

Honestly, there are probably more use cases for global AdaptiveModel state vs run isolated run state.

Looking at the RunContext, the only other property we're potentially giving up is the usage, which could maybe be useful for upgrading/downgrading model based on input/output tokens. But maybe theres a simple workaround.

) -> ModelResponse:
"""Try models based on selector logic until one succeeds or selector returns None."""
attempts: list[AttemptResult] = []
attempt_number = 0

while True:
attempt_number += 1

# Check max attempts
if self._max_attempts is not None and attempt_number > self._max_attempts:
exceptions = [a.exception for a in attempts if a.exception is not None]
if exceptions:
raise FallbackExceptionGroup(
f'AdaptiveModel exceeded max_attempts of {self._max_attempts}', exceptions
)
else:
raise FallbackExceptionGroup(
f'AdaptiveModel exceeded max_attempts of {self._max_attempts}',
[RuntimeError('No models were attempted')],
)

# Create context for selector
context = AdaptiveContext(
run_context=None, # run_context not available in non-streaming request
models=self.models,
attempts=attempts,
attempt_number=attempt_number,
messages=messages,
model_settings=model_settings,
model_request_parameters=model_request_parameters,
)

# Call selector to get next model
model = await self._call_selector(context)

if model is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we allow the selector to return None if that will always raise an error? Couldn't the selector itself be required to return a Model, or itself raise an error if it can't?

# Selector says stop trying
exceptions = [a.exception for a in attempts if a.exception is not None]
if exceptions:
raise FallbackExceptionGroup('AdaptiveModel selector returned None', exceptions)
else:
raise FallbackExceptionGroup(
'AdaptiveModel selector returned None', [RuntimeError('No models were attempted')]
)

# Try the selected model
start_time = time.time()
customized_params = model.customize_request_parameters(model_request_parameters)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we need to use prepare_request as we do in FallbackModel

merged_settings = merge_model_settings(model.settings, model_settings)

try:
response = await model.request(messages, merged_settings, customized_params)
# Success! Set span attributes and return
self._set_span_attributes(model)
return response
except Exception as exc:
# Record the attempt
duration = time.time() - start_time
attempts.append(
AttemptResult(
model=model,
exception=exc,
timestamp=start_time,
duration=duration,
)
)
# Continue loop to try again

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
run_context: RunContext[AgentDepsT] | None = None,
) -> AsyncIterator[StreamedResponse]:
"""Try models based on selector logic until one succeeds or selector returns None."""
attempts: list[AttemptResult] = []
attempt_number = 0

while True:
attempt_number += 1

# Check max attempts
if self._max_attempts is not None and attempt_number > self._max_attempts:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should reduce the duplication between request and request_stream

exceptions = [a.exception for a in attempts if a.exception is not None]
if exceptions:
raise FallbackExceptionGroup(
f'AdaptiveModel exceeded max_attempts of {self._max_attempts}', exceptions
)
else:
raise FallbackExceptionGroup(
f'AdaptiveModel exceeded max_attempts of {self._max_attempts}',
[RuntimeError('No models were attempted')],
)

# Create context for selector
context = AdaptiveContext(
run_context=run_context,
models=self.models,
attempts=attempts,
attempt_number=attempt_number,
messages=messages,
model_settings=model_settings,
model_request_parameters=model_request_parameters,
)

# Call selector to get next model
model = await self._call_selector(context)

if model is None:
# Selector says stop trying
exceptions = [a.exception for a in attempts if a.exception is not None]
if exceptions:
raise FallbackExceptionGroup('AdaptiveModel selector returned None', exceptions)
else:
raise FallbackExceptionGroup(
'AdaptiveModel selector returned None', [RuntimeError('No models were attempted')]
)

# Try the selected model
start_time = time.time()
customized_params = model.customize_request_parameters(model_request_parameters)
merged_settings = merge_model_settings(model.settings, model_settings)

async with AsyncExitStack() as stack:
try:
response = await stack.enter_async_context(
model.request_stream(messages, merged_settings, customized_params, run_context)
)
except Exception as exc:
# Record the attempt and continue
duration = time.time() - start_time
attempts.append(
AttemptResult(
model=model,
exception=exc,
timestamp=start_time,
duration=duration,
)
)
continue

# Success! Set span attributes and yield
self._set_span_attributes(model)
yield response
return

async def _call_selector(self, context: AdaptiveContext[AgentDepsT]) -> Model | None:
"""Call the selector function, handling both sync and async."""
if inspect.iscoroutinefunction(self._selector):
return await self._selector(context)
else:
return self._selector(context) # type: ignore

def _set_span_attributes(self, model: Model):
"""Set OpenTelemetry span attributes for the successful model."""
with suppress(Exception):
span = get_current_span()
if span.is_recording():
attributes = getattr(span, 'attributes', {})
if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch
span.set_attributes(InstrumentedModel.model_attributes(model))
Loading