-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
188 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,14 @@ | ||
from .grader import document_grader, DocumentGrader | ||
from .retriever import create_retriever | ||
from .criteria import RelevanceCriterion, ReadabilityCriterion, CoherenceCriterion | ||
from .providers import LLMProviderFactory | ||
|
||
__all__ = [ | ||
"document_grader", | ||
"DocumentGrader", | ||
"create_retriever", | ||
"RelevanceCriterion", | ||
"ReadabilityCriterion", | ||
"CoherenceCriterion", | ||
"LLMProviderFactory", | ||
] # noqa: E501 | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# criteria.py | ||
|
||
from abc import ABC, abstractmethod | ||
from pydantic import BaseModel, Field | ||
|
||
from langrade.constants import BINARY_SCORE_DESCRIPTION, REASONING_DESCRIPTION | ||
from .providers import LLMProvider | ||
|
||
|
||
class CriterionResult(BaseModel): | ||
binary_score: str = Field(description=BINARY_SCORE_DESCRIPTION) | ||
reasoning: str = Field(description=REASONING_DESCRIPTION) | ||
|
||
|
||
class EvaluationCriterion(ABC): | ||
def __init__(self, llm_provider: LLMProvider): | ||
self.llm_provider = llm_provider | ||
|
||
@abstractmethod | ||
async def evaluate(self, document: str, question: str = None) -> CriterionResult: | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def name(self) -> str: | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def prompt(self) -> str: | ||
pass | ||
|
||
|
||
class RelevanceCriterion(EvaluationCriterion): | ||
@property | ||
def name(self) -> str: | ||
return "Relevance" | ||
|
||
@property | ||
def prompt(self) -> str: | ||
from .constants import RELEVANCE_PROMPT | ||
|
||
return RELEVANCE_PROMPT | ||
|
||
async def evaluate(self, document: str, question: str) -> CriterionResult: | ||
response = await self.llm_provider.agenerate( | ||
self.prompt.format(document=document, question=question) | ||
) | ||
# Parse the response to extract score and explanation | ||
# This is a simplified example; you might need more robust parsing | ||
lines = response.split("\n") | ||
binary_score = lines[0].split(":")[1].strip().lower() | ||
reasoning = lines[1].split(":")[1].strip() | ||
return CriterionResult(binary_score=binary_score, reasoning=reasoning) | ||
|
||
|
||
class ReadabilityCriterion(EvaluationCriterion): | ||
@property | ||
def name(self) -> str: | ||
return "Readability" | ||
|
||
@property | ||
def prompt(self) -> str: | ||
from .constants import READABILITY_PROMPT | ||
|
||
return READABILITY_PROMPT | ||
|
||
async def evaluate(self, document: str, question: str = None) -> CriterionResult: | ||
response = await self.llm_provider.agenerate( | ||
self.prompt.format(document=document) | ||
) | ||
lines = response.split("\n") | ||
binary_score = lines[0].split(":")[1].strip().lower() | ||
reasoning = lines[1].split(":")[1].strip() | ||
return CriterionResult(binary_score=binary_score, reasoning=reasoning) | ||
|
||
|
||
class CoherenceCriterion(EvaluationCriterion): | ||
@property | ||
def name(self) -> str: | ||
return "Coherence" | ||
|
||
@property | ||
def prompt(self) -> str: | ||
from .constants import COHERENCE_PROMPT | ||
|
||
return COHERENCE_PROMPT | ||
|
||
async def evaluate(self, document: str, question: str = None) -> CriterionResult: | ||
response = await self.llm_provider.agenerate( | ||
self.prompt.format(document=document) | ||
) | ||
lines = response.split("\n") | ||
binary_score = lines[0].split(":")[1].strip().lower() | ||
reasoning = lines[1].split(":")[1].strip() | ||
return CriterionResult(binary_score=binary_score, reasoning=reasoning) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,46 @@ | ||
from .providers import LLMProviderFactory, GradeDocuments | ||
from typing import List, Dict | ||
from .providers import LLMProviderFactory | ||
from .criteria import ( | ||
EvaluationCriterion, | ||
RelevanceCriterion, | ||
ReadabilityCriterion, | ||
CoherenceCriterion, | ||
) | ||
from .constants import DEFAULT_MODEL | ||
|
||
|
||
class DocumentGrader: | ||
def __init__( | ||
self, provider: str, api_key: str, model: str = DEFAULT_MODEL | ||
): # noqa: E501 | ||
self.provider = LLMProviderFactory.create_provider( | ||
provider, api_key, model | ||
) # noqa: E501 | ||
self, | ||
provider: str, | ||
api_key: str, | ||
model: str = DEFAULT_MODEL, | ||
criteria: List[EvaluationCriterion] = None, | ||
): | ||
self.llm_provider = LLMProviderFactory.create_provider(provider, api_key, model) | ||
self.criteria = criteria or [ | ||
RelevanceCriterion(self.llm_provider), | ||
ReadabilityCriterion(self.llm_provider), | ||
CoherenceCriterion(self.llm_provider), | ||
] | ||
|
||
def grade_document(self, document: str, question: str) -> GradeDocuments: | ||
return self.provider.grade_document(document, question) | ||
async def grade_document( | ||
self, document: str, question: str = None | ||
) -> Dict[str, Dict[str, float]]: | ||
results = {} | ||
for criterion in self.criteria: | ||
result = await criterion.evaluate(document, question) | ||
results[criterion.name] = { | ||
"binary_score": result.binary_score, | ||
"reasoning": result.reasoning, | ||
} | ||
return results | ||
|
||
|
||
def document_grader( | ||
provider: str, api_key: str, model: str = DEFAULT_MODEL | ||
provider: str, | ||
api_key: str, | ||
model: str = DEFAULT_MODEL, | ||
criteria: List[EvaluationCriterion] = None, | ||
) -> DocumentGrader: | ||
return DocumentGrader(provider, api_key, model) | ||
return DocumentGrader(provider, api_key, model, criteria) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters