Skip to content

Commit

Permalink
feat: update Cohere embedding model to use langchain_cohere, added su…
Browse files Browse the repository at this point in the history
…pport to dynamically load latest embedding models, improved error handling (#6034)

* update cohere model

* Update src/backend/base/langflow/components/embeddings/cohere.py

Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]>

---------

Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]>
  • Loading branch information
edwinjosechittilappilly and ogabrielluiz authored Feb 6, 2025
1 parent 5f63ca0 commit e89edc3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ dependencies = [
"uv>=0.5.7",
"ag2>=0.1.0",
"scrapegraph-py>=1.10.2",
"pydantic-ai>=0.0.19"
"pydantic-ai>=0.0.19",
]

[tool.uv.sources]
Expand Down
57 changes: 46 additions & 11 deletions src/backend/base/langflow/components/embeddings/cohere.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from langchain_community.embeddings.cohere import CohereEmbeddings
from typing import Any

import cohere
from langchain_cohere import CohereEmbeddings

from langflow.base.models.model import LCModelComponent
from langflow.field_typing import Embeddings
from langflow.io import DropdownInput, FloatInput, IntInput, MessageTextInput, Output, SecretStrInput

HTTP_STATUS_OK = 200


class CohereEmbeddingsComponent(LCModelComponent):
display_name = "Cohere Embeddings"
Expand All @@ -12,9 +17,9 @@ class CohereEmbeddingsComponent(LCModelComponent):
name = "CohereEmbeddings"

inputs = [
SecretStrInput(name="cohere_api_key", display_name="Cohere API Key", required=True),
SecretStrInput(name="api_key", display_name="Cohere API Key", required=True, real_time_refresh=True),
DropdownInput(
name="model",
name="model_name",
display_name="Model",
advanced=False,
options=[
Expand All @@ -24,6 +29,8 @@ class CohereEmbeddingsComponent(LCModelComponent):
"embed-multilingual-light-v2.0",
],
value="embed-english-v2.0",
refresh_button=True,
combobox=True,
),
MessageTextInput(name="truncate", display_name="Truncate", advanced=True),
IntInput(name="max_retries", display_name="Max Retries", value=3, advanced=True),
Expand All @@ -36,11 +43,39 @@ class CohereEmbeddingsComponent(LCModelComponent):
]

def build_embeddings(self) -> Embeddings:
return CohereEmbeddings(
cohere_api_key=self.cohere_api_key,
model=self.model,
truncate=self.truncate,
max_retries=self.max_retries,
user_agent=self.user_agent,
request_timeout=self.request_timeout or None,
)
data = None
try:
data = CohereEmbeddings(
cohere_api_key=self.api_key,
model=self.model_name,
truncate=self.truncate,
max_retries=self.max_retries,
user_agent=self.user_agent,
request_timeout=self.request_timeout or None,
)
except Exception as e:
msg = (
"Unable to create Cohere Embeddings. ",
"Please verify the API key and model parameters, and try again.",
)
raise ValueError(msg) from e
# added status if not the return data would be serialised to create the status
return data

def get_model(self):
try:
co = cohere.ClientV2(self.api_key)
response = co.models.list(endpoint="embed")
models = response.models
return [model.name for model in models]
except Exception as e:
msg = f"Failed to fetch Cohere models. Error: {e}"
raise ValueError(msg) from e

async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name in {"model_name", "api_key"}:
if build_config.get("api_key", {}).get("value", None):
build_config["model_name"]["options"] = self.get_model()
else:
build_config["model_name"]["options"] = field_value
return build_config

0 comments on commit e89edc3

Please sign in to comment.