From ec640deb2e4cecd6559b8a387abab4e80dd7f936 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 9 Apr 2024 01:25:02 +0900 Subject: [PATCH] Add OpenAI embedder --- src/jmteb/embedders/__init__.py | 1 + src/jmteb/embedders/openai_embedder.py | 49 ++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 src/jmteb/embedders/openai_embedder.py diff --git a/src/jmteb/embedders/__init__.py b/src/jmteb/embedders/__init__.py index a1149d7..07399c5 100644 --- a/src/jmteb/embedders/__init__.py +++ b/src/jmteb/embedders/__init__.py @@ -1,2 +1,3 @@ from jmteb.embedders.base import TextEmbedder +from jmteb.embedders.openai_embedder import OpenAIEmbedder from jmteb.embedders.sbert_embedder import SentenceBertEmbedder diff --git a/src/jmteb/embedders/openai_embedder.py b/src/jmteb/embedders/openai_embedder.py new file mode 100644 index 0000000..b6eb184 --- /dev/null +++ b/src/jmteb/embedders/openai_embedder.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import numpy as np +from openai import OpenAI + +from jmteb.embedders.base import TextEmbedder + + +class OpenAIEmbedder(TextEmbedder): + """Embedder via OpenAI API.""" + + def __init__(self, model: str = "text-embedding-3-small", dim: int | None = None) -> None: + """Setup. + model and dim: see https://platform.openai.com/docs/models/embeddings + `text-embedding-3-large` model: max 3072 dim + `text-embedding-3-small` model: max 1536 dim + `text-embedding-ada-002` model: max 1536 dim + + Args: + model (str, optional): Name of an OpenAI embedding model. Defaults to "text-embedding-3-small". + dim (int, optional): Output dimension. Defaults to 1536. + """ + self.client = OpenAI() # API key written in .env + self.model = model + if not dim: + if model == "text-embedding-3-large": + self.dim = 3072 + else: + self.dim = 1536 + else: + self.dim = dim + + def encode(self, text: str | list[str]) -> np.ndarray: + result = np.asarray( + [ + data.embedding + for data in self.client.embeddings.create( + input=text, + model=self.model, + dimensions=self.dim, + ).data + ] + ) + if result.shape[0] == 1: + return result.reshape(-1) + return result + + def get_output_dim(self) -> int: + return self.dim