From 6c35fcbfca45d814ee5b47c8115e4e479bd30505 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Fri, 12 Jan 2024 17:58:41 +0800 Subject: [PATCH] support complex index params for pgvecto.rs --- engine/clients/pgvector/upload.py | 67 +++++++++++++------ ...rust_HNSW_single_node_laion-768-5m-ip.json | 8 ++- ...ngle_node_laion-768-5m-probability-ip.json | 8 ++- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/engine/clients/pgvector/upload.py b/engine/clients/pgvector/upload.py index fc5e6f5..5510853 100644 --- a/engine/clients/pgvector/upload.py +++ b/engine/clients/pgvector/upload.py @@ -1,4 +1,5 @@ import time +import toml from typing import List, Optional import psycopg2 from engine.base_client import BaseUploader @@ -13,28 +14,45 @@ class PGVectorUploader(BaseUploader): vector_count: int = None @classmethod - def init_client(cls, host, distance, vector_count, connection_params, upload_params, - extra_columns_name: list, extra_columns_type: list): - database, host, port, user, password = process_connection_params(connection_params, host) - cls.conn = psycopg2.connect(database=database, user=user, password=password, host=host, port=port) + def init_client( + cls, + host, + distance, + vector_count, + connection_params, + upload_params, + extra_columns_name: list, + extra_columns_type: list, + ): + database, host, port, user, password = process_connection_params( + connection_params, host + ) + cls.conn = psycopg2.connect( + database=database, user=user, password=password, host=host, port=port + ) cls.host = host cls.upload_params = upload_params cls.engine_type = upload_params.get("engine_type", "c") - cls.distance = DISTANCE_MAPPING_CREATE[distance] if cls.engine_type == "c" else DISTANCE_MAPPING_CREATE_RUST[ - distance] + cls.distance = ( + DISTANCE_MAPPING_CREATE[distance] + if cls.engine_type == "c" + else DISTANCE_MAPPING_CREATE_RUST[distance] + ) cls.vector_count = vector_count @classmethod - def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]): + def upload_batch( + cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]] + ): if len(ids) != len(vectors): raise RuntimeError("PGVector batch upload unhealthy") # Getting the names of structured data columns based on the first meta information. - col_name_tuple = ('id', 'vector') - col_type_tuple = ('%s', '%s::real[]') + col_name_tuple = ("id", "vector") + col_type_tuple = ("%s", "%s::real[]") if metadata[0] is not None: for col_name in list(metadata[0].keys()): col_name_tuple += (col_name,) - col_type_tuple += ('%s',) + col_type_tuple += ("%s",) insert_data = [] for i in range(0, len(ids)): @@ -43,7 +61,9 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option for col_name in list(metadata[i].keys()): value = metadata[i][col_name] # Determining if the data is a dictionary type of latitude and longitude. - if isinstance(value, dict) and ('lon' and 'lat') in list(value.keys()): + if isinstance(value, dict) and ("lon" and "lat") in list( + value.keys() + ): raise RuntimeError("Postgres doesn't support geo datasets") else: temp_tuple += (value,) @@ -63,21 +83,22 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option @classmethod def post_upload(cls, distance): - index_options_c = "" - index_options_rust = "" - for key in cls.upload_params.get("index_params", {}).keys(): - index_options_c += ("{}={}" if index_options_c == "" else ", {}={}").format( - key, cls.upload_params.get('index_params', {})[key]) - index_options_rust += ("{}={}" if index_options_rust == "" else "\n{}={}").format( - key, cls.upload_params.get('index_params', {})[key]) - create_index_command = f"CREATE INDEX ON {PGVECTOR_INDEX} USING hnsw (vector {cls.distance}) WITH ({index_options_c});" - if cls.engine_type == "rust": + if cls.engine_type == "c": + index_options_c = "" + for key in cls.upload_params.get("index_params", {}).keys(): + index_options_c += ( + "{}={}" if index_options_c == "" else ", {}={}" + ).format(key, cls.upload_params.get("index_params", {})[key]) + create_index_command = f"CREATE INDEX ON {PGVECTOR_INDEX} USING hnsw (vector {cls.distance}) WITH ({index_options_c});" + elif cls.engine_type == "rust": + index_options_rust = toml.dumps(cls.upload_params.get("index_params", {})) create_index_command = f""" CREATE INDEX ON {PGVECTOR_INDEX} USING vectors (vector {cls.distance}) WITH (options=$$ -[indexing.hnsw] {index_options_rust} $$); """ + else: + raise ValueError("PGVector engine type must be c or rust") # create index (blocking) with cls.conn.cursor() as cur: @@ -86,5 +107,7 @@ def post_upload(cls, distance): cls.conn.commit() # wait index finished with cls.conn.cursor() as cur: - cur.execute("SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;") + cur.execute( + "SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;" + ) cls.conn.commit() diff --git a/experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-ip.json b/experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-ip.json index e18515c..34e7b35 100644 --- a/experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-ip.json +++ b/experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-ip.json @@ -53,8 +53,12 @@ "parallel": 16, "batch_size": 64, "index_params": { - "m": 12, - "ef_construction": 100 + "indexing": { + "hnsw": { + "m": 12, + "ef_construction": 100 + } + } }, "index_type": "hnsw", "engine_type": "rust" diff --git a/experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-probability-ip.json b/experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-probability-ip.json index 7850413..c890020 100644 --- a/experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-probability-ip.json +++ b/experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-probability-ip.json @@ -105,8 +105,12 @@ "parallel": 16, "batch_size": 64, "index_params": { - "m": 12, - "ef_construction": 100 + "indexing": { + "hnsw": { + "m": 12, + "ef_construction": 100 + } + } }, "index_type": "hnsw", "engine_type": "rust"