diff --git a/poetry.lock b/poetry.lock index dc596d0..a093539 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1925,6 +1925,23 @@ tomli = ">=1.0.0" [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -3322,4 +3339,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "a6c6e6c7efd4b4c0e39ce9d29327b7fc51fa4d621f7a4154ce2a6dae9f089dcf" +content-hash = "0a438df085b5f588ae2e896ccf7ae833860690af20528afac6bddc963abf564c" diff --git a/pyproject.toml b/pyproject.toml index ba0dfab..49e05b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ pydantic = "^2.6.3" eval-type-backport = "^0.1.3" smart-open = "^7.0.1" openai = "^1.16.2" +pytest-mock = "^3.14.0" [tool.poetry.group.dev.dependencies] black = "^23.11.0" diff --git a/tests/embedders/test_openai.py b/tests/embedders/test_openai.py index f33eb4d..70b1d96 100644 --- a/tests/embedders/test_openai.py +++ b/tests/embedders/test_openai.py @@ -1,13 +1,45 @@ +from __future__ import annotations + +from dataclasses import dataclass + import numpy as np +import pytest +from pytest_mock import MockerFixture -from jmteb.embedders.openai_embedder import OpenAIEmbedder +from jmteb.embedders import OpenAIEmbedder, TextEmbedder OUTPUT_DIM = 1536 -class TestSentenceBertEmbedder: - def setup_class(cls): - cls.model = OpenAIEmbedder(model="text-embedding-3-small") +@pytest.fixture(scope="function") +def mock_openai_embedder(mocker: MockerFixture): + mocker.patch("jmteb.embedders.openai_embedder.OpenAI") + return OpenAIEmbedder(model="text-embedding-3-small") + + +@dataclass +class MockData: + data: list + + +@dataclass +class MockEmbedding: + embedding: list + + +class MockEmbedder: + def create(input: str | list[str], model: str, dimensions: int): + if isinstance(input, str): + input = [input] + return MockData([MockEmbedding(embedding=[0.1] * dimensions)] * len(input)) + + +@pytest.mark.usefixtures("mock_openai_embedder") +class TestOpenAIEmbedder: + @pytest.fixture(autouse=True) + def setup_class(cls, mocker: MockerFixture, mock_openai_embedder: TextEmbedder): + cls.model = mock_openai_embedder + cls.mock_create = mocker.patch.object(cls.model.client, "embeddings", new=MockEmbedder) def test_encode(self): embeddings = self.model.encode("任意のテキスト")