8
8
import numpy as np
9
9
from sklearn .metrics .pairwise import cosine_similarity
10
10
import traceback
11
+ from atomicwrites import atomic_write
11
12
12
13
openai .api_base = os .getenv ("OPENAI_BASE_URL" )
13
14
openai .api_key = os .getenv ("OPENAI_API_KEY" )
14
15
15
16
16
17
class RAGSystem :
18
+ # Cache file paths
19
+ DOC_EMBEDDINGS_PATH = "./data/doc_embeddings.npy"
20
+ DOC_ABOUT_EMBEDDINGS_PATH = "./data/doc_about_embeddings.npy"
21
+
17
22
def __init__ (self , knowledge_base_path = "./data/knowledge_base.json" ):
18
23
self .knowledge_base_path = knowledge_base_path
19
24
@@ -23,29 +28,58 @@ def __init__(self, knowledge_base_path="./data/knowledge_base.json"):
23
28
# load existing embeddings if available
24
29
logging .info ("Embedding knowledge base..." )
25
30
26
- if os .path .exists ("./data/doc_about_embeddings.npy" ) and os .path .exists (
27
- "./data/doc_embeddings.npy"
31
+ if os .path .exists (self . DOC_ABOUT_EMBEDDINGS_PATH ) and os .path .exists (
32
+ self . DOC_EMBEDDINGS_PATH
28
33
):
29
- self .doc_about_embeddings = np .load ("./data/doc_about_embeddings.npy" )
34
+ self .doc_about_embeddings = np .load (self . DOC_ABOUT_EMBEDDINGS_PATH )
30
35
logging .info ("Loaded existing about document about embeddings from disk." )
31
- self .doc_embeddings = np .load ("./data/doc_embeddings.npy" )
36
+ self .doc_embeddings = np .load (self . DOC_EMBEDDINGS_PATH )
32
37
logging .info ("Loaded existing document embeddings from disk." )
38
+
39
+ # Save file timestamps when loading cache
40
+ self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
41
+ self .doc_about_embeddings_timestamp = os .path .getmtime (
42
+ self .DOC_ABOUT_EMBEDDINGS_PATH
43
+ )
44
+ logging .info (
45
+ f"Cache loaded - doc_embeddings timestamp: { self .doc_embeddings_timestamp } , doc_about_embeddings timestamp: { self .doc_about_embeddings_timestamp } "
46
+ )
33
47
else :
34
48
self .rebuild_embeddings ()
35
49
36
50
logging .info ("Knowledge base embeddings created" )
37
51
self .conversation_history = []
38
52
53
+ def _atomic_save_numpy (self , file_path , data ):
54
+ with atomic_write (file_path , mode = "wb" , overwrite = True ) as f :
55
+ np .save (f , data )
56
+
39
57
def rebuild_embeddings (self ):
40
- logging .info ("No existing document embeddings found, creating new embeddings." )
41
- self .doc_embeddings = self .embed_knowledge_base ()
42
- self .doc_about_embeddings = self .embed_knowledge_base_about ()
43
- # cache doc_embeddings to disk
44
- np .save ("./data/doc_embeddings.npy" , self .doc_embeddings .cpu ().numpy ())
45
- np .save (
46
- "./data/doc_about_embeddings.npy" , self .doc_about_embeddings .cpu ().numpy ()
58
+ logging .info ("Rebuilding document embeddings..." )
59
+
60
+ new_doc_embeddings = self .embed_knowledge_base ()
61
+ new_about_embeddings = self .embed_knowledge_base_about ()
62
+
63
+ # Atomic saves with guaranteed order
64
+ self ._atomic_save_numpy (
65
+ self .DOC_EMBEDDINGS_PATH , new_doc_embeddings .cpu ().numpy ()
66
+ )
67
+ self ._atomic_save_numpy (
68
+ self .DOC_ABOUT_EMBEDDINGS_PATH , new_about_embeddings .cpu ().numpy ()
47
69
)
48
70
71
+ # Update in-memory embeddings only after successful saves
72
+ self .doc_embeddings = new_doc_embeddings
73
+ self .doc_about_embeddings = new_about_embeddings
74
+
75
+ # Update file timestamps after successful saves
76
+ self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
77
+ self .doc_about_embeddings_timestamp = os .path .getmtime (
78
+ self .DOC_ABOUT_EMBEDDINGS_PATH
79
+ )
80
+
81
+ logging .info ("Embeddings rebuilt successfully." )
82
+
49
83
def load_knowledge_base (self ):
50
84
with open (self .knowledge_base_path , "r" ) as kb_file :
51
85
return json .load (kb_file )
@@ -102,6 +136,43 @@ def compute_document_scores(
102
136
103
137
return result
104
138
139
+ def cache_check (func ):
140
+ """Decorator to automatically check cache consistency"""
141
+
142
+ def wrapper (self , * args , ** kwargs ):
143
+ try :
144
+ current_times = [
145
+ os .path .getmtime (self .DOC_EMBEDDINGS_PATH ),
146
+ os .path .getmtime (self .DOC_ABOUT_EMBEDDINGS_PATH ),
147
+ ]
148
+ stored_times = [
149
+ self .doc_embeddings_timestamp ,
150
+ self .doc_about_embeddings_timestamp ,
151
+ ]
152
+
153
+ # update cache if timestamps are different from out last load
154
+ if current_times != stored_times :
155
+ self ._reload_cache ()
156
+
157
+ except (OSError , FileNotFoundError , PermissionError ):
158
+ logging .warning ("Cache files inaccessible, rebuilding..." )
159
+ self .rebuild_embeddings ()
160
+
161
+ return func (self , * args , ** kwargs )
162
+
163
+ return wrapper
164
+
165
+ def _reload_cache (self ):
166
+ self .doc_embeddings = np .load (self .DOC_EMBEDDINGS_PATH )
167
+ self .doc_about_embeddings = np .load (self .DOC_ABOUT_EMBEDDINGS_PATH )
168
+
169
+ # update our timestamps of the cached files
170
+ self .doc_embeddings_timestamp = os .path .getmtime (self .DOC_EMBEDDINGS_PATH )
171
+ self .doc_about_embeddings_timestamp = os .path .getmtime (
172
+ self .DOC_ABOUT_EMBEDDINGS_PATH
173
+ )
174
+
175
+ @cache_check
105
176
def retrieve (
106
177
self , query , similarity_threshold = 0.4 , high_match_threshold = 0.8 , max_docs = 5
107
178
):
0 commit comments