-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #91 from qdrant/feat/pgvector-client
feat: Implement pgvector client based on ann-benchmarks implementation
- Loading branch information
Showing
14 changed files
with
516 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from engine.clients.pgvector.configure import PgVectorConfigurator | ||
from engine.clients.pgvector.search import PgVectorSearcher | ||
from engine.clients.pgvector.upload import PgVectorUploader | ||
|
||
__all__ = [ | ||
"PgVectorConfigurator", | ||
"PgVectorSearcher", | ||
"PgVectorUploader", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
PGVECTOR_PORT = 9200 | ||
PGVECTOR_DB = "postgres" | ||
PGVECTOR_USER = "postgres" | ||
PGVECTOR_PASSWORD = "passwd" | ||
|
||
|
||
def get_db_config(host, connection_params): | ||
return { | ||
"host": host or "localhost", | ||
"dbname": PGVECTOR_DB, | ||
"user": PGVECTOR_USER, | ||
"password": PGVECTOR_PASSWORD, | ||
"autocommit": True, | ||
**connection_params, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import pgvector.psycopg | ||
import psycopg | ||
|
||
from benchmark.dataset import Dataset | ||
from engine.base_client import IncompatibilityError | ||
from engine.base_client.configure import BaseConfigurator | ||
from engine.base_client.distances import Distance | ||
from engine.clients.pgvector.config import get_db_config | ||
|
||
|
||
class PgVectorConfigurator(BaseConfigurator): | ||
DISTANCE_MAPPING = { | ||
Distance.L2: "vector_l2_ops", | ||
Distance.COSINE: "vector_cosine_ops", | ||
} | ||
|
||
def __init__(self, host, collection_params: dict, connection_params: dict): | ||
super().__init__(host, collection_params, connection_params) | ||
self.conn = psycopg.connect(**get_db_config(host, connection_params)) | ||
print("configure connection created") | ||
self.conn.execute("CREATE EXTENSION IF NOT EXISTS vector;") | ||
pgvector.psycopg.register_vector(self.conn) | ||
|
||
def clean(self): | ||
self.conn.execute( | ||
"DROP TABLE IF EXISTS items CASCADE;", | ||
) | ||
|
||
def recreate(self, dataset: Dataset, collection_params): | ||
if dataset.config.distance == Distance.DOT: | ||
raise IncompatibilityError | ||
|
||
self.conn.execute( | ||
f"""CREATE TABLE items ( | ||
id SERIAL PRIMARY KEY, | ||
embedding vector({dataset.config.vector_size}) NOT NULL | ||
);""" | ||
) | ||
self.conn.execute("ALTER TABLE items ALTER COLUMN embedding SET STORAGE PLAIN") | ||
|
||
try: | ||
hnsw_distance_type = self.DISTANCE_MAPPING[dataset.config.distance] | ||
except KeyError: | ||
raise IncompatibilityError( | ||
f"Unsupported distance metric: {dataset.config.distance}" | ||
) | ||
|
||
self.conn.execute( | ||
f"CREATE INDEX on items USING hnsw(embedding {hnsw_distance_type}) WITH (m = {collection_params['hnsw_config']['m']}, ef_construction = {collection_params['hnsw_config']['ef_construct']})" | ||
) | ||
|
||
self.conn.close() | ||
|
||
def delete_client(self): | ||
self.conn.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import json | ||
from typing import Any, List, Optional | ||
|
||
from engine.base_client import IncompatibilityError | ||
from engine.base_client.parser import BaseConditionParser, FieldValue | ||
|
||
|
||
class PgVectorConditionParser(BaseConditionParser): | ||
def build_condition( | ||
self, and_subfilters: Optional[List[Any]], or_subfilters: Optional[List[Any]] | ||
) -> Optional[Any]: | ||
clauses = [] | ||
if or_subfilters is not None and len(or_subfilters) > 0: | ||
clauses.append(f"( {' OR '.join(or_subfilters)} )") | ||
if and_subfilters is not None and len(and_subfilters) > 0: | ||
clauses.append(f"( {' AND '.join(or_subfilters)} )") | ||
|
||
return " AND ".join(clauses) | ||
|
||
def build_exact_match_filter(self, field_name: str, value: FieldValue) -> Any: | ||
raise f"{field_name} == {json.dumps(value)}" | ||
|
||
def build_range_filter( | ||
self, | ||
field_name: str, | ||
lt: Optional[FieldValue], | ||
gt: Optional[FieldValue], | ||
lte: Optional[FieldValue], | ||
gte: Optional[FieldValue], | ||
) -> Any: | ||
clauses = [] | ||
if lt is not None: | ||
clauses.append(f"{field_name} < {lt}") | ||
if gt is not None: | ||
clauses.append(f"{field_name} > {gt}") | ||
if lte is not None: | ||
clauses.append(f"{field_name} <= {lte}") | ||
if gte is not None: | ||
clauses.append(f"{field_name} >= {gte}") | ||
return f"( {' AND '.join(clauses)} )" | ||
|
||
def build_geo_filter( | ||
self, field_name: str, lat: float, lon: float, radius: float | ||
) -> Any: | ||
# TODO: Implement this | ||
raise IncompatibilityError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import multiprocessing as mp | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
import psycopg | ||
from pgvector.psycopg import register_vector | ||
|
||
from engine.base_client.distances import Distance | ||
from engine.base_client.search import BaseSearcher | ||
from engine.clients.pgvector.config import get_db_config | ||
from engine.clients.pgvector.parser import PgVectorConditionParser | ||
|
||
|
||
class PgVectorSearcher(BaseSearcher): | ||
conn = None | ||
cur = None | ||
distance = None | ||
search_params = {} | ||
parser = PgVectorConditionParser() | ||
|
||
@classmethod | ||
def init_client(cls, host, distance, connection_params: dict, search_params: dict): | ||
cls.conn = psycopg.connect(**get_db_config(host, connection_params)) | ||
register_vector(cls.conn) | ||
cls.cur = cls.conn.cursor() | ||
cls.distance = distance | ||
cls.search_params = search_params["search_params"] | ||
|
||
@classmethod | ||
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: | ||
cls.cur.execute(f"SET hnsw.ef_search = {cls.search_params['hnsw_ef']}") | ||
|
||
if cls.distance == Distance.COSINE: | ||
query = f"SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT {top};" | ||
elif cls.distance == Distance.L2: | ||
query = f"SELECT id, embedding <-> %s AS _score FROM items ORDER BY _score LIMIT {top};" | ||
else: | ||
raise NotImplementedError(f"Unsupported distance metric {cls.distance}") | ||
|
||
cls.cur.execute( | ||
query, | ||
(np.array(vector),), | ||
) | ||
return cls.cur.fetchall() | ||
|
||
@classmethod | ||
def delete_client(cls): | ||
if cls.cur: | ||
cls.cur.close() | ||
cls.conn.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import psycopg | ||
from pgvector.psycopg import register_vector | ||
|
||
from engine.base_client.upload import BaseUploader | ||
from engine.clients.pgvector.config import get_db_config | ||
|
||
|
||
class PgVectorUploader(BaseUploader): | ||
conn = None | ||
cur = None | ||
upload_params = {} | ||
|
||
@classmethod | ||
def init_client(cls, host, distance, connection_params, upload_params): | ||
cls.conn = psycopg.connect(**get_db_config(host, connection_params)) | ||
register_vector(cls.conn) | ||
cls.cur = cls.conn.cursor() | ||
cls.upload_params = upload_params | ||
|
||
@classmethod | ||
def upload_batch( | ||
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]] | ||
): | ||
vectors = np.array(vectors) | ||
|
||
# Copy is faster than insert | ||
with cls.cur.copy("COPY items (id, embedding) FROM STDIN") as copy: | ||
for i, embedding in zip(ids, vectors): | ||
copy.write_row((i, embedding)) | ||
|
||
@classmethod | ||
def delete_client(cls): | ||
if cls.cur: | ||
cls.cur.close() | ||
cls.conn.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
version: '3.7' | ||
|
||
services: | ||
pgvector: | ||
container_name: pgvector | ||
image: ankane/pgvector:v0.5.1 | ||
environment: | ||
- POSTGRES_DB=postgres | ||
- POSTGRES_USER=postgres | ||
- POSTGRES_PASSWORD=passwd | ||
- POSTGRES_HOST_AUTH_METHOD=trust | ||
- POSTGRES_MAX_CONNECTIONS=200 | ||
ports: | ||
- 5432:5432 | ||
logging: | ||
driver: "json-file" | ||
options: | ||
max-file: 1 | ||
max-size: 10m | ||
deploy: | ||
resources: | ||
limits: | ||
memory: 25Gb |
Oops, something went wrong.