diff --git a/pyproject.toml b/pyproject.toml index aafe70750..03349ebdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,10 +35,11 @@ dependencies = [ "psutil", "polars", "plotly", - "environs", + "environs<14.1.0", "pydantic PgVectorIndexParam: index_parameters = {"lists": self.lists} @@ -215,6 +217,8 @@ class PgVectorHNSWConfig(PgVectorIndexConfig): reranking: bool | None = None quantized_fetch_limit: int | None = None reranking_metric: str | None = None + create_index_before_load: bool | None = True + create_index_after_load: bool | None = False def index_param(self) -> PgVectorIndexParam: index_parameters = {"m": self.m, "ef_construction": self.ef_construction} diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 4164461fb..61f030cde 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -374,11 +374,11 @@ def _create_table(self, dim: int): "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", ).format(table_name=sql.Identifier(self.table_name), dim=dim), ) - self.cursor.execute( - sql.SQL( - "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;", - ).format(table_name=sql.Identifier(self.table_name)), - ) + # self.cursor.execute( + # sql.SQL( + # "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;", + # ).format(table_name=sql.Identifier(self.table_name)), + # ) self.conn.commit() except Exception as e: log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}") diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py index 62700b0fa..b53dabfa9 100644 --- a/vectordb_bench/backend/dataset.py +++ b/vectordb_bench/backend/dataset.py @@ -202,6 +202,7 @@ def prepare( self, source: DatasetSource = DatasetSource.S3, filters: float | str | None = None, + load_train_data: bool = True, ) -> bool: """Download the dataset from DatasetSource url = f"{source}/{self.data.dir_name}" @@ -210,6 +211,7 @@ def prepare( source(DatasetSource): S3 or AliyunOSS, default as S3 filters(Optional[int | float | str]): combined with dataset's with_gt to compose the correct ground_truth file + load_train_data(bool): whether to download train files, default True Returns: bool: whether the dataset is successfully prepared @@ -217,15 +219,19 @@ def prepare( """ file_count, use_shuffled = self.data.file_count, self.data.use_shuffled - train_files = utils.compose_train_files(file_count, use_shuffled) - all_files = train_files + all_files = [] + + # Only include train files if load_train_data is True + if load_train_data: + train_files = utils.compose_train_files(file_count, use_shuffled) + all_files.extend(train_files) gt_file, test_file = None, None if self.data.with_gt: gt_file, test_file = utils.compose_gt_file(filters), "test.parquet" all_files.extend([gt_file, test_file]) - if not self.data.is_custom: + if not self.data.is_custom and all_files: source.reader().read( dataset=self.data.dir_name.lower(), files=all_files, diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 2a583b4f5..de7228614 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -96,7 +96,9 @@ def init_db(self, drop_old: bool = True) -> None: def _pre_run(self, drop_old: bool = True): try: self.init_db(drop_old) - self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filter_rate) + # Only download train data if LOAD stage is enabled + load_train_data = TaskStage.LOAD in self.config.stages + self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filter_rate, load_train_data=load_train_data) except ModuleNotFoundError as e: log.warning(f"pre run case error: please install client for db: {self.config.db}, error={e}") raise e from None