Skip to content

Commit

Permalink
support complex index params for pgvecto.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
BeautyyuYanli committed Jan 12, 2024
1 parent b33309f commit 6c35fcb
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 26 deletions.
67 changes: 45 additions & 22 deletions engine/clients/pgvector/upload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import toml
from typing import List, Optional
import psycopg2
from engine.base_client import BaseUploader
Expand All @@ -13,28 +14,45 @@ class PGVectorUploader(BaseUploader):
vector_count: int = None

@classmethod
def init_client(cls, host, distance, vector_count, connection_params, upload_params,
extra_columns_name: list, extra_columns_type: list):
database, host, port, user, password = process_connection_params(connection_params, host)
cls.conn = psycopg2.connect(database=database, user=user, password=password, host=host, port=port)
def init_client(
cls,
host,
distance,
vector_count,
connection_params,
upload_params,
extra_columns_name: list,
extra_columns_type: list,
):
database, host, port, user, password = process_connection_params(
connection_params, host
)
cls.conn = psycopg2.connect(
database=database, user=user, password=password, host=host, port=port
)
cls.host = host
cls.upload_params = upload_params
cls.engine_type = upload_params.get("engine_type", "c")
cls.distance = DISTANCE_MAPPING_CREATE[distance] if cls.engine_type == "c" else DISTANCE_MAPPING_CREATE_RUST[
distance]
cls.distance = (
DISTANCE_MAPPING_CREATE[distance]
if cls.engine_type == "c"
else DISTANCE_MAPPING_CREATE_RUST[distance]
)
cls.vector_count = vector_count

@classmethod
def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]):
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]
):
if len(ids) != len(vectors):
raise RuntimeError("PGVector batch upload unhealthy")
# Getting the names of structured data columns based on the first meta information.
col_name_tuple = ('id', 'vector')
col_type_tuple = ('%s', '%s::real[]')
col_name_tuple = ("id", "vector")
col_type_tuple = ("%s", "%s::real[]")
if metadata[0] is not None:
for col_name in list(metadata[0].keys()):
col_name_tuple += (col_name,)
col_type_tuple += ('%s',)
col_type_tuple += ("%s",)

insert_data = []
for i in range(0, len(ids)):
Expand All @@ -43,7 +61,9 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option
for col_name in list(metadata[i].keys()):
value = metadata[i][col_name]
# Determining if the data is a dictionary type of latitude and longitude.
if isinstance(value, dict) and ('lon' and 'lat') in list(value.keys()):
if isinstance(value, dict) and ("lon" and "lat") in list(
value.keys()
):
raise RuntimeError("Postgres doesn't support geo datasets")
else:
temp_tuple += (value,)
Expand All @@ -63,21 +83,22 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option

@classmethod
def post_upload(cls, distance):
index_options_c = ""
index_options_rust = ""
for key in cls.upload_params.get("index_params", {}).keys():
index_options_c += ("{}={}" if index_options_c == "" else ", {}={}").format(
key, cls.upload_params.get('index_params', {})[key])
index_options_rust += ("{}={}" if index_options_rust == "" else "\n{}={}").format(
key, cls.upload_params.get('index_params', {})[key])
create_index_command = f"CREATE INDEX ON {PGVECTOR_INDEX} USING hnsw (vector {cls.distance}) WITH ({index_options_c});"
if cls.engine_type == "rust":
if cls.engine_type == "c":
index_options_c = ""
for key in cls.upload_params.get("index_params", {}).keys():
index_options_c += (
"{}={}" if index_options_c == "" else ", {}={}"
).format(key, cls.upload_params.get("index_params", {})[key])
create_index_command = f"CREATE INDEX ON {PGVECTOR_INDEX} USING hnsw (vector {cls.distance}) WITH ({index_options_c});"
elif cls.engine_type == "rust":
index_options_rust = toml.dumps(cls.upload_params.get("index_params", {}))
create_index_command = f"""
CREATE INDEX ON {PGVECTOR_INDEX} USING vectors (vector {cls.distance}) WITH (options=$$
[indexing.hnsw]
{index_options_rust}
$$);
"""
else:
raise ValueError("PGVector engine type must be c or rust")

# create index (blocking)
with cls.conn.cursor() as cur:
Expand All @@ -86,5 +107,7 @@ def post_upload(cls, distance):
cls.conn.commit()
# wait index finished
with cls.conn.cursor() as cur:
cur.execute("SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;")
cur.execute(
"SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;"
)
cls.conn.commit()
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@
"parallel": 16,
"batch_size": 64,
"index_params": {
"m": 12,
"ef_construction": 100
"indexing": {
"hnsw": {
"m": 12,
"ef_construction": 100
}
}
},
"index_type": "hnsw",
"engine_type": "rust"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,12 @@
"parallel": 16,
"batch_size": 64,
"index_params": {
"m": 12,
"ef_construction": 100
"indexing": {
"hnsw": {
"m": 12,
"ef_construction": 100
}
}
},
"index_type": "hnsw",
"engine_type": "rust"
Expand Down

0 comments on commit 6c35fcb

Please sign in to comment.