Skip to content

Commit 2bc29a2

Browse files
Haystack llm and embedding wrapper (#1901)
Co-authored-by: Jithin James <[email protected]>
1 parent 76e14b0 commit 2bc29a2

File tree

9 files changed

+356
-13
lines changed

9 files changed

+356
-13
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ dev = [
6060
"rapidfuzz",
6161
"pandas",
6262
"datacompy",
63+
"haystack-ai",
64+
"sacrebleu",
6365
"r2r",
6466
]
6567
test = [

requirements/dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ nltk
1717
rapidfuzz
1818
pandas
1919
datacompy
20+
haystack-ai

src/ragas/embeddings/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
LlamaIndexEmbeddingsWrapper,
66
embedding_factory,
77
)
8+
from ragas.embeddings.haystack_wrapper import HaystackEmbeddingsWrapper
89

910
__all__ = [
1011
"BaseRagasEmbeddings",
12+
"HaystackEmbeddingsWrapper",
13+
"HuggingfaceEmbeddings",
1114
"LangchainEmbeddingsWrapper",
1215
"LlamaIndexEmbeddingsWrapper",
13-
"HuggingfaceEmbeddings",
1416
"embedding_factory",
1517
]

src/ragas/embeddings/base.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import typing as t
55
from abc import ABC, abstractmethod
66
from dataclasses import field
7-
from typing import List
87

98
import numpy as np
109
from langchain_core.embeddings import Embeddings
@@ -51,15 +50,15 @@ def __init__(self, cache: t.Optional[CacheInterface] = None):
5150
self.aembed_documents
5251
)
5352

54-
async def embed_text(self, text: str, is_async=True) -> List[float]:
53+
async def embed_text(self, text: str, is_async=True) -> t.List[float]:
5554
"""
5655
Embed a single text string.
5756
"""
5857
embs = await self.embed_texts([text], is_async=is_async)
5958
return embs[0]
6059

6160
async def embed_texts(
62-
self, texts: List[str], is_async: bool = True
61+
self, texts: t.List[str], is_async: bool = True
6362
) -> t.List[t.List[float]]:
6463
"""
6564
Embed multiple texts.
@@ -77,10 +76,10 @@ async def embed_texts(
7776
return await loop.run_in_executor(None, embed_documents_with_retry, texts)
7877

7978
@abstractmethod
80-
async def aembed_query(self, text: str) -> List[float]: ...
79+
async def aembed_query(self, text: str) -> t.List[float]: ...
8180

8281
@abstractmethod
83-
async def aembed_documents(self, texts: List[str]) -> t.List[t.List[float]]: ...
82+
async def aembed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]: ...
8483

8584
def set_run_config(self, run_config: RunConfig):
8685
"""
@@ -117,25 +116,25 @@ def __init__(
117116
run_config = RunConfig()
118117
self.set_run_config(run_config)
119118

120-
def embed_query(self, text: str) -> List[float]:
119+
def embed_query(self, text: str) -> t.List[float]:
121120
"""
122121
Embed a single query text.
123122
"""
124123
return self.embeddings.embed_query(text)
125124

126-
def embed_documents(self, texts: List[str]) -> List[List[float]]:
125+
def embed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
127126
"""
128127
Embed multiple documents.
129128
"""
130129
return self.embeddings.embed_documents(texts)
131130

132-
async def aembed_query(self, text: str) -> List[float]:
131+
async def aembed_query(self, text: str) -> t.List[float]:
133132
"""
134133
Asynchronously embed a single query text.
135134
"""
136135
return await self.embeddings.aembed_query(text)
137136

138-
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
137+
async def aembed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
139138
"""
140139
Asynchronously embed multiple documents.
141140
"""
@@ -256,13 +255,13 @@ def __post_init__(self):
256255
if self.cache is not None:
257256
self.predict = cacher(cache_backend=self.cache)(self.predict)
258257

259-
def embed_query(self, text: str) -> List[float]:
258+
def embed_query(self, text: str) -> t.List[float]:
260259
"""
261260
Embed a single query text.
262261
"""
263262
return self.embed_documents([text])[0]
264263

265-
def embed_documents(self, texts: List[str]) -> List[List[float]]:
264+
def embed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
266265
"""
267266
Embed multiple documents.
268267
"""
@@ -279,7 +278,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
279278
assert isinstance(embeddings, Tensor)
280279
return embeddings.tolist()
281280

282-
def predict(self, texts: List[List[str]]) -> List[List[float]]:
281+
def predict(self, texts: t.List[t.List[str]]) -> t.List[t.List[float]]:
283282
"""
284283
Make predictions using a cross-encoder model.
285284
"""
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import asyncio
2+
import typing as t
3+
4+
from ragas.cache import CacheInterface
5+
from ragas.embeddings.base import BaseRagasEmbeddings
6+
from ragas.run_config import RunConfig
7+
8+
9+
class HaystackEmbeddingsWrapper(BaseRagasEmbeddings):
10+
"""
11+
A wrapper for using Haystack embedders within the Ragas framework.
12+
13+
This class allows you to use both synchronous and asynchronous methods
14+
(`embed_query`/`embed_documents` and `aembed_query`/`aembed_documents`)
15+
for generating embeddings through a Haystack embedder.
16+
17+
Parameters
18+
----------
19+
embedder : AzureOpenAITextEmbedder | HuggingFaceAPITextEmbedder | OpenAITextEmbedder | SentenceTransformersTextEmbedder
20+
An instance of a supported Haystack embedder class.
21+
run_config : RunConfig, optional
22+
A configuration object to manage embedding execution settings, by default None.
23+
cache : CacheInterface, optional
24+
A cache instance for storing and retrieving embedding results, by default None.
25+
"""
26+
27+
def __init__(
28+
self,
29+
embedder: t.Any,
30+
run_config: t.Optional[RunConfig] = None,
31+
cache: t.Optional[CacheInterface] = None,
32+
):
33+
super().__init__(cache=cache)
34+
35+
# Lazy Import of required Haystack components
36+
try:
37+
from haystack import AsyncPipeline
38+
from haystack.components.embedders import (
39+
AzureOpenAITextEmbedder,
40+
HuggingFaceAPITextEmbedder,
41+
OpenAITextEmbedder,
42+
SentenceTransformersTextEmbedder,
43+
)
44+
except ImportError as exc:
45+
raise ImportError(
46+
"Haystack is not installed. Please install it with `pip install haystack-ai`."
47+
) from exc
48+
49+
# Validate embedder type
50+
if not isinstance(
51+
embedder,
52+
(
53+
AzureOpenAITextEmbedder,
54+
HuggingFaceAPITextEmbedder,
55+
OpenAITextEmbedder,
56+
SentenceTransformersTextEmbedder,
57+
),
58+
):
59+
raise TypeError(
60+
"Expected 'embedder' to be one of: AzureOpenAITextEmbedder, "
61+
"HuggingFaceAPITextEmbedder, OpenAITextEmbedder, or "
62+
f"SentenceTransformersTextEmbedder, but got {type(embedder).__name__}."
63+
)
64+
65+
self.embedder = embedder
66+
67+
# Initialize an asynchronous pipeline and add the embedder component
68+
self.async_pipeline = AsyncPipeline()
69+
self.async_pipeline.add_component("embedder", self.embedder)
70+
71+
# Set or create the run configuration
72+
if run_config is None:
73+
run_config = RunConfig()
74+
self.set_run_config(run_config)
75+
76+
def embed_query(self, text: str) -> t.List[float]:
77+
result = self.embedder.run(text=text)
78+
return result["embedding"]
79+
80+
def embed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
81+
return [self.embed_query(text) for text in texts]
82+
83+
async def aembed_query(self, text: str) -> t.List[float]:
84+
# Run the async pipeline with the input text
85+
output = await self.async_pipeline.run_async({"embedder": {"text": text}})
86+
return output.get("embedder", {}).get("embedding", [])
87+
88+
async def aembed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
89+
tasks = (self.aembed_query(text) for text in texts)
90+
results = await asyncio.gather(*tasks)
91+
return results
92+
93+
def __repr__(self) -> str:
94+
try:
95+
from haystack.components.embedders import (
96+
AzureOpenAITextEmbedder,
97+
HuggingFaceAPITextEmbedder,
98+
OpenAITextEmbedder,
99+
SentenceTransformersTextEmbedder,
100+
)
101+
except ImportError:
102+
return f"{self.__class__.__name__}(embeddings=Unknown(...))"
103+
104+
if isinstance(self.embedder, (OpenAITextEmbedder, SentenceTransformersTextEmbedder)): # type: ignore
105+
model_info = self.embedder.model
106+
elif isinstance(self.embedder, AzureOpenAITextEmbedder): # type: ignore
107+
model_info = self.embedder.azure_deployment
108+
elif isinstance(self.embedder, HuggingFaceAPITextEmbedder): # type: ignore
109+
model_info = self.embedder.api_params
110+
else:
111+
model_info = "Unknown"
112+
113+
return f"{self.__class__.__name__}(embeddings={model_info}(...))"

src/ragas/llms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
LlamaIndexLLMWrapper,
55
llm_factory,
66
)
7+
from ragas.llms.haystack_wrapper import HaystackLLMWrapper
78

89
__all__ = [
910
"BaseRagasLLM",
11+
"HaystackLLMWrapper",
1012
"LangchainLLMWrapper",
1113
"LlamaIndexLLMWrapper",
1214
"llm_factory",

src/ragas/llms/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from langchain_core.prompt_values import PromptValue
2525
from llama_index.core.base.llms.base import BaseLLM
2626

27+
2728
logger = logging.getLogger(__name__)
2829

2930
MULTIPLE_COMPLETION_SUPPORTED = [

src/ragas/llms/haystack_wrapper.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import typing as t
2+
3+
from langchain_core.callbacks import Callbacks
4+
from langchain_core.outputs import Generation, LLMResult
5+
from langchain_core.prompt_values import PromptValue
6+
7+
from ragas.cache import CacheInterface
8+
from ragas.llms import BaseRagasLLM
9+
from ragas.run_config import RunConfig
10+
11+
12+
class HaystackLLMWrapper(BaseRagasLLM):
13+
"""
14+
A wrapper class for using Haystack LLM generators within the Ragas framework.
15+
16+
This class integrates Haystack's LLM components (e.g., `OpenAIGenerator`,
17+
`HuggingFaceAPIGenerator`, etc.) into Ragas, enabling both synchronous and
18+
asynchronous text generation.
19+
20+
Parameters
21+
----------
22+
haystack_generator : AzureOpenAIGenerator | HuggingFaceAPIGenerator | HuggingFaceLocalGenerator | OpenAIGenerator
23+
An instance of a Haystack generator.
24+
run_config : RunConfig, optional
25+
Configuration object to manage LLM execution settings, by default None.
26+
cache : CacheInterface, optional
27+
A cache instance for storing results, by default None.
28+
"""
29+
30+
def __init__(
31+
self,
32+
haystack_generator: t.Any,
33+
run_config: t.Optional[RunConfig] = None,
34+
cache: t.Optional[CacheInterface] = None,
35+
):
36+
super().__init__(cache=cache)
37+
38+
# Lazy Import of required Haystack components
39+
try:
40+
from haystack import AsyncPipeline
41+
from haystack.components.generators import (
42+
AzureOpenAIGenerator,
43+
HuggingFaceAPIGenerator,
44+
HuggingFaceLocalGenerator,
45+
OpenAIGenerator,
46+
)
47+
except ImportError as exc:
48+
raise ImportError(
49+
"Haystack is not installed. Please install it using `pip install haystack-ai`."
50+
) from exc
51+
52+
# Validate haystack_generator type
53+
if not isinstance(
54+
haystack_generator,
55+
(
56+
AzureOpenAIGenerator,
57+
HuggingFaceAPIGenerator,
58+
HuggingFaceLocalGenerator,
59+
OpenAIGenerator,
60+
),
61+
):
62+
raise TypeError(
63+
"Expected 'haystack_generator' to be one of: "
64+
"AzureOpenAIGenerator, HuggingFaceAPIGenerator, "
65+
"HuggingFaceLocalGenerator, or OpenAIGenerator, but received "
66+
f"{type(haystack_generator).__name__}."
67+
)
68+
69+
# Set up Haystack pipeline and generator
70+
self.generator = haystack_generator
71+
self.async_pipeline = AsyncPipeline()
72+
self.async_pipeline.add_component("llm", self.generator)
73+
74+
if run_config is None:
75+
run_config = RunConfig()
76+
self.set_run_config(run_config)
77+
78+
def is_finished(self, response: LLMResult) -> bool:
79+
return True
80+
81+
def generate_text(
82+
self,
83+
prompt: PromptValue,
84+
n: int = 1,
85+
temperature: float = 1e-8,
86+
stop: t.Optional[t.List[str]] = None,
87+
callbacks: t.Optional[Callbacks] = None,
88+
) -> LLMResult:
89+
90+
component_output: t.Dict[str, t.Any] = self.generator.run(prompt.to_string())
91+
replies = component_output.get("llm", {}).get("replies", [])
92+
output_text = replies[0] if replies else ""
93+
94+
return LLMResult(generations=[[Generation(text=output_text)]])
95+
96+
async def agenerate_text(
97+
self,
98+
prompt: PromptValue,
99+
n: int = 1,
100+
temperature: t.Optional[float] = None,
101+
stop: t.Optional[t.List[str]] = None,
102+
callbacks: t.Optional[Callbacks] = None,
103+
) -> LLMResult:
104+
# Prepare input parameters for the LLM component
105+
llm_input = {
106+
"prompt": prompt.to_string(),
107+
"generation_kwargs": {"temperature": temperature},
108+
}
109+
110+
# Run the async pipeline with the LLM input
111+
pipeline_output = await self.async_pipeline.run_async(data={"llm": llm_input})
112+
replies = pipeline_output.get("llm", {}).get("replies", [])
113+
output_text = replies[0] if replies else ""
114+
115+
return LLMResult(generations=[[Generation(text=output_text)]])
116+
117+
def __repr__(self) -> str:
118+
try:
119+
from haystack.components.generators import (
120+
AzureOpenAIGenerator,
121+
HuggingFaceAPIGenerator,
122+
HuggingFaceLocalGenerator,
123+
OpenAIGenerator,
124+
)
125+
except ImportError:
126+
return f"{self.__class__.__name__}(llm=Unknown(...))"
127+
128+
generator = self.generator
129+
130+
if isinstance(generator, OpenAIGenerator):
131+
model_info = generator.model
132+
elif isinstance(generator, HuggingFaceLocalGenerator):
133+
model_info = generator.huggingface_pipeline_kwargs.get("model")
134+
elif isinstance(generator, HuggingFaceAPIGenerator):
135+
model_info = generator.api_params.get("model")
136+
elif isinstance(generator, AzureOpenAIGenerator):
137+
model_info = generator.azure_deployment
138+
else:
139+
model_info = "Unknown"
140+
141+
return f"{self.__class__.__name__}(llm={model_info}(...))"

0 commit comments

Comments
 (0)