Skip to content

Commit

Permalink
refactor: changing LLMMetadaExtractor to use chat generators (#188)
Browse files Browse the repository at this point in the history
* chaning to ChatGenerators

* linting

* updating tests
  • Loading branch information
davidsbatista authored Feb 10, 2025
1 parent 4aa3bf7 commit c689c05
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.components.builders import PromptBuilder
from haystack.components.generators import AzureOpenAIGenerator, OpenAIGenerator
from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator
from haystack.components.preprocessors import DocumentSplitter
from haystack.dataclasses import ChatMessage
from haystack.lazy_imports import LazyImport
from haystack.utils import deserialize_callable, deserialize_secrets_inplace
from jinja2 import meta
Expand All @@ -20,10 +21,10 @@
from haystack_experimental.util.utils import expand_page_range

with LazyImport(message="Run 'pip install \"amazon-bedrock-haystack>=1.0.2\"'") as amazon_bedrock_generator:
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator

with LazyImport(message="Run 'pip install \"google-vertex-haystack>=2.0.0\"'") as vertex_ai_gemini_generator:
from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator
from haystack_integrations.components.generators.google_vertex.chat.gemini import VertexAIGeminiChatGenerator
from vertexai.generative_models import GenerationConfig


Expand Down Expand Up @@ -192,7 +193,6 @@ def __init__( # pylint: disable=R0917
f"Prompt must have exactly one variable called 'document'. Found {','.join(variables)} in the prompt."
)
self.builder = PromptBuilder(prompt, required_variables=variables)

self.raise_on_failure = raise_on_failure
self.expected_keys = expected_keys or []
self.generator_api = generator_api if isinstance(generator_api, LLMProvider) \
Expand All @@ -207,20 +207,22 @@ def __init__( # pylint: disable=R0917
def _init_generator(
generator_api: LLMProvider,
generator_api_params: Optional[Dict[str, Any]]
) -> Union[OpenAIGenerator, AzureOpenAIGenerator, "AmazonBedrockGenerator", "VertexAIGeminiGenerator"]:
) -> Union[
OpenAIChatGenerator, AzureOpenAIChatGenerator, "AmazonBedrockChatGenerator", "VertexAIGeminiChatGenerator"
]:
"""
Initialize the chat generator based on the specified API provider and parameters.
"""
if generator_api == LLMProvider.OPENAI:
return OpenAIGenerator(**generator_api_params)
return OpenAIChatGenerator(**generator_api_params)
elif generator_api == LLMProvider.OPENAI_AZURE:
return AzureOpenAIGenerator(**generator_api_params)
return AzureOpenAIChatGenerator(**generator_api_params)
elif generator_api == LLMProvider.AWS_BEDROCK:
amazon_bedrock_generator.check()
return AmazonBedrockGenerator(**generator_api_params)
return AmazonBedrockChatGenerator(**generator_api_params)
elif generator_api == LLMProvider.GOOGLE_VERTEX:
vertex_ai_gemini_generator.check()
return VertexAIGeminiGenerator(**generator_api_params)
return VertexAIGeminiChatGenerator(**generator_api_params)
else:
raise ValueError(f"Unsupported generator API: {generator_api}")

Expand Down Expand Up @@ -318,8 +320,8 @@ def _prepare_prompts(
self,
documents: List[Document],
expanded_range: Optional[List[int]] = None
) -> List[Union[str, None]]:
all_prompts: List[Union[str, None]] = []
) -> List[Union[ChatMessage, None]]:
all_prompts: List[Union[ChatMessage, None]] = []
for document in documents:
if not document.content:
logger.warning(
Expand All @@ -341,19 +343,23 @@ def _prepare_prompts(
doc_copy = document

prompt_with_doc = self.builder.run(
template=self.prompt,
template_variables={"document": doc_copy}
)
all_prompts.append(prompt_with_doc["prompt"])
template=self.prompt,
template_variables={"document": doc_copy}
)

# build a ChatMessage with the prompt
message = ChatMessage.from_user(prompt_with_doc["prompt"])
all_prompts.append(message)

return all_prompts

def _run_on_thread(self, prompt: Optional[str]) -> Dict[str, Any]:
def _run_on_thread(self, prompt: Optional[ChatMessage]) -> Dict[str, Any]:
# If prompt is None, return an empty dictionary
if prompt is None:
return {"replies": ["{}"]}

try:
result = self.llm_provider.run(prompt=prompt)
result = self.llm_provider.run(messages=[prompt])
except Exception as e:
logger.error(
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
Expand Down Expand Up @@ -398,7 +404,7 @@ def run(self, documents: List[Document], page_range: Optional[List[Union[str, in
if page_range:
expanded_range = expand_page_range(page_range)

# Create prompts for each document
# Create ChatMessage prompts for each document
all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range)

# Run the LLM on each prompt
Expand All @@ -414,7 +420,7 @@ def run(self, documents: List[Document], page_range: Optional[List[Union[str, in
failed_documents.append(document)
continue

parsed_metadata = self._extract_metadata(result["replies"][0])
parsed_metadata = self._extract_metadata(result["replies"][0].text)
if "error" in parsed_metadata:
document.meta["metadata_extraction_error"] = parsed_metadata["error"]
document.meta["metadata_extraction_response"] = result["replies"][0]
Expand Down
58 changes: 30 additions & 28 deletions test/components/extractors/test_llm_metadata_extractor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import boto3
import os
import pytest
from unittest.mock import MagicMock

from haystack import Pipeline, Document
import boto3
import pytest
from haystack import Document, Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.writers import DocumentWriter
from haystack.dataclasses import ChatMessage
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack_experimental.components.extractors import LLMMetadataExtractor
from haystack_experimental.components.extractors import LLMProvider

from haystack_experimental.components.extractors import LLMMetadataExtractor, LLMProvider


class TestLLMMetadataExtractor:
Expand Down Expand Up @@ -95,7 +96,10 @@ def test_to_dict_openai(self, monkeypatch):
"model": "gpt-4o-mini",
"organization": None,
"streaming_callback": None,
"system_prompt": None,
"max_retries": None,
"timeout": None,
"tools": None,
"tools_strict": False,
},
"max_workers": 3,
},
Expand All @@ -106,11 +110,7 @@ def test_to_dict_aws_bedrock(self, boto3_session_mock):
prompt="some prompt that was used with the LLM {{document.content}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.AWS_BEDROCK,
generator_api_params={
"model": "meta.llama.test",
"max_length": 100,
"truncate": False,
},
generator_api_params={"model": "meta.llama.test"},
raise_on_failure=True,
)
extractor_dict = extractor.to_dict()
Expand Down Expand Up @@ -146,11 +146,11 @@ def test_to_dict_aws_bedrock(self, boto3_session_mock):
"strict": False,
},
"model": "meta.llama.test",
"model_family": None,
"max_length": 100,
"truncate": False,
"stop_words": [],
"generation_kwargs": {},
"streaming_callback": None,
"boto3_config": None,
"tools": None,
},
"expected_keys": ["key1", "key2"],
"page_range": None,
Expand Down Expand Up @@ -179,7 +179,6 @@ def test_from_dict_openai(self, monkeypatch):
"model": "gpt-4o-mini",
"organization": None,
"streaming_callback": None,
"system_prompt": None,
},
},
}
Expand Down Expand Up @@ -225,10 +224,11 @@ def test_from_dict_aws_bedrock(self, boto3_session_mock):
"strict": False,
},
"model": "meta.llama.test",
"max_length": 200,
"truncate": False,
"stop_words": [],
"generation_kwargs": {},
"streaming_callback": None,
"boto3_config": None,
"tools": None,
},
"expected_keys": ["key1", "key2"],
"page_range": None,
Expand All @@ -244,8 +244,6 @@ def test_from_dict_aws_bedrock(self, boto3_session_mock):
== "some prompt that was used with the LLM {{document.content}}"
)
assert extractor.generator_api == LLMProvider.AWS_BEDROCK
assert extractor.llm_provider.max_length == 200
assert extractor.llm_provider.truncate is False
assert extractor.llm_provider.model == "meta.llama.test"

def test_warm_up(self, monkeypatch):
Expand Down Expand Up @@ -288,7 +286,7 @@ def test_extract_metadata_missing_key(self, monkeypatch, caplog):
def test_prepare_prompts(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}",
prompt="some_user_definer_prompt {{document.content}}",
generator_api=LLMProvider.OPENAI,
)
docs = [
Expand All @@ -300,15 +298,16 @@ def test_prepare_prompts(self, monkeypatch):
),
]
prompts = extractor._prepare_prompts(docs)

assert prompts == [
"prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework",
"prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library",
ChatMessage.from_dict({"_role": "user", "_meta": {}, "_name": None, "_content": [{"text": "some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework"}]}),
ChatMessage.from_dict({"_role": "user", "_meta": {}, "_name": None, "_content": [{"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"}]})
]

def test_prepare_prompts_empty_document(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}",
prompt="some_user_definer_prompt {{document.content}}",
generator_api=LLMProvider.OPENAI,
)
docs = [
Expand All @@ -320,13 +319,14 @@ def test_prepare_prompts_empty_document(self, monkeypatch):
prompts = extractor._prepare_prompts(docs)
assert prompts == [
None,
"prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library",
ChatMessage.from_dict(
{"_role": "user", "_meta": {}, "_name": None, "_content": [{"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"}]})
]

def test_prepare_prompts_expanded_range(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}",
prompt="some_user_definer_prompt {{document.content}}",
generator_api=LLMProvider.OPENAI,
page_range=["1-2"],
)
Expand All @@ -336,9 +336,11 @@ def test_prepare_prompts_expanded_range(self, monkeypatch):
)
]
prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2])
assert prompts == [
"prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\fPage 2\f",
]

assert prompts == [ChatMessage.from_dict({"_role": "user",
"_meta": {},
"_name": None,
"_content": [{"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\x0cPage 2\x0c"}]})]

def test_run_no_documents(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
Expand Down

0 comments on commit c689c05

Please sign in to comment.