diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index fe2e554f..d34a8ea4 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -23,7 +23,7 @@ class IndexType(str, Enum): IVFSQ8 = "IVF_SQ8" Flat = "FLAT" AUTOINDEX = "AUTOINDEX" - ES_HNSW = "hnsw" + ES_HNSW = "ybhnsw" ES_IVFFlat = "ivfflat" GPU_IVF_FLAT = "GPU_IVF_FLAT" GPU_IVF_PQ = "GPU_IVF_PQ" diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index ef8914be..5c2920f9 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -118,7 +118,28 @@ class PgVectorTypedDict(CommonTypedDict): callback=set_default_quantized_fetch_limit, ) ] - + create_index_before_load: Annotated[ + Optional[bool], + click.option( + "--create_index_before_load", + type=bool, + help="Create index before load", + default=True, + required=False, + show_default=True, + ), + ] + create_index_after_load: Annotated[ + Optional[bool], + click.option( + "--create_index_after_load", + type=bool, + help="Create index after load", + default=False, + required=False, + show_default=True, + ), + ] class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict): @@ -151,6 +172,8 @@ def PgVectorIVFFlat( reranking=parameters["reranking"], reranking_metric=parameters["reranking_metric"], quantized_fetch_limit=parameters["quantized_fetch_limit"], + create_index_before_load=parameters["create_index_before_load"], + create_index_after_load=parameters["create_index_after_load"], ), **parameters, ) @@ -188,6 +211,8 @@ def PgVectorHNSW( reranking=parameters["reranking"], reranking_metric=parameters["reranking_metric"], quantized_fetch_limit=parameters["quantized_fetch_limit"], + create_index_before_load=parameters["create_index_before_load"], + create_index_after_load=parameters["create_index_after_load"], ), **parameters, ) diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index 16d54744..a21dc222 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -74,7 +74,7 @@ def parse_metric(self) -> str: return "vector_l2_ops" elif self.metric_type == MetricType.IP: return "vector_ip_ops" - return "vector_cosine_ops" + return "vector_l2_ops" def parse_metric_fun_op(self) -> LiteralString: if self.quantization_type == "bit": @@ -171,6 +171,8 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig): reranking: Optional[bool] = None quantized_fetch_limit: Optional[int] = None reranking_metric: Optional[str] = None + create_index_before_load: Optional[bool] = True + create_index_after_load: Optional[bool] = True def index_param(self) -> PgVectorIndexParam: index_parameters = {"lists": self.lists} @@ -221,6 +223,8 @@ class PgVectorHNSWConfig(PgVectorIndexConfig): reranking: Optional[bool] = None quantized_fetch_limit: Optional[int] = None reranking_metric: Optional[str] = None + create_index_before_load: Optional[bool] = True + create_index_after_load: Optional[bool] = False def index_param(self) -> PgVectorIndexParam: index_parameters = {"m": self.m, "ef_construction": self.ef_construction} diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 069b8938..587291b6 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -387,11 +387,11 @@ def _create_table(self, dim: int): "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" ).format(table_name=sql.Identifier(self.table_name), dim=dim) ) - self.cursor.execute( - sql.SQL( - "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;" - ).format(table_name=sql.Identifier(self.table_name)) - ) + # self.cursor.execute( + # sql.SQL( + # "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;" + # ).format(table_name=sql.Identifier(self.table_name)) + # ) self.conn.commit() except Exception as e: log.warning(