-
Notifications
You must be signed in to change notification settings - Fork 0
/
vectorstore.py
62 lines (57 loc) · 2.35 KB
/
vectorstore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os, zipfile, requests
from langchain_text_splitters import RecursiveCharacterTextSplitter
from qdrant_client import QdrantClient
from io import BytesIO
from dotenv import load_dotenv
load_dotenv()
COLLECTION_NAME = "nemo-docs" # Name of the collection
GITHUB_URL = "https://github.com/NVIDIA/NeMo-Guardrails"
BRANCH = "develop"
qdrant_client = QdrantClient(
os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
def load_github_docs(document_url, branch='master'):
filename = os.path.basename(document_url.rstrip('/').strip())
unzip_path = 'docs' + '/' + filename + '-' + branch
document_url = document_url + "/archive/refs/heads/" + branch + ".zip"
temp_dir = os.path.join(os.getcwd(), 'docs')
response = requests.get(document_url)
if response.status_code == 200:
zip_data = BytesIO(response.content)
with zipfile.ZipFile(zip_data, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
return unzip_path
else:
return None
def ingest_embeddings(path):
metadatas = []
text = []
for root, _, files in os.walk(path):
for file_name in files:
file_path = os.path.join(root, file_name)
relative_path = os.path.relpath(file_path, path)
relative_path = GITHUB_URL + "/blob/" + BRANCH + '/' + relative_path
try:
if file_name.endswith(".md"):
with open(file_path, "r", encoding="utf-8") as file:
text.append(file.read())
metadatas.append({"source": relative_path})
except Exception as error:
print(f"Error: {error}")
text_splitter = RecursiveCharacterTextSplitter(separators=["\n\n"], chunk_size=700, chunk_overlap=100)
chunked_documents = text_splitter.create_documents(text, metadatas=metadatas)
chunks, metadata, ids = zip(*[(chunk.page_content, chunk.metadata, i+1) for i, chunk in enumerate(chunked_documents)])
try:
qdrant_client.add(
collection_name=COLLECTION_NAME,
documents=chunks,
metadata=metadata,
ids=ids
)
print("Collection created and persisted")
except Exception as error:
print(f"Error: {error}")
if __name__ == "__main__":
file_path = load_github_docs(GITHUB_URL, BRANCH)
ingest_embeddings(file_path)