3
3
import os
4
4
import sys
5
5
import logging
6
+ import threading
6
7
from datetime import date
7
8
from sentence_transformers import SentenceTransformer
8
9
import numpy as np
9
10
from sklearn .metrics .pairwise import cosine_similarity
10
11
import traceback
11
12
from atomicwrites import atomic_write
12
13
14
+
13
15
openai .api_base = os .getenv ("OPENAI_BASE_URL" )
14
16
openai .api_key = os .getenv ("OPENAI_API_KEY" )
15
17
@@ -20,9 +22,10 @@ class RAGSystem:
20
22
DOC_ABOUT_EMBEDDINGS_PATH = "./data/doc_about_embeddings.npy"
21
23
22
24
def __init__ (self , knowledge_base_path = "./data/knowledge_base.json" ):
25
+ self ._update_lock = threading .Lock ()
23
26
self .knowledge_base_path = knowledge_base_path
24
27
25
- self . knowledge_base = self .load_knowledge_base ()
28
+ knowledge_base = self .load_knowledge_base ()
26
29
self .model = SentenceTransformer ("all-MiniLM-L6-v2" )
27
30
28
31
# load existing embeddings if available
@@ -31,21 +34,27 @@ def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
31
34
if os .path .exists (self .DOC_ABOUT_EMBEDDINGS_PATH ) and os .path .exists (
32
35
self .DOC_EMBEDDINGS_PATH
33
36
):
34
- self .doc_about_embeddings = np .load (self .DOC_ABOUT_EMBEDDINGS_PATH )
35
- logging .info ("Loaded existing about document about embeddings from disk." )
36
- self .doc_embeddings = np .load (self .DOC_EMBEDDINGS_PATH )
37
- logging .info ("Loaded existing document embeddings from disk." )
37
+ with self ._update_lock :
38
+ self .doc_about_embeddings = np .load (self .DOC_ABOUT_EMBEDDINGS_PATH )
39
+ logging .info (
40
+ "Loaded existing about document about embeddings from disk."
41
+ )
42
+ self .doc_embeddings = np .load (self .DOC_EMBEDDINGS_PATH )
43
+ logging .info ("Loaded existing document embeddings from disk." )
44
+ self .knowledge_base = knowledge_base
38
45
39
- # Save file timestamps when loading cache
40
- self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
41
- self .doc_about_embeddings_timestamp = os .path .getmtime (
42
- self .DOC_ABOUT_EMBEDDINGS_PATH
43
- )
44
- logging .info (
45
- f"Cache loaded - doc_embeddings timestamp: { self .doc_embeddings_timestamp } , doc_about_embeddings timestamp: { self .doc_about_embeddings_timestamp } "
46
- )
46
+ # Save file timestamps when loading cache
47
+ self .doc_embeddings_timestamp = os .path .getmtime (
48
+ self .DOC_EMBEDDINGS_PATH
49
+ )
50
+ self .doc_about_embeddings_timestamp = os .path .getmtime (
51
+ self .DOC_ABOUT_EMBEDDINGS_PATH
52
+ )
53
+ logging .info (
54
+ f"Cache loaded - doc_embeddings timestamp: { self .doc_embeddings_timestamp } , doc_about_embeddings timestamp: { self .doc_about_embeddings_timestamp } "
55
+ )
47
56
else :
48
- self .rebuild_embeddings ()
57
+ self .rebuild_embeddings (knowledge_base )
49
58
50
59
logging .info ("Knowledge base embeddings created" )
51
60
self .conversation_history = []
@@ -54,43 +63,53 @@ def _atomic_save_numpy(self, file_path, data):
54
63
with atomic_write (file_path , mode = "wb" , overwrite = True ) as f :
55
64
np .save (f , data )
56
65
57
- def rebuild_embeddings (self ):
66
+ def rebuild_embeddings (self , knowledge_base ):
58
67
logging .info ("Rebuilding document embeddings..." )
59
68
60
- new_doc_embeddings = self .embed_knowledge_base ()
61
- new_about_embeddings = self .embed_knowledge_base_about ()
62
-
63
- # Atomic saves with guaranteed order
64
- self ._atomic_save_numpy (
65
- self .DOC_EMBEDDINGS_PATH , new_doc_embeddings .cpu ().numpy ()
66
- )
67
- self ._atomic_save_numpy (
68
- self .DOC_ABOUT_EMBEDDINGS_PATH , new_about_embeddings .cpu ().numpy ()
69
- )
69
+ new_doc_embeddings = self .embed_knowledge_base (knowledge_base )
70
+ new_about_embeddings = self .embed_knowledge_base_about (knowledge_base )
70
71
71
- # Update in-memory embeddings only after successful saves
72
- self .doc_embeddings = new_doc_embeddings
73
- self .doc_about_embeddings = new_about_embeddings
72
+ # Defensive check for size mismatches
73
+ sizes = [
74
+ len (new_about_embeddings ),
75
+ len (new_doc_embeddings ),
76
+ len (knowledge_base ),
77
+ ]
78
+ if len (set (sizes )) > 1 : # Not all sizes are equal
79
+ logging .error (
80
+ f"rebuild embeddings Array size mismatch detected: text_similarities={ sizes [0 ]} , about_similarities={ sizes [1 ]} , knowledge_base={ sizes [2 ]} "
81
+ )
82
+ return # Abandon update
74
83
75
- # Update file timestamps after successful saves
76
- self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
77
- self .doc_about_embeddings_timestamp = os .path .getmtime (
78
- self .DOC_ABOUT_EMBEDDINGS_PATH
79
- )
84
+ # Atomically update files, in-memory cache, and timestamps
85
+ with self ._update_lock :
86
+ self ._atomic_save_numpy (
87
+ self .DOC_EMBEDDINGS_PATH , new_doc_embeddings .cpu ().numpy ()
88
+ )
89
+ self ._atomic_save_numpy (
90
+ self .DOC_ABOUT_EMBEDDINGS_PATH , new_about_embeddings .cpu ().numpy ()
91
+ )
92
+ self .knowledge_base = knowledge_base
93
+ self .doc_embeddings = new_doc_embeddings
94
+ self .doc_about_embeddings = new_about_embeddings
95
+ self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
96
+ self .doc_about_embeddings_timestamp = os .path .getmtime (
97
+ self .DOC_ABOUT_EMBEDDINGS_PATH
98
+ )
80
99
81
100
logging .info ("Embeddings rebuilt successfully." )
82
101
83
102
def load_knowledge_base (self ):
84
103
with open (self .knowledge_base_path , "r" ) as kb_file :
85
104
return json .load (kb_file )
86
105
87
- def embed_knowledge_base (self ):
88
- docs = [f"{ doc ['about' ]} . { doc ['text' ]} " for doc in self . knowledge_base ]
106
+ def embed_knowledge_base (self , knowledge_base ):
107
+ docs = [f"{ doc ['about' ]} . { doc ['text' ]} " for doc in knowledge_base ]
89
108
return self .model .encode (docs , convert_to_tensor = True )
90
109
91
- def embed_knowledge_base_about (self ):
110
+ def embed_knowledge_base_about (self , knowledge_base ):
92
111
return self .model .encode (
93
- [doc ["about" ] for doc in self . knowledge_base ], convert_to_tensor = True
112
+ [doc ["about" ] for doc in knowledge_base ], convert_to_tensor = True
94
113
)
95
114
96
115
def normalize_query (self , query ):
@@ -193,6 +212,7 @@ def compute_relevance_scores(
193
212
self , text_similarities , about_similarities , high_match_threshold
194
213
):
195
214
relevance_scores = []
215
+
196
216
for i , _ in enumerate (self .knowledge_base ):
197
217
about_similarity = about_similarities [i ]
198
218
text_similarity = text_similarities [i ]
@@ -321,8 +341,8 @@ def rebuild(self):
321
341
Rebuild the embeddings for the knowledge base. This should be called whenever the knowledge base is updated.
322
342
"""
323
343
print ("Rebuilding embeddings for the knowledge base..." )
324
- self . knowledge_base = self .load_knowledge_base () # Reload the knowledge base
325
- self .doc_embeddings = self . rebuild_embeddings () # Rebuild the embeddings
344
+ knowledge_base = self .load_knowledge_base () # Reload the knowledge base
345
+ self .rebuild_embeddings (knowledge_base ) # Rebuild the embeddings
326
346
print ("Embeddings have been rebuilt." )
327
347
328
348
def get_citations (self , retrieved_docs ):
0 commit comments