diff --git a/tests/test_constants.py b/tests/test_constants.py new file mode 100644 index 0000000..e91ae19 --- /dev/null +++ b/tests/test_constants.py @@ -0,0 +1,25 @@ +from langrade import constants + + +def test_prompt_structure(): + prompts = [ + constants.RELEVANCE_PROMPT, + constants.READABILITY_PROMPT, + constants.COHERENCE_PROMPT, + ] + + for prompt in prompts: + assert "{document}" in prompt + assert "Binary score (yes/no):" in prompt + assert "Reasoning:" in prompt + + +def test_system_prompt(): + assert "You are a grader" in constants.SYSTEM_PROMPT + assert "binary score" in constants.SYSTEM_PROMPT.lower() + + +def test_descriptions(): + assert len(constants.REASONING_DESCRIPTION) > 0 + assert "yes" in constants.BINARY_SCORE_DESCRIPTION.lower() + assert "no" in constants.BINARY_SCORE_DESCRIPTION.lower() diff --git a/tests/test_criteria.py b/tests/test_criteria.py new file mode 100644 index 0000000..a0e804f --- /dev/null +++ b/tests/test_criteria.py @@ -0,0 +1,48 @@ +import pytest +from unittest.mock import AsyncMock +from langrade.criteria import ( + RelevanceCriterion, + ReadabilityCriterion, + CoherenceCriterion, +) +from langrade.providers import LLMProvider + + +@pytest.fixture +def mock_llm_provider(): + provider = AsyncMock(spec=LLMProvider) + provider.agenerate.return_value = "Binary score (yes/no): yes\nReasoning: The document is relevant to the question." + return provider + + +@pytest.mark.asyncio +async def test_relevance_criterion(mock_llm_provider): + criterion = RelevanceCriterion(mock_llm_provider) + result = await criterion.evaluate("Test document", "Test question") + assert result.binary_score == "yes" + assert "relevant" in result.reasoning.lower() + + +@pytest.mark.asyncio +async def test_readability_criterion(mock_llm_provider): + criterion = ReadabilityCriterion(mock_llm_provider) + result = await criterion.evaluate("Test document") + assert result.binary_score == "yes" + assert len(result.reasoning) > 0 + + +@pytest.mark.asyncio +async def test_coherence_criterion(mock_llm_provider): + criterion = CoherenceCriterion(mock_llm_provider) + result = await criterion.evaluate("Test document") + assert result.binary_score == "yes" + assert len(result.reasoning) > 0 + + +@pytest.mark.asyncio +async def test_criterion_with_negative_response(mock_llm_provider): + mock_llm_provider.agenerate.return_value = "Binary score (yes/no): no\nReasoning: The document is not relevant to the question." + criterion = RelevanceCriterion(mock_llm_provider) + result = await criterion.evaluate("Irrelevant document", "Test question") + assert result.binary_score == "no" + assert "not relevant" in result.reasoning.lower() diff --git a/tests/test_grader.py b/tests/test_grader.py index aa77cdd..b232140 100644 --- a/tests/test_grader.py +++ b/tests/test_grader.py @@ -1,18 +1,55 @@ -import unittest +import pytest +from unittest.mock import AsyncMock, patch from langrade import document_grader -from conftest import get_api_key +from langrade.criteria import ( + RelevanceCriterion, + ReadabilityCriterion, + CoherenceCriterion, +) -class TestDocumentGrader(unittest.TestCase): - def setUp(self): - self.api_key = get_api_key() - self.provider = "openai" # または適切なプロバイダー - self.grader = document_grader(self.provider, self.api_key) +@pytest.fixture +def mock_llm_provider(): + with patch("langrade.providers.LLMProviderFactory.create_provider") as mock: + provider = AsyncMock() + provider.agenerate.return_value = "Binary score (yes/no): yes\nReasoning: The document is relevant and readable." + mock.return_value = provider + yield mock - def test_grade_document(self): - document = "AI is a field of computer science that focuses on creating intelligent machines." # noqa: E501 - question = "What is AI?" - result = self.grader.grade_document(document, question) - self.assertIsNotNone(result.binary_score) - self.assertIsNotNone(result.reasoning) +@pytest.mark.asyncio +async def test_grade_document_with_default_criteria(mock_llm_provider): + grader = document_grader("openai", "fake_api_key") + result = await grader.grade_document("Test document", "Test question") + + assert "Relevance" in result + assert "Readability" in result + assert "Coherence" in result + assert result["Relevance"]["binary_score"] == "yes" + assert "relevant" in result["Relevance"]["reasoning"].lower() + + +@pytest.mark.asyncio +async def test_grade_document_with_custom_criteria(mock_llm_provider): + custom_criteria = [RelevanceCriterion(mock_llm_provider.return_value)] + grader = document_grader("openai", "fake_api_key", criteria=custom_criteria) + result = await grader.grade_document("Test document", "Test question") + + assert "Relevance" in result + assert "Readability" not in result + assert "Coherence" not in result + + +@pytest.mark.asyncio +async def test_grade_document_with_all_criteria(mock_llm_provider): + all_criteria = [ + RelevanceCriterion(mock_llm_provider.return_value), + ReadabilityCriterion(mock_llm_provider.return_value), + CoherenceCriterion(mock_llm_provider.return_value), + ] + grader = document_grader("openai", "fake_api_key", criteria=all_criteria) + result = await grader.grade_document("Test document", "Test question") + + assert "Relevance" in result + assert "Readability" in result + assert "Coherence" in result diff --git a/tests/test_providers.py b/tests/test_providers.py index 79559da..d2a2b99 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,75 +1,126 @@ import os -import unittest +from langchain_anthropic import ChatAnthropic +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_groq import ChatGroq +from langchain_openai import ChatOpenAI +import pytest +from unittest.mock import patch, AsyncMock from dotenv import load_dotenv -from langrade.providers import LLMProviderFactory, GradeDocuments +from langrade.providers import LLMProviderFactory, LLMProvider +from langrade.constants import DEFAULT_MODEL # 環境変数を読み込む load_dotenv() -class TestProviders(unittest.TestCase): +@pytest.fixture +def api_keys(): + return { + "openai": os.getenv("OPENAI_API_KEY"), + "anthropic": os.getenv("ANTHROPIC_API_KEY"), + "google": os.getenv("GOOGLE_API_KEY"), + "groq": os.getenv("GROQ_API_KEY"), + } - def setUp(self): - self.openai_api_key = os.getenv("OPENAI_API_KEY") - self.openai_model = os.getenv("OPENAI_MODEL") - self.google_api_key = os.getenv("GOOGLE_API_KEY") - self.gemini_model = os.getenv("GEMINI_MODEL") - self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") - self.claude_model = os.getenv("CLAUDE_MODEL") - self.groq_api_key = os.getenv("GROQ_API_KEY") - self.groq_model = os.getenv("GROQ_MODEL") - self.test_document = "Artificial Intelligence (AI) is the simulation of human intelligence processes by machines, especially computer systems." # noqa: E501 - self.test_question = "What is AI?" +@pytest.fixture +def models(): + return { + "openai": os.getenv("OPENAI_MODEL", DEFAULT_MODEL), + "anthropic": os.getenv("CLAUDE_MODEL", "claude-3-sonnet-20240229"), + "google": os.getenv("GEMINI_MODEL", "gemini-1.5-pro-latest"), + "groq": os.getenv("GROQ_MODEL", "mixtral-8x7b-32768"), + } - def test_openai_provider(self): - provider = LLMProviderFactory.create_provider( - "openai", self.openai_api_key, self.openai_model - ) - result = provider.grade_document( - self.test_document, self.test_question - ) # noqa: E501 - self.assertIsInstance(result, GradeDocuments) - self.assertIn(result.binary_score.lower(), ["yes", "no"]) - self.assertIsNotNone(result.reasoning) - - def test_groq_provider(self): - provider = LLMProviderFactory.create_provider( - "groq", self.groq_api_key, self.groq_model - ) - result = provider.grade_document(self.test_document, self.test_question) - self.assertIsInstance(result, GradeDocuments) - self.assertIn(result.binary_score.lower(), ["yes", "no"]) - self.assertIsNotNone(result.reasoning) - - def test_google_provider(self): - provider = LLMProviderFactory.create_provider( - "google", self.google_api_key, self.gemini_model - ) - result = provider.grade_document( - self.test_document, self.test_question - ) # noqa: E501 - self.assertIsInstance(result, GradeDocuments) - self.assertIn(result.binary_score.lower(), ["yes", "no"]) - self.assertIsNotNone(result.reasoning) - - def test_anthropic_provider(self): - provider = LLMProviderFactory.create_provider( - "anthropic", self.anthropic_api_key, self.claude_model + +@pytest.mark.asyncio +@pytest.mark.parametrize("provider", ["openai", "anthropic", "google", "groq"]) +async def test_provider_creation(api_keys, models, provider): + api_key = api_keys[provider] + model = models[provider] + + if not api_key: + pytest.skip(f"API key for {provider} not found in environment variables") + + provider_instance = LLMProviderFactory.create_provider(provider, api_key, model) + assert isinstance(provider_instance, LLMProvider) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("provider", ["openai", "anthropic", "google", "groq"]) +async def test_provider_generation(api_keys, models, provider): + api_key = api_keys[provider] + model = models[provider] + + if not api_key: + pytest.skip(f"API key for {provider} not found in environment variables") + + provider_instance = LLMProviderFactory.create_provider(provider, api_key, model) + + test_prompt = "What is the capital of France?" + response = await provider_instance.agenerate(test_prompt) + + assert isinstance(response, str) + assert len(response) > 0 + assert "Paris" in response + + +@pytest.mark.asyncio +async def test_invalid_provider(): + with pytest.raises(ValueError): + LLMProviderFactory.create_provider( + "invalid_provider", "dummy_key", "dummy_model" ) - result = provider.grade_document( - self.test_document, self.test_question - ) # noqa: E501 - self.assertIsInstance(result, GradeDocuments) - self.assertIn(result.binary_score.lower(), ["yes", "no"]) - self.assertIsNotNone(result.reasoning) - - def test_invalid_provider(self): - with self.assertRaises(ValueError): - LLMProviderFactory.create_provider( - "invalid_provider", "dummy_key", "dummy_model" - ) - - -if __name__ == "__main__": - unittest.main() + + +@pytest.mark.asyncio +@patch("langrade.providers.ChatOpenAI") +async def test_openai_provider_mock(mock_chat): + mock_llm = AsyncMock(spec=ChatOpenAI) + mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response") + mock_chat.return_value = mock_llm + + provider = LLMProviderFactory.create_provider("openai", "fake_api_key") + result = await provider.agenerate("Test prompt") + + assert result == "Mocked response" + + +@pytest.mark.asyncio +@patch("langrade.providers.ChatAnthropic") +async def test_anthropic_provider_mock(mock_chat): + mock_llm = AsyncMock(spec=ChatAnthropic) + mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response") + mock_chat.return_value = mock_llm + + provider = LLMProviderFactory.create_provider("anthropic", "fake_api_key") + result = await provider.agenerate("Test prompt") + + assert result == "Mocked response" + + +@pytest.mark.asyncio +@patch("langrade.providers.ChatGoogleGenerativeAI") +async def test_google_provider_mock(mock_chat): + mock_llm = AsyncMock(spec=ChatGoogleGenerativeAI) + mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response") + mock_chat.return_value = mock_llm + + provider = LLMProviderFactory.create_provider("google", "fake_api_key") + result = await provider.agenerate("Test prompt") + + assert result == "Mocked response" + + +@pytest.mark.asyncio +@patch("langrade.providers.ChatGroq") +async def test_groq_provider_mock(mock_chat): + mock_llm = AsyncMock(spec=ChatGroq) + mock_llm.ainvoke.return_value = AsyncMock(content="Mocked response") + mock_chat.return_value = mock_llm + + provider = LLMProviderFactory.create_provider("groq", "fake_api_key") + result = await provider.agenerate("Test prompt") + + assert result == "Mocked response" + mock_chat.assert_called_once() diff --git a/tests/test_retriever.py b/tests/test_retriever.py index e0ab37e..10cd38a 100644 --- a/tests/test_retriever.py +++ b/tests/test_retriever.py @@ -1,49 +1,56 @@ -import unittest -from unittest.mock import patch, MagicMock -from langrade import create_retriever -from conftest import get_api_key +import pytest +from unittest.mock import patch, AsyncMock +from langrade.retriever import create_retriever from langchain.schema import Document -class TestRetriever(unittest.TestCase): - def setUp(self): - self.api_key = get_api_key() - self.urls = [ - "https://lilianweng.github.io/posts/2023-06-23-agent/", - "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", - "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/", - ] - - @patch("langrade.retriever.WebBaseLoader") - @patch("langrade.retriever.Chroma") - def test_create_retriever(self, mock_chroma, mock_web_loader): - mock_doc = Document(page_content="This is a test document", metadata={}) - mock_web_loader.return_value.load.return_value = [mock_doc] - mock_chroma.from_documents.return_value.as_retriever.return_value = MagicMock() - - retriever = create_retriever(self.urls, self.api_key) - - self.assertIsNotNone(retriever) - mock_web_loader.assert_called() - mock_chroma.from_documents.assert_called() - - @patch("langrade.retriever.WebBaseLoader") - @patch("langrade.retriever.Chroma") - def test_retriever_get_relevant_documents(self, mock_chroma, mock_web_loader): - mock_doc = Document(page_content="This is a test document", metadata={}) - mock_web_loader.return_value.load.return_value = [mock_doc] - mock_retriever = MagicMock() - mock_chroma.from_documents.return_value.as_retriever.return_value = ( - mock_retriever - ) - - retriever = create_retriever(self.urls, self.api_key) - question = "What is AI?" - mock_retriever.get_relevant_documents.return_value = [ - Document(page_content="AI is a field of computer science.", metadata={}) - ] - - docs = retriever.get_relevant_documents(question) - - self.assertEqual(len(docs), 1) - self.assertIn("AI", docs[0].page_content) +@pytest.fixture +def api_key(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake_api_key") + return "fake_api_key" + + +@pytest.fixture +def urls(): + return [ + "https://lilianweng.github.io/posts/2023-06-23-agent/", + "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", + "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/", + ] + + +@patch("langrade.retriever.WebBaseLoader") +@patch("langrade.retriever.Chroma") +def test_create_retriever(mock_chroma, mock_web_loader, api_key, urls): + mock_doc = Document(page_content="This is a test document", metadata={}) + mock_web_loader.return_value.load.return_value = [mock_doc] + mock_chroma.from_documents.return_value.as_retriever.return_value = AsyncMock() + + retriever = create_retriever(urls, api_key) + + assert retriever is not None + mock_web_loader.assert_called() + mock_chroma.from_documents.assert_called() + + +@pytest.mark.asyncio +@patch("langrade.retriever.WebBaseLoader") +@patch("langrade.retriever.Chroma") +async def test_retriever_get_relevant_documents( + mock_chroma, mock_web_loader, api_key, urls +): + mock_doc = Document(page_content="This is a test document", metadata={}) + mock_web_loader.return_value.load.return_value = [mock_doc] + mock_retriever = AsyncMock() + mock_chroma.from_documents.return_value.as_retriever.return_value = mock_retriever + + retriever = create_retriever(urls, api_key) + question = "What is AI?" + mock_retriever.get_relevant_documents.return_value = [ + Document(page_content="AI is a field of computer science.", metadata={}) + ] + + docs = await retriever.get_relevant_documents(question) + + assert len(docs) == 1 + assert "AI" in docs[0].page_content