Skip to content

Commit

Permalink
add: test
Browse files Browse the repository at this point in the history
  • Loading branch information
nisaji committed Oct 18, 2024
1 parent f58b80b commit f822f18
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 123 deletions.
25 changes: 25 additions & 0 deletions tests/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from langrade import constants


def test_prompt_structure():
prompts = [
constants.RELEVANCE_PROMPT,
constants.READABILITY_PROMPT,
constants.COHERENCE_PROMPT,
]

for prompt in prompts:
assert "{document}" in prompt
assert "Binary score (yes/no):" in prompt
assert "Reasoning:" in prompt


def test_system_prompt():
assert "You are a grader" in constants.SYSTEM_PROMPT
assert "binary score" in constants.SYSTEM_PROMPT.lower()


def test_descriptions():
assert len(constants.REASONING_DESCRIPTION) > 0
assert "yes" in constants.BINARY_SCORE_DESCRIPTION.lower()
assert "no" in constants.BINARY_SCORE_DESCRIPTION.lower()
48 changes: 48 additions & 0 deletions tests/test_criteria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
from unittest.mock import AsyncMock
from langrade.criteria import (
RelevanceCriterion,
ReadabilityCriterion,
CoherenceCriterion,
)
from langrade.providers import LLMProvider


@pytest.fixture
def mock_llm_provider():
provider = AsyncMock(spec=LLMProvider)
provider.agenerate.return_value = "Binary score (yes/no): yes\nReasoning: The document is relevant to the question."
return provider


@pytest.mark.asyncio
async def test_relevance_criterion(mock_llm_provider):
criterion = RelevanceCriterion(mock_llm_provider)
result = await criterion.evaluate("Test document", "Test question")
assert result.binary_score == "yes"
assert "relevant" in result.reasoning.lower()


@pytest.mark.asyncio
async def test_readability_criterion(mock_llm_provider):
criterion = ReadabilityCriterion(mock_llm_provider)
result = await criterion.evaluate("Test document")
assert result.binary_score == "yes"
assert len(result.reasoning) > 0


@pytest.mark.asyncio
async def test_coherence_criterion(mock_llm_provider):
criterion = CoherenceCriterion(mock_llm_provider)
result = await criterion.evaluate("Test document")
assert result.binary_score == "yes"
assert len(result.reasoning) > 0


@pytest.mark.asyncio
async def test_criterion_with_negative_response(mock_llm_provider):
mock_llm_provider.agenerate.return_value = "Binary score (yes/no): no\nReasoning: The document is not relevant to the question."
criterion = RelevanceCriterion(mock_llm_provider)
result = await criterion.evaluate("Irrelevant document", "Test question")
assert result.binary_score == "no"
assert "not relevant" in result.reasoning.lower()
63 changes: 50 additions & 13 deletions tests/test_grader.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,55 @@
import unittest
import pytest
from unittest.mock import AsyncMock, patch
from langrade import document_grader
from conftest import get_api_key
from langrade.criteria import (
RelevanceCriterion,
ReadabilityCriterion,
CoherenceCriterion,
)


class TestDocumentGrader(unittest.TestCase):
def setUp(self):
self.api_key = get_api_key()
self.provider = "openai" # または適切なプロバイダー
self.grader = document_grader(self.provider, self.api_key)
@pytest.fixture
def mock_llm_provider():
with patch("langrade.providers.LLMProviderFactory.create_provider") as mock:
provider = AsyncMock()
provider.agenerate.return_value = "Binary score (yes/no): yes\nReasoning: The document is relevant and readable."
mock.return_value = provider
yield mock

def test_grade_document(self):
document = "AI is a field of computer science that focuses on creating intelligent machines." # noqa: E501
question = "What is AI?"

result = self.grader.grade_document(document, question)
self.assertIsNotNone(result.binary_score)
self.assertIsNotNone(result.reasoning)
@pytest.mark.asyncio
async def test_grade_document_with_default_criteria(mock_llm_provider):
grader = document_grader("openai", "fake_api_key")
result = await grader.grade_document("Test document", "Test question")

assert "Relevance" in result
assert "Readability" in result
assert "Coherence" in result
assert result["Relevance"]["binary_score"] == "yes"
assert "relevant" in result["Relevance"]["reasoning"].lower()


@pytest.mark.asyncio
async def test_grade_document_with_custom_criteria(mock_llm_provider):
custom_criteria = [RelevanceCriterion(mock_llm_provider.return_value)]
grader = document_grader("openai", "fake_api_key", criteria=custom_criteria)
result = await grader.grade_document("Test document", "Test question")

assert "Relevance" in result
assert "Readability" not in result
assert "Coherence" not in result


@pytest.mark.asyncio
async def test_grade_document_with_all_criteria(mock_llm_provider):
all_criteria = [
RelevanceCriterion(mock_llm_provider.return_value),
ReadabilityCriterion(mock_llm_provider.return_value),
CoherenceCriterion(mock_llm_provider.return_value),
]
grader = document_grader("openai", "fake_api_key", criteria=all_criteria)
result = await grader.grade_document("Test document", "Test question")

assert "Relevance" in result
assert "Readability" in result
assert "Coherence" in result
179 changes: 115 additions & 64 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,126 @@
import os
import unittest
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
import pytest
from unittest.mock import patch, AsyncMock
from dotenv import load_dotenv
from langrade.providers import LLMProviderFactory, GradeDocuments
from langrade.providers import LLMProviderFactory, LLMProvider
from langrade.constants import DEFAULT_MODEL

# 環境変数を読み込む
load_dotenv()


class TestProviders(unittest.TestCase):
@pytest.fixture
def api_keys():
return {
"openai": os.getenv("OPENAI_API_KEY"),
"anthropic": os.getenv("ANTHROPIC_API_KEY"),
"google": os.getenv("GOOGLE_API_KEY"),
"groq": os.getenv("GROQ_API_KEY"),
}

def setUp(self):
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.openai_model = os.getenv("OPENAI_MODEL")
self.google_api_key = os.getenv("GOOGLE_API_KEY")
self.gemini_model = os.getenv("GEMINI_MODEL")
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
self.claude_model = os.getenv("CLAUDE_MODEL")
self.groq_api_key = os.getenv("GROQ_API_KEY")
self.groq_model = os.getenv("GROQ_MODEL")

self.test_document = "Artificial Intelligence (AI) is the simulation of human intelligence processes by machines, especially computer systems." # noqa: E501
self.test_question = "What is AI?"
@pytest.fixture
def models():
return {
"openai": os.getenv("OPENAI_MODEL", DEFAULT_MODEL),
"anthropic": os.getenv("CLAUDE_MODEL", "claude-3-sonnet-20240229"),
"google": os.getenv("GEMINI_MODEL", "gemini-1.5-pro-latest"),
"groq": os.getenv("GROQ_MODEL", "mixtral-8x7b-32768"),
}

def test_openai_provider(self):
provider = LLMProviderFactory.create_provider(
"openai", self.openai_api_key, self.openai_model
)
result = provider.grade_document(
self.test_document, self.test_question
) # noqa: E501
self.assertIsInstance(result, GradeDocuments)
self.assertIn(result.binary_score.lower(), ["yes", "no"])
self.assertIsNotNone(result.reasoning)

def test_groq_provider(self):
provider = LLMProviderFactory.create_provider(
"groq", self.groq_api_key, self.groq_model
)
result = provider.grade_document(self.test_document, self.test_question)
self.assertIsInstance(result, GradeDocuments)
self.assertIn(result.binary_score.lower(), ["yes", "no"])
self.assertIsNotNone(result.reasoning)

def test_google_provider(self):
provider = LLMProviderFactory.create_provider(
"google", self.google_api_key, self.gemini_model
)
result = provider.grade_document(
self.test_document, self.test_question
) # noqa: E501
self.assertIsInstance(result, GradeDocuments)
self.assertIn(result.binary_score.lower(), ["yes", "no"])
self.assertIsNotNone(result.reasoning)

def test_anthropic_provider(self):
provider = LLMProviderFactory.create_provider(
"anthropic", self.anthropic_api_key, self.claude_model

@pytest.mark.asyncio
@pytest.mark.parametrize("provider", ["openai", "anthropic", "google", "groq"])
async def test_provider_creation(api_keys, models, provider):
api_key = api_keys[provider]
model = models[provider]

if not api_key:
pytest.skip(f"API key for {provider} not found in environment variables")

provider_instance = LLMProviderFactory.create_provider(provider, api_key, model)
assert isinstance(provider_instance, LLMProvider)


@pytest.mark.asyncio
@pytest.mark.parametrize("provider", ["openai", "anthropic", "google", "groq"])
async def test_provider_generation(api_keys, models, provider):
api_key = api_keys[provider]
model = models[provider]

if not api_key:
pytest.skip(f"API key for {provider} not found in environment variables")

provider_instance = LLMProviderFactory.create_provider(provider, api_key, model)

test_prompt = "What is the capital of France?"
response = await provider_instance.agenerate(test_prompt)

assert isinstance(response, str)
assert len(response) > 0
assert "Paris" in response


@pytest.mark.asyncio
async def test_invalid_provider():
with pytest.raises(ValueError):
LLMProviderFactory.create_provider(
"invalid_provider", "dummy_key", "dummy_model"
)
result = provider.grade_document(
self.test_document, self.test_question
) # noqa: E501
self.assertIsInstance(result, GradeDocuments)
self.assertIn(result.binary_score.lower(), ["yes", "no"])
self.assertIsNotNone(result.reasoning)

def test_invalid_provider(self):
with self.assertRaises(ValueError):
LLMProviderFactory.create_provider(
"invalid_provider", "dummy_key", "dummy_model"
)


if __name__ == "__main__":
unittest.main()


@pytest.mark.asyncio
@patch("langrade.providers.ChatOpenAI")
async def test_openai_provider_mock(mock_chat):
mock_llm = AsyncMock(spec=ChatOpenAI)
mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response")
mock_chat.return_value = mock_llm

provider = LLMProviderFactory.create_provider("openai", "fake_api_key")
result = await provider.agenerate("Test prompt")

assert result == "Mocked response"


@pytest.mark.asyncio
@patch("langrade.providers.ChatAnthropic")
async def test_anthropic_provider_mock(mock_chat):
mock_llm = AsyncMock(spec=ChatAnthropic)
mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response")
mock_chat.return_value = mock_llm

provider = LLMProviderFactory.create_provider("anthropic", "fake_api_key")
result = await provider.agenerate("Test prompt")

assert result == "Mocked response"


@pytest.mark.asyncio
@patch("langrade.providers.ChatGoogleGenerativeAI")
async def test_google_provider_mock(mock_chat):
mock_llm = AsyncMock(spec=ChatGoogleGenerativeAI)
mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response")
mock_chat.return_value = mock_llm

provider = LLMProviderFactory.create_provider("google", "fake_api_key")
result = await provider.agenerate("Test prompt")

assert result == "Mocked response"


@pytest.mark.asyncio
@patch("langrade.providers.ChatGroq")
async def test_groq_provider_mock(mock_chat):
mock_llm = AsyncMock(spec=ChatGroq)
mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response")
mock_chat.return_value = mock_llm

provider = LLMProviderFactory.create_provider("groq", "fake_api_key")
result = await provider.agenerate("Test prompt")

assert result == "Mocked response"
mock_chat.assert_called_once()
Loading

0 comments on commit f822f18

Please sign in to comment.