diff --git a/langrade/__init__.py b/langrade/__init__.py index 6712bb6..15de499 100644 --- a/langrade/__init__.py +++ b/langrade/__init__.py @@ -1,10 +1,14 @@ from .grader import document_grader, DocumentGrader from .retriever import create_retriever +from .criteria import RelevanceCriterion, ReadabilityCriterion, CoherenceCriterion from .providers import LLMProviderFactory __all__ = [ "document_grader", "DocumentGrader", "create_retriever", + "RelevanceCriterion", + "ReadabilityCriterion", + "CoherenceCriterion", "LLMProviderFactory", -] # noqa: E501 +] diff --git a/langrade/constants.py b/langrade/constants.py index fb9a465..522dbd1 100644 --- a/langrade/constants.py +++ b/langrade/constants.py @@ -10,3 +10,42 @@ REASONING_DESCRIPTION = "Thinking process to give a correct binary score shortly." BINARY_SCORE_DESCRIPTION = "Documents are relevant to the question, 'yes' or 'no'" + +RELEVANCE_PROMPT = """ +Assess the relevance of the retrieved document to the user question. +Your goal is to filter out erroneous retrievals without being overly strict. +If the document contains keywords or semantic meanings related to the user question, grade it as relevant. + +Retrieved document: +{document} + +User question: +{question} + +Binary score (yes/no): +Reasoning: +""" + +READABILITY_PROMPT = """ +Assess the readability of the given document. +Consider factors such as sentence structure, vocabulary, and overall clarity. +If the document is easy to understand for a general audience, grade it as readable. + +Document to assess: +{document} + +Binary score (yes/no): +Reasoning: +""" + +COHERENCE_PROMPT = """ +Assess the coherence of the given document. +Consider factors such as logical flow, organization of ideas, and overall structure. +If the document presents ideas in a clear and logical manner, grade it as coherent. + +Document to assess: +{document} + +Binary score (yes/no): +Reasoning: +""" diff --git a/langrade/criteria.py b/langrade/criteria.py new file mode 100644 index 0000000..8eae163 --- /dev/null +++ b/langrade/criteria.py @@ -0,0 +1,96 @@ +# criteria.py + +from abc import ABC, abstractmethod +from pydantic import BaseModel, Field + +from langrade.constants import BINARY_SCORE_DESCRIPTION, REASONING_DESCRIPTION +from .providers import LLMProvider + + +class CriterionResult(BaseModel): + binary_score: str = Field(description=BINARY_SCORE_DESCRIPTION) + reasoning: str = Field(description=REASONING_DESCRIPTION) + + +class EvaluationCriterion(ABC): + def __init__(self, llm_provider: LLMProvider): + self.llm_provider = llm_provider + + @abstractmethod + async def evaluate(self, document: str, question: str = None) -> CriterionResult: + pass + + @property + @abstractmethod + def name(self) -> str: + pass + + @property + @abstractmethod + def prompt(self) -> str: + pass + + +class RelevanceCriterion(EvaluationCriterion): + @property + def name(self) -> str: + return "Relevance" + + @property + def prompt(self) -> str: + from .constants import RELEVANCE_PROMPT + + return RELEVANCE_PROMPT + + async def evaluate(self, document: str, question: str) -> CriterionResult: + response = await self.llm_provider.agenerate( + self.prompt.format(document=document, question=question) + ) + # Parse the response to extract score and explanation + # This is a simplified example; you might need more robust parsing + lines = response.split("\n") + binary_score = lines[0].split(":")[1].strip().lower() + reasoning = lines[1].split(":")[1].strip() + return CriterionResult(binary_score=binary_score, reasoning=reasoning) + + +class ReadabilityCriterion(EvaluationCriterion): + @property + def name(self) -> str: + return "Readability" + + @property + def prompt(self) -> str: + from .constants import READABILITY_PROMPT + + return READABILITY_PROMPT + + async def evaluate(self, document: str, question: str = None) -> CriterionResult: + response = await self.llm_provider.agenerate( + self.prompt.format(document=document) + ) + lines = response.split("\n") + binary_score = lines[0].split(":")[1].strip().lower() + reasoning = lines[1].split(":")[1].strip() + return CriterionResult(binary_score=binary_score, reasoning=reasoning) + + +class CoherenceCriterion(EvaluationCriterion): + @property + def name(self) -> str: + return "Coherence" + + @property + def prompt(self) -> str: + from .constants import COHERENCE_PROMPT + + return COHERENCE_PROMPT + + async def evaluate(self, document: str, question: str = None) -> CriterionResult: + response = await self.llm_provider.agenerate( + self.prompt.format(document=document) + ) + lines = response.split("\n") + binary_score = lines[0].split(":")[1].strip().lower() + reasoning = lines[1].split(":")[1].strip() + return CriterionResult(binary_score=binary_score, reasoning=reasoning) diff --git a/langrade/grader.py b/langrade/grader.py index 46e58f4..5af6a6d 100644 --- a/langrade/grader.py +++ b/langrade/grader.py @@ -1,20 +1,46 @@ -from .providers import LLMProviderFactory, GradeDocuments +from typing import List, Dict +from .providers import LLMProviderFactory +from .criteria import ( + EvaluationCriterion, + RelevanceCriterion, + ReadabilityCriterion, + CoherenceCriterion, +) from .constants import DEFAULT_MODEL class DocumentGrader: def __init__( - self, provider: str, api_key: str, model: str = DEFAULT_MODEL - ): # noqa: E501 - self.provider = LLMProviderFactory.create_provider( - provider, api_key, model - ) # noqa: E501 + self, + provider: str, + api_key: str, + model: str = DEFAULT_MODEL, + criteria: List[EvaluationCriterion] = None, + ): + self.llm_provider = LLMProviderFactory.create_provider(provider, api_key, model) + self.criteria = criteria or [ + RelevanceCriterion(self.llm_provider), + ReadabilityCriterion(self.llm_provider), + CoherenceCriterion(self.llm_provider), + ] - def grade_document(self, document: str, question: str) -> GradeDocuments: - return self.provider.grade_document(document, question) + async def grade_document( + self, document: str, question: str = None + ) -> Dict[str, Dict[str, float]]: + results = {} + for criterion in self.criteria: + result = await criterion.evaluate(document, question) + results[criterion.name] = { + "binary_score": result.binary_score, + "reasoning": result.reasoning, + } + return results def document_grader( - provider: str, api_key: str, model: str = DEFAULT_MODEL + provider: str, + api_key: str, + model: str = DEFAULT_MODEL, + criteria: List[EvaluationCriterion] = None, ) -> DocumentGrader: - return DocumentGrader(provider, api_key, model) + return DocumentGrader(provider, api_key, model, criteria) diff --git a/langrade/providers.py b/langrade/providers.py index 6d9974a..a0b92f3 100644 --- a/langrade/providers.py +++ b/langrade/providers.py @@ -1,47 +1,28 @@ from abc import ABC, abstractmethod -from pydantic import BaseModel, Field from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI -from langchain_groq import ChatGroq from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI -from .constants import ( - SYSTEM_PROMPT, - DEFAULT_MODEL, - REASONING_DESCRIPTION, - BINARY_SCORE_DESCRIPTION, -) - - -class GradeDocuments(BaseModel): - reasoning: str = Field(description=REASONING_DESCRIPTION) - binary_score: str = Field(description=BINARY_SCORE_DESCRIPTION) +from langchain_groq import ChatGroq +from .constants import SYSTEM_PROMPT, DEFAULT_MODEL class LLMProvider(ABC): def __init__(self, api_key: str, model: str): self.llm = self._create_llm(api_key, model) - self.prompt = self._create_prompt() @abstractmethod def _create_llm(self, api_key: str, model: str): pass - def _create_prompt(self): - return ChatPromptTemplate.from_messages( - [ - ("system", SYSTEM_PROMPT), - ( - "human", - "Retrieved document: \n\n {document} \n\n User question: \n\n {question}", # noqa: E501 - ), - ] - ) - - def grade_document(self, document: str, question: str) -> GradeDocuments: - structured_llm = self.llm.with_structured_output(GradeDocuments) - chain = self.prompt | structured_llm - return chain.invoke({"document": document, "question": question}) + async def agenerate(self, prompt: str) -> str: + messages = [ + ("human", prompt), + ] + chat_prompt = ChatPromptTemplate.from_messages(messages) + chain = chat_prompt | self.llm + result = await chain.ainvoke({}) + return result.content class OpenAIProvider(LLMProvider): @@ -56,12 +37,12 @@ def _create_llm(self, api_key: str, model: str): class GoogleProvider(LLMProvider): def _create_llm(self, api_key: str, model: str): - return ChatGoogleGenerativeAI(google_api_key=api_key, model=model) # noqa: E501 + return ChatGoogleGenerativeAI(google_api_key=api_key, model=model) class GroqProvider(LLMProvider): def _create_llm(self, api_key: str, model: str): - return ChatGroq(groq_api_key=api_key, model_name=model) # noqa: E501 + return ChatGroq(groq_api_key=api_key, model_name=model) class LLMProviderFactory: