Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add OpenAIEmbedder #12

Merged
merged 9 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 156 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ torch = "2.0.0" # this version is needed to avoid "libcu*" related errors
pydantic = "^2.6.3"
eval-type-backport = "^0.1.3"
smart-open = "^7.0.1"
openai = "^1.16.2"
pytest-mock = "^3.14.0"

[tool.poetry.group.dev.dependencies]
black = "^23.11.0"
Expand Down
1 change: 1 addition & 0 deletions src/jmteb/embedders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from jmteb.embedders.base import TextEmbedder
from jmteb.embedders.openai_embedder import OpenAIEmbedder
from jmteb.embedders.sbert_embedder import SentenceBertEmbedder
61 changes: 61 additions & 0 deletions src/jmteb/embedders/openai_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

import numpy as np
from loguru import logger
from openai import OpenAI

from jmteb.embedders.base import TextEmbedder

MODEL_DIM = {
"text-embedding-3-large": 3072,
"text-embedding-3-small": 1536,
"text-embedding-ada-002": 1536,
}


class OpenAIEmbedder(TextEmbedder):
"""Embedder via OpenAI API."""

def __init__(self, model: str = "text-embedding-3-small", dim: int | None = None) -> None:
"""Setup.
model and dim: see https://platform.openai.com/docs/models/embeddings
`text-embedding-3-large` model: max 3072 dim
`text-embedding-3-small` model: max 1536 dim
`text-embedding-ada-002` model: max 1536 dim

OpenAI embeddings have been normalized to length 1. See
https://platform.openai.com/docs/guides/embeddings/which-distance-function-should-i-use

Args:
model (str, optional): Name of an OpenAI embedding model. Defaults to "text-embedding-3-small".
dim (int, optional): Output dimension. Defaults to 1536.
"""
self.client = OpenAI() # API key written in .env
assert model in MODEL_DIM.keys(), f"`model` must be one of {list(MODEL_DIM.keys())}!"
self.model = model
if not dim:
self.dim = MODEL_DIM[self.model]
else:
if dim > MODEL_DIM[self.model]:
self.dim = MODEL_DIM[self.model]
logger.warning(f"The maximum dimension of model {self.model} is {self.dim}, use dim={self.dim}.")
else:
self.dim = dim

def encode(self, text: str | list[str]) -> np.ndarray:
result = np.asarray(
[
data.embedding
for data in self.client.embeddings.create(
input=text,
model=self.model,
dimensions=self.dim,
).data
]
)
if result.shape[0] == 1:
return result.reshape(-1)
return result

def get_output_dim(self) -> int:
return self.dim
69 changes: 69 additions & 0 deletions tests/embedders/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from dataclasses import dataclass

import numpy as np
import pytest
from pytest_mock import MockerFixture

from jmteb.embedders import OpenAIEmbedder, TextEmbedder

OUTPUT_DIM = 1536 # the maximum dim of default model `text-embedding-3-small`


@pytest.fixture(scope="function")
def mock_openai_embedder(mocker: MockerFixture):
mocker.patch("jmteb.embedders.openai_embedder.OpenAI")
return OpenAIEmbedder(model="text-embedding-3-small")
lsz05 marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class MockData:
data: list


@dataclass
class MockEmbedding:
embedding: list


class MockOpenAIClientEmbedding:
def create(input: str | list[str], model: str, dimensions: int):
if isinstance(input, str):
input = [input]
return MockData(data=[MockEmbedding(embedding=[0.1] * dimensions)] * len(input))


@pytest.mark.usefixtures("mock_openai_embedder")
class TestOpenAIEmbedder:
@pytest.fixture(autouse=True)
def setup_class(cls, mocker: MockerFixture, mock_openai_embedder: TextEmbedder):
cls.model = mock_openai_embedder
cls.mock_create = mocker.patch.object(cls.model.client, "embeddings", new=MockOpenAIClientEmbedding)

def test_encode(self):
embeddings = self.model.encode("任意のテキスト")
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (OUTPUT_DIM,)

def test_encode_multiple(self):
embeddings = self.model.encode(["任意のテキスト"] * 3)
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (3, OUTPUT_DIM)

def test_get_output_dim(self):
assert self.model.get_output_dim() == OUTPUT_DIM

def test_nonexistent_model(self):
with pytest.raises(AssertionError):
_ = OpenAIEmbedder(model="model")

def test_model_dim(self):
assert OpenAIEmbedder(model="text-embedding-3-large").dim == 3072
assert OpenAIEmbedder(model="text-embedding-ada-002").dim == 1536

def test_dim_over_max(self):
assert OpenAIEmbedder(dim=2 * OUTPUT_DIM).dim == OUTPUT_DIM

def test_dim_smaller(self):
assert OpenAIEmbedder(dim=OUTPUT_DIM // 2).dim == OUTPUT_DIM // 2
Loading