Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding retrying of aembedding if it fails #535

Merged
merged 3 commits into from
Oct 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
get_model_cost_map,
token_counter,
)
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
field_validator,
model_validator,
)

from paperqa.prompts import default_system_prompt
from paperqa.rate_limiter import GLOBAL_LIMITER
Expand Down Expand Up @@ -79,7 +86,7 @@ class EmbeddingModes(StrEnum):

class EmbeddingModel(ABC, BaseModel):
name: str
config: dict = Field(
config: dict[str, Any] = Field(
default_factory=dict,
description="Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem string for parsing",
)
Expand All @@ -102,7 +109,21 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]:


class LiteLLMEmbeddingModel(EmbeddingModel):

name: str = Field(default="text-embedding-3-small")
config: dict[str, Any] = Field(
default_factory=dict, # See below field_validator for injection of kwargs
description="Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem string for parsing",
)

@field_validator("config")
@classmethod
def set_up_default_config(cls, value: dict[str, Any]) -> dict[str, Any]:
if "kwargs" not in value:
value["kwargs"] = get_litellm_retrying_config(
timeout=120, # 2-min timeout seemed reasonable
)
return value

def _truncate_if_large(self, texts: list[str]) -> list[str]:
"""Truncate texts if they are too large by using litellm cost map."""
Expand Down Expand Up @@ -489,6 +510,11 @@ async def rate_limited_generator() -> AsyncGenerator[LLMModelOrChild, None]:
)


def get_litellm_retrying_config(timeout: float = 60.0) -> dict[str, Any]:
"""Get retrying configuration for litellm.acompletion and litellm.aembedding."""
return {"num_retries": 3, "timeout": timeout}


class LiteLLMModel(LLMModel):
"""A wrapper around the litellm library.

Expand Down Expand Up @@ -527,10 +553,8 @@ def maybe_set_config_attribute(cls, data: dict[str, Any]) -> dict[str, Any]:
} | data.get("config", {})

if "router_kwargs" not in data.get("config", {}):
data["config"]["router_kwargs"] = {
"num_retries": 3,
"retry_after": 5,
"timeout": 60,
data["config"]["router_kwargs"] = get_litellm_retrying_config() | {
"retry_after": 5
}

# we only support one "model name" for now, here we validate
Expand Down Expand Up @@ -769,7 +793,6 @@ async def similarity_search(


def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel:

if embedding.startswith("hybrid"):
embedding_model_name = "-".join(embedding.split("-")[1:])
return HybridEmbeddingModel(
Expand All @@ -780,5 +803,6 @@ def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel:
)
if embedding == "sparse":
return SparseEmbeddingModel(**kwargs)

return LiteLLMEmbeddingModel(name=embedding, config=kwargs)
if kwargs: # Only override the default config if there are actually kwargs
kwargs = {"config": kwargs}
return LiteLLMEmbeddingModel(name=embedding, **kwargs)