diff --git a/paperqa/llms.py b/paperqa/llms.py index 444aa46f..34122f12 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -27,7 +27,14 @@ get_model_cost_map, token_counter, ) -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + TypeAdapter, + field_validator, + model_validator, +) from paperqa.prompts import default_system_prompt from paperqa.rate_limiter import GLOBAL_LIMITER @@ -79,7 +86,7 @@ class EmbeddingModes(StrEnum): class EmbeddingModel(ABC, BaseModel): name: str - config: dict = Field( + config: dict[str, Any] = Field( default_factory=dict, description="Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem string for parsing", ) @@ -102,7 +109,21 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]: class LiteLLMEmbeddingModel(EmbeddingModel): + name: str = Field(default="text-embedding-3-small") + config: dict[str, Any] = Field( + default_factory=dict, # See below field_validator for injection of kwargs + description="Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem string for parsing", + ) + + @field_validator("config") + @classmethod + def set_up_default_config(cls, value: dict[str, Any]) -> dict[str, Any]: + if "kwargs" not in value: + value["kwargs"] = get_litellm_retrying_config( + timeout=120, # 2-min timeout seemed reasonable + ) + return value def _truncate_if_large(self, texts: list[str]) -> list[str]: """Truncate texts if they are too large by using litellm cost map.""" @@ -489,6 +510,11 @@ async def rate_limited_generator() -> AsyncGenerator[LLMModelOrChild, None]: ) +def get_litellm_retrying_config(timeout: float = 60.0) -> dict[str, Any]: + """Get retrying configuration for litellm.acompletion and litellm.aembedding.""" + return {"num_retries": 3, "timeout": timeout} + + class LiteLLMModel(LLMModel): """A wrapper around the litellm library. @@ -527,10 +553,8 @@ def maybe_set_config_attribute(cls, data: dict[str, Any]) -> dict[str, Any]: } | data.get("config", {}) if "router_kwargs" not in data.get("config", {}): - data["config"]["router_kwargs"] = { - "num_retries": 3, - "retry_after": 5, - "timeout": 60, + data["config"]["router_kwargs"] = get_litellm_retrying_config() | { + "retry_after": 5 } # we only support one "model name" for now, here we validate @@ -769,7 +793,6 @@ async def similarity_search( def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel: - if embedding.startswith("hybrid"): embedding_model_name = "-".join(embedding.split("-")[1:]) return HybridEmbeddingModel( @@ -780,5 +803,6 @@ def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel: ) if embedding == "sparse": return SparseEmbeddingModel(**kwargs) - - return LiteLLMEmbeddingModel(name=embedding, config=kwargs) + if kwargs: # Only override the default config if there are actually kwargs + kwargs = {"config": kwargs} + return LiteLLMEmbeddingModel(name=embedding, **kwargs)