Skip to content

Commit

Permalink
add: another criterions
Browse files Browse the repository at this point in the history
  • Loading branch information
nisaji committed Oct 18, 2024
1 parent 737ac23 commit f58b80b
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 42 deletions.
6 changes: 5 additions & 1 deletion langrade/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from .grader import document_grader, DocumentGrader
from .retriever import create_retriever
from .criteria import RelevanceCriterion, ReadabilityCriterion, CoherenceCriterion
from .providers import LLMProviderFactory

__all__ = [
"document_grader",
"DocumentGrader",
"create_retriever",
"RelevanceCriterion",
"ReadabilityCriterion",
"CoherenceCriterion",
"LLMProviderFactory",
] # noqa: E501
]
39 changes: 39 additions & 0 deletions langrade/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,42 @@

REASONING_DESCRIPTION = "Thinking process to give a correct binary score shortly."
BINARY_SCORE_DESCRIPTION = "Documents are relevant to the question, 'yes' or 'no'"

RELEVANCE_PROMPT = """
Assess the relevance of the retrieved document to the user question.
Your goal is to filter out erroneous retrievals without being overly strict.
If the document contains keywords or semantic meanings related to the user question, grade it as relevant.
Retrieved document:
{document}
User question:
{question}
Binary score (yes/no):
Reasoning:
"""

READABILITY_PROMPT = """
Assess the readability of the given document.
Consider factors such as sentence structure, vocabulary, and overall clarity.
If the document is easy to understand for a general audience, grade it as readable.
Document to assess:
{document}
Binary score (yes/no):
Reasoning:
"""

COHERENCE_PROMPT = """
Assess the coherence of the given document.
Consider factors such as logical flow, organization of ideas, and overall structure.
If the document presents ideas in a clear and logical manner, grade it as coherent.
Document to assess:
{document}
Binary score (yes/no):
Reasoning:
"""
96 changes: 96 additions & 0 deletions langrade/criteria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# criteria.py

from abc import ABC, abstractmethod
from pydantic import BaseModel, Field

from langrade.constants import BINARY_SCORE_DESCRIPTION, REASONING_DESCRIPTION
from .providers import LLMProvider


class CriterionResult(BaseModel):
binary_score: str = Field(description=BINARY_SCORE_DESCRIPTION)
reasoning: str = Field(description=REASONING_DESCRIPTION)


class EvaluationCriterion(ABC):
def __init__(self, llm_provider: LLMProvider):
self.llm_provider = llm_provider

@abstractmethod
async def evaluate(self, document: str, question: str = None) -> CriterionResult:
pass

@property
@abstractmethod
def name(self) -> str:
pass

@property
@abstractmethod
def prompt(self) -> str:
pass


class RelevanceCriterion(EvaluationCriterion):
@property
def name(self) -> str:
return "Relevance"

@property
def prompt(self) -> str:
from .constants import RELEVANCE_PROMPT

return RELEVANCE_PROMPT

async def evaluate(self, document: str, question: str) -> CriterionResult:
response = await self.llm_provider.agenerate(
self.prompt.format(document=document, question=question)
)
# Parse the response to extract score and explanation
# This is a simplified example; you might need more robust parsing
lines = response.split("\n")
binary_score = lines[0].split(":")[1].strip().lower()
reasoning = lines[1].split(":")[1].strip()
return CriterionResult(binary_score=binary_score, reasoning=reasoning)


class ReadabilityCriterion(EvaluationCriterion):
@property
def name(self) -> str:
return "Readability"

@property
def prompt(self) -> str:
from .constants import READABILITY_PROMPT

return READABILITY_PROMPT

async def evaluate(self, document: str, question: str = None) -> CriterionResult:
response = await self.llm_provider.agenerate(
self.prompt.format(document=document)
)
lines = response.split("\n")
binary_score = lines[0].split(":")[1].strip().lower()
reasoning = lines[1].split(":")[1].strip()
return CriterionResult(binary_score=binary_score, reasoning=reasoning)


class CoherenceCriterion(EvaluationCriterion):
@property
def name(self) -> str:
return "Coherence"

@property
def prompt(self) -> str:
from .constants import COHERENCE_PROMPT

return COHERENCE_PROMPT

async def evaluate(self, document: str, question: str = None) -> CriterionResult:
response = await self.llm_provider.agenerate(
self.prompt.format(document=document)
)
lines = response.split("\n")
binary_score = lines[0].split(":")[1].strip().lower()
reasoning = lines[1].split(":")[1].strip()
return CriterionResult(binary_score=binary_score, reasoning=reasoning)
46 changes: 36 additions & 10 deletions langrade/grader.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,46 @@
from .providers import LLMProviderFactory, GradeDocuments
from typing import List, Dict
from .providers import LLMProviderFactory
from .criteria import (
EvaluationCriterion,
RelevanceCriterion,
ReadabilityCriterion,
CoherenceCriterion,
)
from .constants import DEFAULT_MODEL


class DocumentGrader:
def __init__(
self, provider: str, api_key: str, model: str = DEFAULT_MODEL
): # noqa: E501
self.provider = LLMProviderFactory.create_provider(
provider, api_key, model
) # noqa: E501
self,
provider: str,
api_key: str,
model: str = DEFAULT_MODEL,
criteria: List[EvaluationCriterion] = None,
):
self.llm_provider = LLMProviderFactory.create_provider(provider, api_key, model)
self.criteria = criteria or [
RelevanceCriterion(self.llm_provider),
ReadabilityCriterion(self.llm_provider),
CoherenceCriterion(self.llm_provider),
]

def grade_document(self, document: str, question: str) -> GradeDocuments:
return self.provider.grade_document(document, question)
async def grade_document(
self, document: str, question: str = None
) -> Dict[str, Dict[str, float]]:
results = {}
for criterion in self.criteria:
result = await criterion.evaluate(document, question)
results[criterion.name] = {
"binary_score": result.binary_score,
"reasoning": result.reasoning,
}
return results


def document_grader(
provider: str, api_key: str, model: str = DEFAULT_MODEL
provider: str,
api_key: str,
model: str = DEFAULT_MODEL,
criteria: List[EvaluationCriterion] = None,
) -> DocumentGrader:
return DocumentGrader(provider, api_key, model)
return DocumentGrader(provider, api_key, model, criteria)
43 changes: 12 additions & 31 deletions langrade/providers.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,28 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from .constants import (
SYSTEM_PROMPT,
DEFAULT_MODEL,
REASONING_DESCRIPTION,
BINARY_SCORE_DESCRIPTION,
)


class GradeDocuments(BaseModel):
reasoning: str = Field(description=REASONING_DESCRIPTION)
binary_score: str = Field(description=BINARY_SCORE_DESCRIPTION)
from langchain_groq import ChatGroq
from .constants import SYSTEM_PROMPT, DEFAULT_MODEL


class LLMProvider(ABC):
def __init__(self, api_key: str, model: str):
self.llm = self._create_llm(api_key, model)
self.prompt = self._create_prompt()

@abstractmethod
def _create_llm(self, api_key: str, model: str):
pass

def _create_prompt(self):
return ChatPromptTemplate.from_messages(
[
("system", SYSTEM_PROMPT),
(
"human",
"Retrieved document: \n\n {document} \n\n User question: \n\n {question}", # noqa: E501
),
]
)

def grade_document(self, document: str, question: str) -> GradeDocuments:
structured_llm = self.llm.with_structured_output(GradeDocuments)
chain = self.prompt | structured_llm
return chain.invoke({"document": document, "question": question})
async def agenerate(self, prompt: str) -> str:
messages = [
("human", prompt),
]
chat_prompt = ChatPromptTemplate.from_messages(messages)
chain = chat_prompt | self.llm
result = await chain.ainvoke({})
return result.content


class OpenAIProvider(LLMProvider):
Expand All @@ -56,12 +37,12 @@ def _create_llm(self, api_key: str, model: str):

class GoogleProvider(LLMProvider):
def _create_llm(self, api_key: str, model: str):
return ChatGoogleGenerativeAI(google_api_key=api_key, model=model) # noqa: E501
return ChatGoogleGenerativeAI(google_api_key=api_key, model=model)


class GroqProvider(LLMProvider):
def _create_llm(self, api_key: str, model: str):
return ChatGroq(groq_api_key=api_key, model_name=model) # noqa: E501
return ChatGroq(groq_api_key=api_key, model_name=model)


class LLMProviderFactory:
Expand Down

0 comments on commit f58b80b

Please sign in to comment.