-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
291 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from langrade import constants | ||
|
||
|
||
def test_prompt_structure(): | ||
prompts = [ | ||
constants.RELEVANCE_PROMPT, | ||
constants.READABILITY_PROMPT, | ||
constants.COHERENCE_PROMPT, | ||
] | ||
|
||
for prompt in prompts: | ||
assert "{document}" in prompt | ||
assert "Binary score (yes/no):" in prompt | ||
assert "Reasoning:" in prompt | ||
|
||
|
||
def test_system_prompt(): | ||
assert "You are a grader" in constants.SYSTEM_PROMPT | ||
assert "binary score" in constants.SYSTEM_PROMPT.lower() | ||
|
||
|
||
def test_descriptions(): | ||
assert len(constants.REASONING_DESCRIPTION) > 0 | ||
assert "yes" in constants.BINARY_SCORE_DESCRIPTION.lower() | ||
assert "no" in constants.BINARY_SCORE_DESCRIPTION.lower() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import pytest | ||
from unittest.mock import AsyncMock | ||
from langrade.criteria import ( | ||
RelevanceCriterion, | ||
ReadabilityCriterion, | ||
CoherenceCriterion, | ||
) | ||
from langrade.providers import LLMProvider | ||
|
||
|
||
@pytest.fixture | ||
def mock_llm_provider(): | ||
provider = AsyncMock(spec=LLMProvider) | ||
provider.agenerate.return_value = "Binary score (yes/no): yes\nReasoning: The document is relevant to the question." | ||
return provider | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_relevance_criterion(mock_llm_provider): | ||
criterion = RelevanceCriterion(mock_llm_provider) | ||
result = await criterion.evaluate("Test document", "Test question") | ||
assert result.binary_score == "yes" | ||
assert "relevant" in result.reasoning.lower() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_readability_criterion(mock_llm_provider): | ||
criterion = ReadabilityCriterion(mock_llm_provider) | ||
result = await criterion.evaluate("Test document") | ||
assert result.binary_score == "yes" | ||
assert len(result.reasoning) > 0 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_coherence_criterion(mock_llm_provider): | ||
criterion = CoherenceCriterion(mock_llm_provider) | ||
result = await criterion.evaluate("Test document") | ||
assert result.binary_score == "yes" | ||
assert len(result.reasoning) > 0 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_criterion_with_negative_response(mock_llm_provider): | ||
mock_llm_provider.agenerate.return_value = "Binary score (yes/no): no\nReasoning: The document is not relevant to the question." | ||
criterion = RelevanceCriterion(mock_llm_provider) | ||
result = await criterion.evaluate("Irrelevant document", "Test question") | ||
assert result.binary_score == "no" | ||
assert "not relevant" in result.reasoning.lower() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,55 @@ | ||
import unittest | ||
import pytest | ||
from unittest.mock import AsyncMock, patch | ||
from langrade import document_grader | ||
from conftest import get_api_key | ||
from langrade.criteria import ( | ||
RelevanceCriterion, | ||
ReadabilityCriterion, | ||
CoherenceCriterion, | ||
) | ||
|
||
|
||
class TestDocumentGrader(unittest.TestCase): | ||
def setUp(self): | ||
self.api_key = get_api_key() | ||
self.provider = "openai" # または適切なプロバイダー | ||
self.grader = document_grader(self.provider, self.api_key) | ||
@pytest.fixture | ||
def mock_llm_provider(): | ||
with patch("langrade.providers.LLMProviderFactory.create_provider") as mock: | ||
provider = AsyncMock() | ||
provider.agenerate.return_value = "Binary score (yes/no): yes\nReasoning: The document is relevant and readable." | ||
mock.return_value = provider | ||
yield mock | ||
|
||
def test_grade_document(self): | ||
document = "AI is a field of computer science that focuses on creating intelligent machines." # noqa: E501 | ||
question = "What is AI?" | ||
|
||
result = self.grader.grade_document(document, question) | ||
self.assertIsNotNone(result.binary_score) | ||
self.assertIsNotNone(result.reasoning) | ||
@pytest.mark.asyncio | ||
async def test_grade_document_with_default_criteria(mock_llm_provider): | ||
grader = document_grader("openai", "fake_api_key") | ||
result = await grader.grade_document("Test document", "Test question") | ||
|
||
assert "Relevance" in result | ||
assert "Readability" in result | ||
assert "Coherence" in result | ||
assert result["Relevance"]["binary_score"] == "yes" | ||
assert "relevant" in result["Relevance"]["reasoning"].lower() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_grade_document_with_custom_criteria(mock_llm_provider): | ||
custom_criteria = [RelevanceCriterion(mock_llm_provider.return_value)] | ||
grader = document_grader("openai", "fake_api_key", criteria=custom_criteria) | ||
result = await grader.grade_document("Test document", "Test question") | ||
|
||
assert "Relevance" in result | ||
assert "Readability" not in result | ||
assert "Coherence" not in result | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_grade_document_with_all_criteria(mock_llm_provider): | ||
all_criteria = [ | ||
RelevanceCriterion(mock_llm_provider.return_value), | ||
ReadabilityCriterion(mock_llm_provider.return_value), | ||
CoherenceCriterion(mock_llm_provider.return_value), | ||
] | ||
grader = document_grader("openai", "fake_api_key", criteria=all_criteria) | ||
result = await grader.grade_document("Test document", "Test question") | ||
|
||
assert "Relevance" in result | ||
assert "Readability" in result | ||
assert "Coherence" in result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,126 @@ | ||
import os | ||
import unittest | ||
from langchain_anthropic import ChatAnthropic | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
from langchain_groq import ChatGroq | ||
from langchain_openai import ChatOpenAI | ||
import pytest | ||
from unittest.mock import patch, AsyncMock | ||
from dotenv import load_dotenv | ||
from langrade.providers import LLMProviderFactory, GradeDocuments | ||
from langrade.providers import LLMProviderFactory, LLMProvider | ||
from langrade.constants import DEFAULT_MODEL | ||
|
||
# 環境変数を読み込む | ||
load_dotenv() | ||
|
||
|
||
class TestProviders(unittest.TestCase): | ||
@pytest.fixture | ||
def api_keys(): | ||
return { | ||
"openai": os.getenv("OPENAI_API_KEY"), | ||
"anthropic": os.getenv("ANTHROPIC_API_KEY"), | ||
"google": os.getenv("GOOGLE_API_KEY"), | ||
"groq": os.getenv("GROQ_API_KEY"), | ||
} | ||
|
||
def setUp(self): | ||
self.openai_api_key = os.getenv("OPENAI_API_KEY") | ||
self.openai_model = os.getenv("OPENAI_MODEL") | ||
self.google_api_key = os.getenv("GOOGLE_API_KEY") | ||
self.gemini_model = os.getenv("GEMINI_MODEL") | ||
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") | ||
self.claude_model = os.getenv("CLAUDE_MODEL") | ||
self.groq_api_key = os.getenv("GROQ_API_KEY") | ||
self.groq_model = os.getenv("GROQ_MODEL") | ||
|
||
self.test_document = "Artificial Intelligence (AI) is the simulation of human intelligence processes by machines, especially computer systems." # noqa: E501 | ||
self.test_question = "What is AI?" | ||
@pytest.fixture | ||
def models(): | ||
return { | ||
"openai": os.getenv("OPENAI_MODEL", DEFAULT_MODEL), | ||
"anthropic": os.getenv("CLAUDE_MODEL", "claude-3-sonnet-20240229"), | ||
"google": os.getenv("GEMINI_MODEL", "gemini-1.5-pro-latest"), | ||
"groq": os.getenv("GROQ_MODEL", "mixtral-8x7b-32768"), | ||
} | ||
|
||
def test_openai_provider(self): | ||
provider = LLMProviderFactory.create_provider( | ||
"openai", self.openai_api_key, self.openai_model | ||
) | ||
result = provider.grade_document( | ||
self.test_document, self.test_question | ||
) # noqa: E501 | ||
self.assertIsInstance(result, GradeDocuments) | ||
self.assertIn(result.binary_score.lower(), ["yes", "no"]) | ||
self.assertIsNotNone(result.reasoning) | ||
|
||
def test_groq_provider(self): | ||
provider = LLMProviderFactory.create_provider( | ||
"groq", self.groq_api_key, self.groq_model | ||
) | ||
result = provider.grade_document(self.test_document, self.test_question) | ||
self.assertIsInstance(result, GradeDocuments) | ||
self.assertIn(result.binary_score.lower(), ["yes", "no"]) | ||
self.assertIsNotNone(result.reasoning) | ||
|
||
def test_google_provider(self): | ||
provider = LLMProviderFactory.create_provider( | ||
"google", self.google_api_key, self.gemini_model | ||
) | ||
result = provider.grade_document( | ||
self.test_document, self.test_question | ||
) # noqa: E501 | ||
self.assertIsInstance(result, GradeDocuments) | ||
self.assertIn(result.binary_score.lower(), ["yes", "no"]) | ||
self.assertIsNotNone(result.reasoning) | ||
|
||
def test_anthropic_provider(self): | ||
provider = LLMProviderFactory.create_provider( | ||
"anthropic", self.anthropic_api_key, self.claude_model | ||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("provider", ["openai", "anthropic", "google", "groq"]) | ||
async def test_provider_creation(api_keys, models, provider): | ||
api_key = api_keys[provider] | ||
model = models[provider] | ||
|
||
if not api_key: | ||
pytest.skip(f"API key for {provider} not found in environment variables") | ||
|
||
provider_instance = LLMProviderFactory.create_provider(provider, api_key, model) | ||
assert isinstance(provider_instance, LLMProvider) | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("provider", ["openai", "anthropic", "google", "groq"]) | ||
async def test_provider_generation(api_keys, models, provider): | ||
api_key = api_keys[provider] | ||
model = models[provider] | ||
|
||
if not api_key: | ||
pytest.skip(f"API key for {provider} not found in environment variables") | ||
|
||
provider_instance = LLMProviderFactory.create_provider(provider, api_key, model) | ||
|
||
test_prompt = "What is the capital of France?" | ||
response = await provider_instance.agenerate(test_prompt) | ||
|
||
assert isinstance(response, str) | ||
assert len(response) > 0 | ||
assert "Paris" in response | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invalid_provider(): | ||
with pytest.raises(ValueError): | ||
LLMProviderFactory.create_provider( | ||
"invalid_provider", "dummy_key", "dummy_model" | ||
) | ||
result = provider.grade_document( | ||
self.test_document, self.test_question | ||
) # noqa: E501 | ||
self.assertIsInstance(result, GradeDocuments) | ||
self.assertIn(result.binary_score.lower(), ["yes", "no"]) | ||
self.assertIsNotNone(result.reasoning) | ||
|
||
def test_invalid_provider(self): | ||
with self.assertRaises(ValueError): | ||
LLMProviderFactory.create_provider( | ||
"invalid_provider", "dummy_key", "dummy_model" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
|
||
|
||
@pytest.mark.asyncio | ||
@patch("langrade.providers.ChatOpenAI") | ||
async def test_openai_provider_mock(mock_chat): | ||
mock_llm = AsyncMock(spec=ChatOpenAI) | ||
mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response") | ||
mock_chat.return_value = mock_llm | ||
|
||
provider = LLMProviderFactory.create_provider("openai", "fake_api_key") | ||
result = await provider.agenerate("Test prompt") | ||
|
||
assert result == "Mocked response" | ||
|
||
|
||
@pytest.mark.asyncio | ||
@patch("langrade.providers.ChatAnthropic") | ||
async def test_anthropic_provider_mock(mock_chat): | ||
mock_llm = AsyncMock(spec=ChatAnthropic) | ||
mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response") | ||
mock_chat.return_value = mock_llm | ||
|
||
provider = LLMProviderFactory.create_provider("anthropic", "fake_api_key") | ||
result = await provider.agenerate("Test prompt") | ||
|
||
assert result == "Mocked response" | ||
|
||
|
||
@pytest.mark.asyncio | ||
@patch("langrade.providers.ChatGoogleGenerativeAI") | ||
async def test_google_provider_mock(mock_chat): | ||
mock_llm = AsyncMock(spec=ChatGoogleGenerativeAI) | ||
mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response") | ||
mock_chat.return_value = mock_llm | ||
|
||
provider = LLMProviderFactory.create_provider("google", "fake_api_key") | ||
result = await provider.agenerate("Test prompt") | ||
|
||
assert result == "Mocked response" | ||
|
||
|
||
@pytest.mark.asyncio | ||
@patch("langrade.providers.ChatGroq") | ||
async def test_groq_provider_mock(mock_chat): | ||
mock_llm = AsyncMock(spec=ChatGroq) | ||
mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response") | ||
mock_chat.return_value = mock_llm | ||
|
||
provider = LLMProviderFactory.create_provider("groq", "fake_api_key") | ||
result = await provider.agenerate("Test prompt") | ||
|
||
assert result == "Mocked response" | ||
mock_chat.assert_called_once() |
Oops, something went wrong.