diff --git a/poetry.lock b/poetry.lock index 1631e34..a093539 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -136,6 +136,28 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} +[[package]] +name = "anyio" +version = "4.3.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, + {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] + [[package]] name = "async-timeout" version = "4.0.3" @@ -429,6 +451,17 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "eval-type-backport" version = "0.1.3" @@ -443,6 +476,20 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "exceptiongroup" +version = "1.2.0" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, + {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, +] + +[package.extras] +test = ["pytest (>=6)"] + [[package]] name = "filelock" version = "3.13.1" @@ -630,6 +677,62 @@ files = [ unidic = ["unidic"] unidic-lite = ["unidic-lite"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.5" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] + +[[package]] +name = "httpx" +version = "0.27.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "huggingface-hub" version = "0.20.3" @@ -1357,6 +1460,29 @@ files = [ setuptools = "*" wheel = "*" +[[package]] +name = "openai" +version = "1.16.2" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.16.2-py3-none-any.whl", hash = "sha256:46a435380921e42dae218d04d6dd0e89a30d7f3b9d8a778d5887f78003cf9354"}, + {file = "openai-1.16.2.tar.gz", hash = "sha256:c93d5efe5b73b6cb72c4cd31823852d2e7c84a138c0af3cbe4a8eb32b1164ab2"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "packaging" version = "23.2" @@ -1799,6 +1925,23 @@ tomli = ">=1.0.0" [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -2360,6 +2503,17 @@ test = ["azure-common", "azure-core", "azure-storage-blob", "boto3", "google-clo webhdfs = ["requests"] zst = ["zstandard"] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "sudachidict-core" version = "20240109" @@ -3185,4 +3339,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "6e7ddbb7e9d29824bd543489a80f8378e1be9b83131b5e6d0b0bf699a6a941f3" +content-hash = "0a438df085b5f588ae2e896ccf7ae833860690af20528afac6bddc963abf564c" diff --git a/pyproject.toml b/pyproject.toml index 0f98250..49e05b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ torch = "2.0.0" # this version is needed to avoid "libcu*" related errors pydantic = "^2.6.3" eval-type-backport = "^0.1.3" smart-open = "^7.0.1" +openai = "^1.16.2" +pytest-mock = "^3.14.0" [tool.poetry.group.dev.dependencies] black = "^23.11.0" diff --git a/src/jmteb/embedders/__init__.py b/src/jmteb/embedders/__init__.py index a1149d7..07399c5 100644 --- a/src/jmteb/embedders/__init__.py +++ b/src/jmteb/embedders/__init__.py @@ -1,2 +1,3 @@ from jmteb.embedders.base import TextEmbedder +from jmteb.embedders.openai_embedder import OpenAIEmbedder from jmteb.embedders.sbert_embedder import SentenceBertEmbedder diff --git a/src/jmteb/embedders/openai_embedder.py b/src/jmteb/embedders/openai_embedder.py new file mode 100644 index 0000000..63af35e --- /dev/null +++ b/src/jmteb/embedders/openai_embedder.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import numpy as np +from loguru import logger +from openai import OpenAI + +from jmteb.embedders.base import TextEmbedder + +MODEL_DIM = { + "text-embedding-3-large": 3072, + "text-embedding-3-small": 1536, + "text-embedding-ada-002": 1536, +} + + +class OpenAIEmbedder(TextEmbedder): + """Embedder via OpenAI API.""" + + def __init__(self, model: str = "text-embedding-3-small", dim: int | None = None) -> None: + """Setup. + model and dim: see https://platform.openai.com/docs/models/embeddings + `text-embedding-3-large` model: max 3072 dim + `text-embedding-3-small` model: max 1536 dim + `text-embedding-ada-002` model: max 1536 dim + + OpenAI embeddings have been normalized to length 1. See + https://platform.openai.com/docs/guides/embeddings/which-distance-function-should-i-use + + Args: + model (str, optional): Name of an OpenAI embedding model. Defaults to "text-embedding-3-small". + dim (int, optional): Output dimension. Defaults to 1536. + """ + self.client = OpenAI() # API key written in .env + assert model in MODEL_DIM.keys(), f"`model` must be one of {list(MODEL_DIM.keys())}!" + self.model = model + if not dim: + self.dim = MODEL_DIM[self.model] + else: + if dim > MODEL_DIM[self.model]: + self.dim = MODEL_DIM[self.model] + logger.warning(f"The maximum dimension of model {self.model} is {self.dim}, use dim={self.dim}.") + else: + self.dim = dim + + def encode(self, text: str | list[str]) -> np.ndarray: + result = np.asarray( + [ + data.embedding + for data in self.client.embeddings.create( + input=text, + model=self.model, + dimensions=self.dim, + ).data + ] + ) + if result.shape[0] == 1: + return result.reshape(-1) + return result + + def get_output_dim(self) -> int: + return self.dim diff --git a/tests/embedders/test_openai.py b/tests/embedders/test_openai.py new file mode 100644 index 0000000..e1c1302 --- /dev/null +++ b/tests/embedders/test_openai.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import pytest +from pytest_mock import MockerFixture + +from jmteb.embedders import OpenAIEmbedder, TextEmbedder + +OUTPUT_DIM = 1536 # the maximum dim of default model `text-embedding-3-small` + + +@pytest.fixture(scope="function") +def mock_openai_embedder(mocker: MockerFixture): + mocker.patch("jmteb.embedders.openai_embedder.OpenAI") + return OpenAIEmbedder(model="text-embedding-3-small") + + +@dataclass +class MockData: + data: list + + +@dataclass +class MockEmbedding: + embedding: list + + +class MockOpenAIClientEmbedding: + def create(input: str | list[str], model: str, dimensions: int): + if isinstance(input, str): + input = [input] + return MockData(data=[MockEmbedding(embedding=[0.1] * dimensions)] * len(input)) + + +@pytest.mark.usefixtures("mock_openai_embedder") +class TestOpenAIEmbedder: + @pytest.fixture(autouse=True) + def setup_class(cls, mocker: MockerFixture, mock_openai_embedder: TextEmbedder): + cls.model = mock_openai_embedder + cls.mock_create = mocker.patch.object(cls.model.client, "embeddings", new=MockOpenAIClientEmbedding) + + def test_encode(self): + embeddings = self.model.encode("任意のテキスト") + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (OUTPUT_DIM,) + + def test_encode_multiple(self): + embeddings = self.model.encode(["任意のテキスト"] * 3) + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (3, OUTPUT_DIM) + + def test_get_output_dim(self): + assert self.model.get_output_dim() == OUTPUT_DIM + + def test_nonexistent_model(self): + with pytest.raises(AssertionError): + _ = OpenAIEmbedder(model="model") + + def test_model_dim(self): + assert OpenAIEmbedder(model="text-embedding-3-large").dim == 3072 + assert OpenAIEmbedder(model="text-embedding-ada-002").dim == 1536 + + def test_dim_over_max(self): + assert OpenAIEmbedder(dim=2 * OUTPUT_DIM).dim == OUTPUT_DIM + + def test_dim_smaller(self): + assert OpenAIEmbedder(dim=OUTPUT_DIM // 2).dim == OUTPUT_DIM // 2