Skip to content

Commit 3f4dde0

Browse files
committed
fixed file-based RAG
1 parent bb92452 commit 3f4dde0

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

modules/index_func.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
import os
1+
import hashlib
22
import logging
3+
import os
34

4-
import hashlib
55
import PyPDF2
6+
from langchain_community.chat_models import ChatOpenAI
7+
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
8+
from langchain_community.vectorstores import FAISS
9+
from langchain_openai import OpenAIEmbeddings
610
from tqdm import tqdm
711

12+
from modules.config import local_embedding
813
from modules.presets import *
914
from modules.utils import *
10-
from modules.config import local_embedding
1115

1216

1317
def get_documents(file_src):
@@ -28,8 +32,8 @@ def get_documents(file_src):
2832
if file_type == ".pdf":
2933
logging.debug("Loading PDF...")
3034
try:
31-
from modules.pdf_func import parse_pdf
3235
from modules.config import advance_docs
36+
from modules.pdf_func import parse_pdf
3337

3438
two_column = advance_docs["pdf"].get("two_column", False)
3539
pdftext = parse_pdf(filepath, two_column).text
@@ -43,12 +47,14 @@ def get_documents(file_src):
4347
metadata={"source": filepath})]
4448
elif file_type == ".docx":
4549
logging.debug("Loading Word...")
46-
from langchain.document_loaders import UnstructuredWordDocumentLoader
50+
from langchain.document_loaders import \
51+
UnstructuredWordDocumentLoader
4752
loader = UnstructuredWordDocumentLoader(filepath)
4853
texts = loader.load()
4954
elif file_type == ".pptx":
5055
logging.debug("Loading PowerPoint...")
51-
from langchain.document_loaders import UnstructuredPowerPointLoader
56+
from langchain.document_loaders import \
57+
UnstructuredPowerPointLoader
5258
loader = UnstructuredPowerPointLoader(filepath)
5359
texts = loader.load()
5460
elif file_type == ".epub":
@@ -93,9 +99,6 @@ def construct_index(
9399
separator=" ",
94100
load_from_cache_if_possible=True,
95101
):
96-
from langchain.chat_models import ChatOpenAI
97-
from langchain.vectorstores import FAISS
98-
99102
if api_key:
100103
os.environ["OPENAI_API_KEY"] = api_key
101104
else:
@@ -109,11 +112,9 @@ def construct_index(
109112
index_name = get_file_hash(file_src)
110113
index_path = f"./index/{index_name}"
111114
if local_embedding:
112-
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
113115
embeddings = HuggingFaceEmbeddings(
114116
model_name="sentence-transformers/distiluse-base-multilingual-cased-v2")
115117
else:
116-
from langchain.embeddings import OpenAIEmbeddings
117118
if os.environ.get("OPENAI_API_TYPE", "openai") == "openai":
118119
embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get(
119120
"OPENAI_API_BASE", None), openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key))
@@ -122,7 +123,7 @@ def construct_index(
122123
model=os.environ["AZURE_EMBEDDING_MODEL_NAME"], openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"], openai_api_type="azure")
123124
if os.path.exists(index_path) and load_from_cache_if_possible:
124125
logging.info(i18n("找到了缓存的索引文件,加载中……"))
125-
return FAISS.load_local(index_path, embeddings)
126+
return FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
126127
else:
127128
documents = get_documents(file_src)
128129
logging.debug(i18n("构建索引中……"))

0 commit comments

Comments
 (0)