Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update Cohere embedding model to use langchain_cohere, added support to dynamically load latest embedding models, improved error handling #6034

Merged
merged 6 commits into from
Feb 6, 2025
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ dependencies = [
"uv>=0.5.7",
"ag2>=0.1.0",
"scrapegraph-py>=1.10.2",
"pydantic-ai>=0.0.19"
"pydantic-ai>=0.0.19",
]

[tool.uv.sources]
Expand Down
58 changes: 47 additions & 11 deletions src/backend/base/langflow/components/embeddings/cohere.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from langchain_community.embeddings.cohere import CohereEmbeddings
from typing import Any

import cohere
from langchain_cohere import CohereEmbeddings

from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Embeddings
from langflow.io import DropdownInput, FloatInput, IntInput, MessageTextInput, Output, SecretStrInput

HTTP_STATUS_OK = 200


class CohereEmbeddingsComponent(LCModelComponent):
display_name = "Cohere Embeddings"
Expand All @@ -12,9 +17,9 @@ class CohereEmbeddingsComponent(LCModelComponent):
name = "CohereEmbeddings"

inputs = [
SecretStrInput(name="cohere_api_key", display_name="Cohere API Key", required=True),
SecretStrInput(name="api_key", display_name="Cohere API Key", required=True, real_time_refresh=True),
DropdownInput(
name="model",
name="model_name",
display_name="Model",
advanced=False,
options=[
Expand All @@ -24,6 +29,8 @@ class CohereEmbeddingsComponent(LCModelComponent):
"embed-multilingual-light-v2.0",
],
value="embed-english-v2.0",
refresh_button=True,
combobox=True,
),
MessageTextInput(name="truncate", display_name="Truncate", advanced=True),
IntInput(name="max_retries", display_name="Max Retries", value=3, advanced=True),
Expand All @@ -36,11 +43,40 @@ class CohereEmbeddingsComponent(LCModelComponent):
]

def build_embeddings(self) -> Embeddings:
return CohereEmbeddings(
cohere_api_key=self.cohere_api_key,
model=self.model,
truncate=self.truncate,
max_retries=self.max_retries,
user_agent=self.user_agent,
request_timeout=self.request_timeout or None,
)
data = None
try:
data = CohereEmbeddings(
cohere_api_key=self.api_key,
model=self.model_name,
truncate=self.truncate,
max_retries=self.max_retries,
user_agent=self.user_agent,
request_timeout=self.request_timeout or None,
)
except Exception as e:
msg = (
"Unable to create Cohere Embeddings. ",
"Please verify the API key and model parameters, and try again.",
)
raise ValueError(msg) from e
# added status if not the return data would be serialised to create the status
self.status = "Success Cohere Embeddings Model created"
return data
edwinjosechittilappilly marked this conversation as resolved.
Show resolved Hide resolved

def get_model(self):
try:
co = cohere.ClientV2(self.api_key)
response = co.models.list(endpoint="embed")
models = response.models
return [model.name for model in models]
except Exception as e:
msg = f"Failed to fetch Cohere models. Error: {e}"
raise ValueError(msg) from e

async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name in {"model_name", "api_key"}:
if build_config.get("api_key", {}).get("value", None):
build_config["model_name"]["options"] = self.get_model()
else:
build_config["model_name"]["options"] = field_value
return build_config
Loading