Skip to content

Commit 0371014

Browse files
fix(rate_runner): ensure thread safety for pgvector in concurrent inserts (#585)
Signed-off-by: min.tian <[email protected]>
1 parent 1bbd404 commit 0371014

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

vectordb_bench/backend/runner/rate_runner.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import multiprocessing as mp
44
import time
55
from concurrent.futures import ThreadPoolExecutor
6+
from copy import deepcopy
67

78
from vectordb_bench import config
89
from vectordb_bench.backend.clients import api
10+
from vectordb_bench.backend.clients.pgvector.pgvector import PgVector
911
from vectordb_bench.backend.dataset import DataSetIterator
1012
from vectordb_bench.backend.utils import time_it
1113

@@ -33,17 +35,27 @@ def __init__(
3335
self.executing_futures = []
3436
self.sig_idx = 0
3537

36-
def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str], retry_idx: int = 0):
37-
_, error = db.insert_embeddings(emb, metadata)
38-
if error is not None:
39-
log.warning(f"Insert Failed, try_idx={retry_idx}, Exception: {error}")
40-
retry_idx += 1
41-
if retry_idx <= config.MAX_INSERT_RETRY:
42-
time.sleep(retry_idx)
43-
self.send_insert_task(db, emb=emb, metadata=metadata, retry_idx=retry_idx)
44-
else:
45-
msg = f"Insert failed and retried more than {config.MAX_INSERT_RETRY} times"
46-
raise RuntimeError(msg) from None
38+
def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str]):
39+
def _insert_embeddings(db: api.VectorDB, emb: list[list[float]], metadata: list[str], retry_idx: int = 0):
40+
_, error = db.insert_embeddings(emb, metadata)
41+
if error is not None:
42+
log.warning(f"Insert Failed, try_idx={retry_idx}, Exception: {error}")
43+
retry_idx += 1
44+
if retry_idx <= config.MAX_INSERT_RETRY:
45+
time.sleep(retry_idx)
46+
_insert_embeddings(db, emb=emb, metadata=metadata, retry_idx=retry_idx)
47+
else:
48+
msg = f"Insert failed and retried more than {config.MAX_INSERT_RETRY} times"
49+
raise RuntimeError(msg) from None
50+
51+
if isinstance(db, PgVector):
52+
# pgvector is not thread-safe for concurrent insert,
53+
# so we need to copy the db object, make sure each thread has its own connection
54+
db_copy = deepcopy(db)
55+
with db_copy.init():
56+
_insert_embeddings(db_copy, emb, metadata, retry_idx=0)
57+
else:
58+
_insert_embeddings(db, emb, metadata, retry_idx=0)
4759

4860
@time_it
4961
def run_with_rate(self, q: mp.Queue):

0 commit comments

Comments
 (0)