Skip to content

Commit 5409914

Browse files
authored
Fix: dynamically increase query params for higher k values (#131)
* fix: return enough results if k > ncells*32 * fix: increase both ndocs and ncells to match k * chore: prepare release * linting * chore: saner ncells for larger datasets * linting
1 parent 9ef207d commit 5409914

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "RAGatouille"
3-
version = "0.0.6c1"
3+
version = "0.0.6c2"
44
description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
55
authors = ["Benjamin Clavie <[email protected]>"]
66
license = "Apache-2.0"

ragatouille/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.0.6c1"
1+
__version__ = "0.0.6c2"
22
from .RAGPretrainedModel import RAGPretrainedModel
33
from .RAGTrainer import RAGTrainer
44

ragatouille/models/colbert.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,14 @@ def _load_searcher(
422422
)
423423

424424
if not force_fast:
425+
self.searcher.configure(ndocs=1024)
426+
self.searcher.configure(ncells=16)
425427
if len(self.searcher.collection) < 10000:
426-
self.searcher.configure(ncells=4)
428+
self.searcher.configure(ncells=8)
427429
self.searcher.configure(centroid_score_threshold=0.4)
428-
self.searcher.configure(ndocs=512)
429430
elif len(self.searcher.collection) < 100000:
430-
self.searcher.configure(ncells=2)
431+
self.searcher.configure(ncells=4)
431432
self.searcher.configure(centroid_score_threshold=0.45)
432-
self.searcher.configure(ndocs=1024)
433433
# Otherwise, use defaults for k
434434
else:
435435
# Use fast settingss
@@ -459,6 +459,22 @@ def search(
459459
for doc_id in doc_ids:
460460
pids.extend(self.docid_pid_map[doc_id])
461461

462+
base_ncells = self.searcher.config.ncells
463+
base_ndocs = self.searcher.config.ndocs
464+
465+
if k > len(self.searcher.collection):
466+
print(
467+
"WARNING: k value is larger than the number of documents in the index!",
468+
f"Lowering k to {len(self.searcher.collection)}...",
469+
)
470+
k = len(self.searcher.collection)
471+
472+
# For smaller collections, we need a higher ncells value to ensure we return enough results
473+
if k > (32 * self.searcher.config.ncells):
474+
self.searcher.configure(ncells=min((k // 32 + 2), base_ncells))
475+
476+
self.searcher.configure(ndocs=max(k * 4, base_ndocs))
477+
462478
if isinstance(query, str):
463479
results = [self._search(query, k, pids)]
464480
else:
@@ -487,6 +503,10 @@ def search(
487503

488504
to_return.append(result_for_query)
489505

506+
# Restore original ncells&ndocs if it had to be changed for large k values
507+
self.searcher.configure(ncells=base_ncells)
508+
self.searcher.configure(ndocs=base_ndocs)
509+
490510
if len(to_return) == 1:
491511
return to_return[0]
492512
return to_return

0 commit comments

Comments
 (0)