Skip to content

Commit d27b693

Browse files
authored
Feat/indexing faissless (#173)
* fix: fix searcher always being reloaded * feat: implement torch kmeans * chore: lower cutoff * chore: move warning * chore: higher kmeans batch size * chore: argument support * chore: restore all default behaviour when `use_faiss` is True after having been false * chore: lint * chore: print exception if one occurs when using pytorch indexing * chore: make _original_train_kmeans robust to subsequent calls * nit: comment feat: rework kmeans to be closer to FAISS chore: store kmeans functions as class attributes fix: method assignment chore: more memory efficient lint chore: lower bsize, resultd unaffected feat: better batching, slower max doc count chore: batch size safe for 8gb GPUs chore: more elaborate warning chore: use external lib to support minibatching, revert to homebrew later * poetry lock * lint * chore: better batch size * 0.0.8 dependency prep
1 parent f8c53cb commit d27b693

File tree

8 files changed

+1954
-1163
lines changed

8 files changed

+1954
-1163
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,14 @@ jobs:
1717
with:
1818
python-version: 3.9
1919

20-
- name: Cache Poetry virtualenv
21-
uses: actions/cache@v3
22-
with:
23-
path: ~/.cache/pypoetry/virtualenvs
24-
key: ${{ runner.os }}-poetry-${{ hashFiles('**/poetry.lock') }}
25-
restore-keys: |
26-
${{ runner.os }}-poetry-
27-
2820
- name: Install Poetry
2921
uses: snok/[email protected]
3022

23+
- name: Clean poetry
24+
run: rm poetry.lock
25+
3126
- name: Install dependencies
32-
run: poetry install --with dev
27+
run: poetry install --with dev --no-cache
3328

3429
- name: Run tests
3530
run: poetry run pytest tests/

poetry.lock

Lines changed: 1540 additions & 1121 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "RAGatouille"
3-
version = "0.0.7post11"
3+
version = "0.0.8"
44
description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
55
authors = ["Benjamin Clavie <[email protected]>"]
66
license = "Apache-2.0"
@@ -9,20 +9,19 @@ packages = [{include = "ragatouille"}]
99
repository = "https://github.com/bclavie/ragatouille"
1010

1111
[tool.poetry.dependencies]
12-
python = ">=3.8.1,<4.0"
13-
ruff = "^0.1.9"
12+
python = ">=3.9,<4.0"
1413
faiss-cpu = "^1.7.4"
1514
transformers = "^4.36.2"
1615
voyager = "^2.0.2"
17-
aiohttp = "3.9.1"
1816
sentence-transformers = "^2.2.2"
19-
torch = "^2.0.1"
20-
llama-index = "^0.9.24"
17+
torch = ">=1.13"
18+
llama-index = ">=0.7"
2119
langchain_core = "^0.1.4"
2220
colbert-ai = "0.2.19"
2321
langchain = "^0.1.0"
2422
onnx = "^1.15.0"
2523
srsly = "2.4.8"
24+
fast-pytorch-kmeans= "0.2.0.1"
2625

2726
[tool.poetry.group.dev.dependencies]
2827
pytest = "^7.4.0"

ragatouille/RAGPretrainedModel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def index(
180180
document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
181181
preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None,
182182
bsize: int = 32,
183+
use_faiss: bool = False,
183184
):
184185
"""Build an index from a list of documents.
185186
@@ -215,6 +216,7 @@ def index(
215216
max_document_length=max_document_length,
216217
overwrite=overwrite_index,
217218
bsize=bsize,
219+
use_faiss=use_faiss,
218220
)
219221

220222
def add_to_index(
@@ -227,6 +229,7 @@ def add_to_index(
227229
document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
228230
preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None,
229231
bsize: int = 32,
232+
use_faiss: bool = False,
230233
):
231234
"""Add documents to an existing index.
232235
@@ -258,6 +261,7 @@ def add_to_index(
258261
new_docid_metadata_map=new_docid_metadata_map,
259262
index_name=index_name,
260263
bsize=bsize,
264+
use_faiss=use_faiss,
261265
)
262266

263267
def delete_from_index(

ragatouille/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.0.7post11"
1+
__version__ = "0.0.8"
22
from .RAGPretrainedModel import RAGPretrainedModel
33
from .RAGTrainer import RAGTrainer
44

ragatouille/models/colbert.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
self.pid_docid_map = None
3636
self.docid_pid_map = None
3737
self.docid_metadata_map = None
38-
self.base_model_max_tokens = 512
38+
self.base_model_max_tokens = 510
3939
if n_gpu == -1:
4040
n_gpu = 1 if torch.cuda.device_count() == 0 else torch.cuda.device_count()
4141

@@ -86,7 +86,7 @@ def __init__(
8686
)
8787
self.base_model_max_tokens = (
8888
self.inference_ckpt.bert.config.max_position_embeddings
89-
)
89+
) - 4
9090

9191
self.run_context = Run().context(self.run_config)
9292
self.run_context.__enter__() # Manually enter the context
@@ -125,6 +125,7 @@ def add_to_index(
125125
new_docid_metadata_map: Optional[List[dict]] = None,
126126
index_name: Optional[str] = None,
127127
bsize: int = 32,
128+
use_faiss: bool = False,
128129
):
129130
self.index_name = index_name if index_name is not None else self.index_name
130131
if self.index_name is None:
@@ -181,6 +182,7 @@ def add_to_index(
181182
new_collection,
182183
verbose=self.verbose != 0,
183184
bsize=bsize,
185+
use_faiss=use_faiss,
184186
)
185187
self.config = self.model_index.config
186188

@@ -294,6 +296,7 @@ def index(
294296
max_document_length: int = 256,
295297
overwrite: Union[bool, str] = "reuse",
296298
bsize: int = 32,
299+
use_faiss: bool = False,
297300
):
298301
self.collection = collection
299302
self.config.doc_maxlen = max_document_length
@@ -341,6 +344,7 @@ def index(
341344
overwrite,
342345
verbose=self.verbose != 0,
343346
bsize=bsize,
347+
use_faiss=use_faiss,
344348
)
345349
self.config = self.model_index.config
346350
self._save_index_metadata()
@@ -494,7 +498,11 @@ def _set_inference_max_tokens(
494498
not hasattr(self, "inference_ckpt_len_set")
495499
or self.inference_ckpt_len_set is False
496500
):
497-
if max_tokens == "auto" or max_tokens > self.base_model_max_tokens:
501+
if max_tokens == "auto":
502+
max_tokens = self.base_model_max_tokens
503+
else:
504+
max_tokens = int(max_tokens)
505+
if max_tokens > self.base_model_max_tokens:
498506
max_tokens = self.base_model_max_tokens
499507
percentile_90 = np.percentile(
500508
[len(x.split(" ")) for x in documents], 90
@@ -504,6 +512,7 @@ def _set_inference_max_tokens(
504512
self.base_model_max_tokens,
505513
)
506514
max_tokens = max(256, max_tokens)
515+
507516
if max_tokens > 300:
508517
print(
509518
f"Your documents are roughly {percentile_90} tokens long at the 90th percentile!",

ragatouille/models/index.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
from abc import ABC, abstractmethod
2+
from copy import deepcopy
23
from pathlib import Path
34
from time import time
45
from typing import Any, List, Literal, Optional, TypeVar, Union
56

67
import srsly
78
import torch
89
from colbert import Indexer, IndexUpdater, Searcher
10+
from colbert.indexing.collection_indexer import CollectionIndexer
911
from colbert.infra import ColBERTConfig
1012

13+
from ragatouille.models import torch_kmeans
14+
1115
IndexType = Literal["FLAT", "HNSW", "PLAID"]
1216

1317

@@ -126,6 +130,8 @@ class HNSWModelIndex(ModelIndex):
126130
class PLAIDModelIndex(ModelIndex):
127131
_DEFAULT_INDEX_BSIZE = 32
128132
index_type = "PLAID"
133+
faiss_kmeans = staticmethod(deepcopy(CollectionIndexer._train_kmeans))
134+
pytorch_kmeans = staticmethod(torch_kmeans._train_kmeans)
129135

130136
def __init__(self, config: ColBERTConfig) -> None:
131137
super().__init__(config)
@@ -168,21 +174,6 @@ def build(
168174
bsize = kwargs.get("bsize", PLAIDModelIndex._DEFAULT_INDEX_BSIZE)
169175
assert isinstance(bsize, int)
170176

171-
if torch.cuda.is_available():
172-
import faiss
173-
174-
if not hasattr(faiss, "StandardGpuResources"):
175-
print(
176-
"________________________________________________________________________________\n"
177-
"WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
178-
"This means that indexing will be slow. To make use of your GPU.\n"
179-
"Please install `faiss-gpu` by running:\n"
180-
"pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
181-
"________________________________________________________________________________",
182-
)
183-
print("Will continue with CPU indexing in 5 seconds...")
184-
time.sleep(5)
185-
186177
nbits = 2
187178
if len(collection) < 5000:
188179
nbits = 8
@@ -192,22 +183,76 @@ def build(
192183
self.config, ColBERTConfig(nbits=nbits, index_bsize=bsize)
193184
)
194185

186+
# Instruct colbert-ai to disable forking if nranks == 1
187+
self.config.avoid_fork_if_possible = True
188+
195189
if len(collection) > 100000:
196190
self.config.kmeans_niters = 4
197191
elif len(collection) > 50000:
198192
self.config.kmeans_niters = 10
199193
else:
200194
self.config.kmeans_niters = 20
201195

202-
# Instruct colbert-ai to disable forking if nranks == 1
203-
self.config.avoid_fork_if_possible = True
204-
indexer = Indexer(
205-
checkpoint=checkpoint,
206-
config=self.config,
207-
verbose=verbose,
196+
# Monkey-patch colbert-ai to avoid using FAISS
197+
monkey_patching = (
198+
len(collection) < 100000 and kwargs.get("use_faiss", False) is False
208199
)
209-
indexer.configure(avoid_fork_if_possible=True)
210-
indexer.index(name=index_name, collection=collection, overwrite=overwrite)
200+
if monkey_patching:
201+
print(
202+
"---- WARNING! You are using PLAID with an experimental replacement for FAISS for greater compatibility ----"
203+
)
204+
print("This is a behaviour change from RAGatouille 0.8.0 onwards.")
205+
print(
206+
"This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations."
207+
)
208+
print(
209+
"If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour."
210+
)
211+
print("--------------------")
212+
CollectionIndexer._train_kmeans = self.pytorch_kmeans
213+
214+
# Try to keep runtime stable -- these are values that empirically didn't degrade performance at all on 3 benchmarks.
215+
# More tests required before warning can be removed.
216+
try:
217+
indexer = Indexer(
218+
checkpoint=checkpoint,
219+
config=self.config,
220+
verbose=verbose,
221+
)
222+
indexer.configure(avoid_fork_if_possible=True)
223+
indexer.index(
224+
name=index_name, collection=collection, overwrite=overwrite
225+
)
226+
except Exception as err:
227+
print(
228+
f"PyTorch-based indexing did not succeed with error: {err}",
229+
"! Reverting to using FAISS and attempting again...",
230+
)
231+
monkey_patching = False
232+
if monkey_patching is False:
233+
CollectionIndexer._train_kmeans = self.faiss_kmeans
234+
if torch.cuda.is_available():
235+
import faiss
236+
237+
if not hasattr(faiss, "StandardGpuResources"):
238+
print(
239+
"________________________________________________________________________________\n"
240+
"WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
241+
"This means that indexing will be slow. To make use of your GPU.\n"
242+
"Please install `faiss-gpu` by running:\n"
243+
"pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
244+
"________________________________________________________________________________",
245+
)
246+
print("Will continue with CPU indexing in 5 seconds...")
247+
time.sleep(5)
248+
indexer = Indexer(
249+
checkpoint=checkpoint,
250+
config=self.config,
251+
verbose=verbose,
252+
)
253+
indexer.configure(avoid_fork_if_possible=True)
254+
indexer.index(name=index_name, collection=collection, overwrite=overwrite)
255+
211256
return self
212257

213258
def _load_searcher(

0 commit comments

Comments
 (0)