Skip to content

Commit

Permalink
Merge pull request #91 from qdrant/feat/pgvector-client
Browse files Browse the repository at this point in the history
feat: Implement pgvector client based on ann-benchmarks implementation
  • Loading branch information
KShivendu authored Jan 15, 2024
2 parents 405d87c + 369f114 commit 4224a27
Show file tree
Hide file tree
Showing 14 changed files with 516 additions and 3 deletions.
7 changes: 7 additions & 0 deletions engine/base_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,10 @@ def run_experiment(
)
print("Experiment stage: Done")
print("Results saved to: ", RESULTS_DIR)

def delete_client(self):
self.uploader.delete_client()
self.configurator.delete_client()

for s in self.searchers:
s.delete_client()
3 changes: 3 additions & 0 deletions engine/base_client/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ def configure(self, dataset: Dataset) -> Optional[dict]:

def execution_params(self, distance, vector_size) -> dict:
return {}

def delete_client(self):
pass
8 changes: 8 additions & 0 deletions engine/clients/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
OpenSearchSearcher,
OpenSearchUploader,
)
from engine.clients.pgvector import (
PgVectorConfigurator,
PgVectorSearcher,
PgVectorUploader,
)
from engine.clients.qdrant import QdrantConfigurator, QdrantSearcher, QdrantUploader
from engine.clients.redis import RedisConfigurator, RedisSearcher, RedisUploader
from engine.clients.weaviate import (
Expand All @@ -33,6 +38,7 @@
"elasticsearch": ElasticConfigurator,
"opensearch": OpenSearchConfigurator,
"redis": RedisConfigurator,
"pgvector": PgVectorConfigurator,
}

ENGINE_UPLOADERS = {
Expand All @@ -42,6 +48,7 @@
"elasticsearch": ElasticUploader,
"opensearch": OpenSearchUploader,
"redis": RedisUploader,
"pgvector": PgVectorUploader,
}

ENGINE_SEARCHERS = {
Expand All @@ -51,6 +58,7 @@
"elasticsearch": ElasticSearcher,
"opensearch": OpenSearchSearcher,
"redis": RedisSearcher,
"pgvector": PgVectorSearcher,
}


Expand Down
9 changes: 9 additions & 0 deletions engine/clients/pgvector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from engine.clients.pgvector.configure import PgVectorConfigurator
from engine.clients.pgvector.search import PgVectorSearcher
from engine.clients.pgvector.upload import PgVectorUploader

__all__ = [
"PgVectorConfigurator",
"PgVectorSearcher",
"PgVectorUploader",
]
15 changes: 15 additions & 0 deletions engine/clients/pgvector/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
PGVECTOR_PORT = 9200
PGVECTOR_DB = "postgres"
PGVECTOR_USER = "postgres"
PGVECTOR_PASSWORD = "passwd"


def get_db_config(host, connection_params):
return {
"host": host or "localhost",
"dbname": PGVECTOR_DB,
"user": PGVECTOR_USER,
"password": PGVECTOR_PASSWORD,
"autocommit": True,
**connection_params,
}
55 changes: 55 additions & 0 deletions engine/clients/pgvector/configure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pgvector.psycopg
import psycopg

from benchmark.dataset import Dataset
from engine.base_client import IncompatibilityError
from engine.base_client.configure import BaseConfigurator
from engine.base_client.distances import Distance
from engine.clients.pgvector.config import get_db_config


class PgVectorConfigurator(BaseConfigurator):
DISTANCE_MAPPING = {
Distance.L2: "vector_l2_ops",
Distance.COSINE: "vector_cosine_ops",
}

def __init__(self, host, collection_params: dict, connection_params: dict):
super().__init__(host, collection_params, connection_params)
self.conn = psycopg.connect(**get_db_config(host, connection_params))
print("configure connection created")
self.conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
pgvector.psycopg.register_vector(self.conn)

def clean(self):
self.conn.execute(
"DROP TABLE IF EXISTS items CASCADE;",
)

def recreate(self, dataset: Dataset, collection_params):
if dataset.config.distance == Distance.DOT:
raise IncompatibilityError

self.conn.execute(
f"""CREATE TABLE items (
id SERIAL PRIMARY KEY,
embedding vector({dataset.config.vector_size}) NOT NULL
);"""
)
self.conn.execute("ALTER TABLE items ALTER COLUMN embedding SET STORAGE PLAIN")

try:
hnsw_distance_type = self.DISTANCE_MAPPING[dataset.config.distance]
except KeyError:
raise IncompatibilityError(
f"Unsupported distance metric: {dataset.config.distance}"
)

self.conn.execute(
f"CREATE INDEX on items USING hnsw(embedding {hnsw_distance_type}) WITH (m = {collection_params['hnsw_config']['m']}, ef_construction = {collection_params['hnsw_config']['ef_construct']})"
)

self.conn.close()

def delete_client(self):
self.conn.close()
46 changes: 46 additions & 0 deletions engine/clients/pgvector/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
from typing import Any, List, Optional

from engine.base_client import IncompatibilityError
from engine.base_client.parser import BaseConditionParser, FieldValue


class PgVectorConditionParser(BaseConditionParser):
def build_condition(
self, and_subfilters: Optional[List[Any]], or_subfilters: Optional[List[Any]]
) -> Optional[Any]:
clauses = []
if or_subfilters is not None and len(or_subfilters) > 0:
clauses.append(f"( {' OR '.join(or_subfilters)} )")
if and_subfilters is not None and len(and_subfilters) > 0:
clauses.append(f"( {' AND '.join(or_subfilters)} )")

return " AND ".join(clauses)

def build_exact_match_filter(self, field_name: str, value: FieldValue) -> Any:
raise f"{field_name} == {json.dumps(value)}"

def build_range_filter(
self,
field_name: str,
lt: Optional[FieldValue],
gt: Optional[FieldValue],
lte: Optional[FieldValue],
gte: Optional[FieldValue],
) -> Any:
clauses = []
if lt is not None:
clauses.append(f"{field_name} < {lt}")
if gt is not None:
clauses.append(f"{field_name} > {gt}")
if lte is not None:
clauses.append(f"{field_name} <= {lte}")
if gte is not None:
clauses.append(f"{field_name} >= {gte}")
return f"( {' AND '.join(clauses)} )"

def build_geo_filter(
self, field_name: str, lat: float, lon: float, radius: float
) -> Any:
# TODO: Implement this
raise IncompatibilityError
50 changes: 50 additions & 0 deletions engine/clients/pgvector/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import multiprocessing as mp
from typing import List, Tuple

import numpy as np
import psycopg
from pgvector.psycopg import register_vector

from engine.base_client.distances import Distance
from engine.base_client.search import BaseSearcher
from engine.clients.pgvector.config import get_db_config
from engine.clients.pgvector.parser import PgVectorConditionParser


class PgVectorSearcher(BaseSearcher):
conn = None
cur = None
distance = None
search_params = {}
parser = PgVectorConditionParser()

@classmethod
def init_client(cls, host, distance, connection_params: dict, search_params: dict):
cls.conn = psycopg.connect(**get_db_config(host, connection_params))
register_vector(cls.conn)
cls.cur = cls.conn.cursor()
cls.distance = distance
cls.search_params = search_params["search_params"]

@classmethod
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
cls.cur.execute(f"SET hnsw.ef_search = {cls.search_params['hnsw_ef']}")

if cls.distance == Distance.COSINE:
query = f"SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT {top};"
elif cls.distance == Distance.L2:
query = f"SELECT id, embedding <-> %s AS _score FROM items ORDER BY _score LIMIT {top};"
else:
raise NotImplementedError(f"Unsupported distance metric {cls.distance}")

cls.cur.execute(
query,
(np.array(vector),),
)
return cls.cur.fetchall()

@classmethod
def delete_client(cls):
if cls.cur:
cls.cur.close()
cls.conn.close()
38 changes: 38 additions & 0 deletions engine/clients/pgvector/upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import List, Optional

import numpy as np
import psycopg
from pgvector.psycopg import register_vector

from engine.base_client.upload import BaseUploader
from engine.clients.pgvector.config import get_db_config


class PgVectorUploader(BaseUploader):
conn = None
cur = None
upload_params = {}

@classmethod
def init_client(cls, host, distance, connection_params, upload_params):
cls.conn = psycopg.connect(**get_db_config(host, connection_params))
register_vector(cls.conn)
cls.cur = cls.conn.cursor()
cls.upload_params = upload_params

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
):
vectors = np.array(vectors)

# Copy is faster than insert
with cls.cur.copy("COPY items (id, embedding) FROM STDIN") as copy:
for i, embedding in zip(ids, vectors):
copy.write_row((i, embedding))

@classmethod
def delete_client(cls):
if cls.cur:
cls.cur.close()
cls.conn.close()
23 changes: 23 additions & 0 deletions engine/servers/pgvector-single-node/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
version: '3.7'

services:
pgvector:
container_name: pgvector
image: ankane/pgvector:v0.5.1
environment:
- POSTGRES_DB=postgres
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=passwd
- POSTGRES_HOST_AUTH_METHOD=trust
- POSTGRES_MAX_CONNECTIONS=200
ports:
- 5432:5432
logging:
driver: "json-file"
options:
max-file: 1
max-size: 10m
deploy:
resources:
limits:
memory: 25Gb
Loading

0 comments on commit 4224a27

Please sign in to comment.