Skip to content

Commit d6ba9f0

Browse files
authored
Merge pull request #106 from DefangLabs/eric/make-rebuild-run-in-background
2 parents f2c419e + a46c3d8 commit d6ba9f0

File tree

4 files changed

+179
-33
lines changed

4 files changed

+179
-33
lines changed

app/app.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import segment.analytics as analytics
1717
import uuid
18+
import threading
1819

1920
import logging
2021
import redis
@@ -117,25 +118,15 @@ def v1_ask():
117118
return jsonify({"error": "Invalid or missing Ask Token"}), 401
118119

119120

120-
@app.route("/trigger-rebuild", methods=["POST"])
121-
@csrf.exempt
122-
def trigger_rebuild():
123-
token = request.args.get("token")
124-
if token != os.getenv("REBUILD_TOKEN"):
125-
return jsonify({"error": "Unauthorized"}), 401
121+
def run_rebuild():
126122
try:
127123
print("Running get_knowledge_base.py script...")
128124
result = subprocess.run(
129125
["python3", "get_knowledge_base.py"], capture_output=True, text=True
130126
)
131127
if result.returncode != 0:
132128
print(f"Error running get_knowledge_base.py script: {result.stderr}")
133-
return jsonify(
134-
{
135-
"error": "Error running get_knowledge_base.py script",
136-
"details": result.stderr,
137-
}
138-
), 500
129+
return
139130

140131
print("Finished running get_knowledge_base.py script.")
141132

@@ -146,12 +137,7 @@ def trigger_rebuild():
146137
)
147138
if result.returncode != 0:
148139
print(f"Error running get_samples_examples.py script: {result.stderr}")
149-
return jsonify(
150-
{
151-
"error": "Error running get_samples_examples.py script",
152-
"details": result.stderr,
153-
}
154-
), 500
140+
return
155141

156142
print("Finished running get_samples_examples.py script.")
157143

@@ -160,14 +146,28 @@ def trigger_rebuild():
160146
app.rag_system.rebuild()
161147
except Exception as e:
162148
logging.error(f"Error rebuilding embeddings: {str(e)}")
163-
return jsonify({"error": "Error rebuilding embeddings"}), 500
149+
return
164150

165151
logging.info("Finished rebuilding embeddings.")
166-
return jsonify({"status": "Rebuild triggered successfully"}), 200
167152

168153
except Exception as e:
169-
print(f"Error in /trigger-rebuild endpoint: {e}")
170-
return jsonify({"error": "Internal Server Error"}), 500
154+
logging.error(f"Error in rebuild process: {e}")
155+
156+
157+
@app.route("/trigger-rebuild", methods=["POST"])
158+
@csrf.exempt
159+
def trigger_rebuild():
160+
token = request.args.get("token")
161+
if token != os.getenv("REBUILD_TOKEN"):
162+
return jsonify({"error": "Unauthorized"}), 401
163+
164+
# Start the rebuild in a background thread
165+
thread = threading.Thread(target=run_rebuild)
166+
thread.daemon = True # Dies when main process dies
167+
thread.start()
168+
169+
# Return immediately
170+
return jsonify({"status": "Rebuild started successfully"}), 202
171171

172172

173173
@app.route("/data/<path:name>")

app/rag_system.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88
import numpy as np
99
from sklearn.metrics.pairwise import cosine_similarity
1010
import traceback
11+
from atomicwrites import atomic_write
1112

1213
openai.api_base = os.getenv("OPENAI_BASE_URL")
1314
openai.api_key = os.getenv("OPENAI_API_KEY")
1415

1516

1617
class RAGSystem:
18+
# Cache file paths
19+
DOC_EMBEDDINGS_PATH = "./data/doc_embeddings.npy"
20+
DOC_ABOUT_EMBEDDINGS_PATH = "./data/doc_about_embeddings.npy"
21+
1722
def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
1823
self.knowledge_base_path = knowledge_base_path
1924

@@ -23,29 +28,58 @@ def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
2328
# load existing embeddings if available
2429
logging.info("Embedding knowledge base...")
2530

26-
if os.path.exists("./data/doc_about_embeddings.npy") and os.path.exists(
27-
"./data/doc_embeddings.npy"
31+
if os.path.exists(self.DOC_ABOUT_EMBEDDINGS_PATH) and os.path.exists(
32+
self.DOC_EMBEDDINGS_PATH
2833
):
29-
self.doc_about_embeddings = np.load("./data/doc_about_embeddings.npy")
34+
self.doc_about_embeddings = np.load(self.DOC_ABOUT_EMBEDDINGS_PATH)
3035
logging.info("Loaded existing about document about embeddings from disk.")
31-
self.doc_embeddings = np.load("./data/doc_embeddings.npy")
36+
self.doc_embeddings = np.load(self.DOC_EMBEDDINGS_PATH)
3237
logging.info("Loaded existing document embeddings from disk.")
38+
39+
# Save file timestamps when loading cache
40+
self.doc_embeddings_timestamp = os.path.getmtime(self.DOC_EMBEDDINGS_PATH)
41+
self.doc_about_embeddings_timestamp = os.path.getmtime(
42+
self.DOC_ABOUT_EMBEDDINGS_PATH
43+
)
44+
logging.info(
45+
f"Cache loaded - doc_embeddings timestamp: {self.doc_embeddings_timestamp}, doc_about_embeddings timestamp: {self.doc_about_embeddings_timestamp}"
46+
)
3347
else:
3448
self.rebuild_embeddings()
3549

3650
logging.info("Knowledge base embeddings created")
3751
self.conversation_history = []
3852

53+
def _atomic_save_numpy(self, file_path, data):
54+
with atomic_write(file_path, mode="wb", overwrite=True) as f:
55+
np.save(f, data)
56+
3957
def rebuild_embeddings(self):
40-
logging.info("No existing document embeddings found, creating new embeddings.")
41-
self.doc_embeddings = self.embed_knowledge_base()
42-
self.doc_about_embeddings = self.embed_knowledge_base_about()
43-
# cache doc_embeddings to disk
44-
np.save("./data/doc_embeddings.npy", self.doc_embeddings.cpu().numpy())
45-
np.save(
46-
"./data/doc_about_embeddings.npy", self.doc_about_embeddings.cpu().numpy()
58+
logging.info("Rebuilding document embeddings...")
59+
60+
new_doc_embeddings = self.embed_knowledge_base()
61+
new_about_embeddings = self.embed_knowledge_base_about()
62+
63+
# Atomic saves with guaranteed order
64+
self._atomic_save_numpy(
65+
self.DOC_EMBEDDINGS_PATH, new_doc_embeddings.cpu().numpy()
66+
)
67+
self._atomic_save_numpy(
68+
self.DOC_ABOUT_EMBEDDINGS_PATH, new_about_embeddings.cpu().numpy()
4769
)
4870

71+
# Update in-memory embeddings only after successful saves
72+
self.doc_embeddings = new_doc_embeddings
73+
self.doc_about_embeddings = new_about_embeddings
74+
75+
# Update file timestamps after successful saves
76+
self.doc_embeddings_timestamp = os.path.getmtime(self.DOC_EMBEDDINGS_PATH)
77+
self.doc_about_embeddings_timestamp = os.path.getmtime(
78+
self.DOC_ABOUT_EMBEDDINGS_PATH
79+
)
80+
81+
logging.info("Embeddings rebuilt successfully.")
82+
4983
def load_knowledge_base(self):
5084
with open(self.knowledge_base_path, "r") as kb_file:
5185
return json.load(kb_file)
@@ -102,6 +136,43 @@ def compute_document_scores(
102136

103137
return result
104138

139+
def cache_check(func):
140+
"""Decorator to automatically check cache consistency"""
141+
142+
def wrapper(self, *args, **kwargs):
143+
try:
144+
current_times = [
145+
os.path.getmtime(self.DOC_EMBEDDINGS_PATH),
146+
os.path.getmtime(self.DOC_ABOUT_EMBEDDINGS_PATH),
147+
]
148+
stored_times = [
149+
self.doc_embeddings_timestamp,
150+
self.doc_about_embeddings_timestamp,
151+
]
152+
153+
# update cache if timestamps are different from out last load
154+
if current_times != stored_times:
155+
self._reload_cache()
156+
157+
except (OSError, FileNotFoundError, PermissionError):
158+
logging.warning("Cache files inaccessible, rebuilding...")
159+
self.rebuild_embeddings()
160+
161+
return func(self, *args, **kwargs)
162+
163+
return wrapper
164+
165+
def _reload_cache(self):
166+
self.doc_embeddings = np.load(self.DOC_EMBEDDINGS_PATH)
167+
self.doc_about_embeddings = np.load(self.DOC_ABOUT_EMBEDDINGS_PATH)
168+
169+
# update our timestamps of the cached files
170+
self.doc_embeddings_timestamp = os.path.getmtime(self.DOC_EMBEDDINGS_PATH)
171+
self.doc_about_embeddings_timestamp = os.path.getmtime(
172+
self.DOC_ABOUT_EMBEDDINGS_PATH
173+
)
174+
175+
@cache_check
105176
def retrieve(
106177
self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5
107178
):

app/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ PyYAML==6.0.2
1313
GitPython==3.1.44
1414
redis==6.2.0
1515
fakeredis==2.30.1
16+
atomicwrites==1.4.1
1617

1718
# linter
1819
ruff>=0.12.5

app/test_rag_system.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22

33
from rag_system import RAGSystem
4+
import os
45

56

67
class TestRAGSystem(unittest.TestCase):
@@ -123,6 +124,79 @@ def test_compute_document_scores(self):
123124

124125
print("Test for compute_document_scores passed successfully!")
125126

127+
def test_cache_check_reload_cache(self):
128+
# Simulate cache file timestamp change to trigger _reload_cache
129+
original_doc_embeddings_timestamp = self.rag_system.doc_embeddings_timestamp
130+
original_doc_about_embeddings_timestamp = (
131+
self.rag_system.doc_about_embeddings_timestamp
132+
)
133+
134+
# Patch os.path.getmtime to return different timestamps
135+
def fake_getmtime(path):
136+
if path == self.rag_system.DOC_EMBEDDINGS_PATH:
137+
return original_doc_embeddings_timestamp + 1
138+
if path == self.rag_system.DOC_ABOUT_EMBEDDINGS_PATH:
139+
return original_doc_about_embeddings_timestamp + 1
140+
return 0
141+
142+
self.rag_system._reload_cache_called = False
143+
144+
def fake_reload_cache():
145+
self.rag_system._reload_cache_called = True
146+
real_reload_cache()
147+
148+
real_getmtime = os.path.getmtime
149+
os.path.getmtime = fake_getmtime
150+
151+
# Patch _reload_cache to set a flag
152+
real_reload_cache = self.rag_system._reload_cache
153+
self.rag_system._reload_cache = fake_reload_cache
154+
155+
# Call a cache_check-decorated method
156+
self.rag_system.retrieve("test query")
157+
158+
self.assertTrue(
159+
self.rag_system._reload_cache_called,
160+
"Cache reload was not triggered when timestamps changed.",
161+
)
162+
163+
# Restore patched methods
164+
os.path.getmtime = real_getmtime
165+
self.rag_system._reload_cache = real_reload_cache
166+
print("Test for cache_check reload_cache passed successfully!")
167+
168+
def test_cache_check_rebuild_embeddings_on_error(self):
169+
# Patch os.path.getmtime to raise OSError
170+
real_getmtime = os.path.getmtime
171+
172+
def raise_oserror(path):
173+
raise OSError("Simulated error")
174+
175+
os.path.getmtime = raise_oserror
176+
177+
self.rag_system._rebuild_embeddings_called = False
178+
179+
def fake_rebuild_embeddings():
180+
self.rag_system._rebuild_embeddings_called = True
181+
return real_rebuild_embeddings()
182+
183+
self.rag_system.rebuild_embeddings = fake_rebuild_embeddings
184+
# Patch rebuild_embeddings to set a flag
185+
real_rebuild_embeddings = self.rag_system.rebuild_embeddings
186+
187+
# Call a cache_check-decorated method
188+
self.rag_system.retrieve("test query")
189+
190+
self.assertTrue(
191+
self.rag_system._rebuild_embeddings_called,
192+
"rebuild_embeddings was not triggered on cache access error.",
193+
)
194+
195+
# Restore patched methods
196+
os.path.getmtime = real_getmtime
197+
self.rag_system.rebuild_embeddings = real_rebuild_embeddings
198+
print("Test for cache_check rebuild_embeddings on error passed successfully!")
199+
126200

127201
if __name__ == "__main__":
128202
unittest.main()

0 commit comments

Comments
 (0)