Skip to content

Commit

Permalink
Merge pull request #71 from CogitoNTNU/vector-database-tensor-sierra
Browse files Browse the repository at this point in the history
Vector database tensor sierra
  • Loading branch information
SverreNystad authored Mar 14, 2024
2 parents 8813bea + c66bc76 commit 316f7fa
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 0 deletions.
1 change: 1 addition & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ def __init__(self, path='.env', gpt_model="gpt-3.5-turbo"):
self.GPT_MODEL = os.getenv(key='GPT_MODEL', default='gpt-4')
load_dotenv(dotenv_path=path)
self.API_KEY = os.getenv('OPENAI_API_KEY')
self.MONGODB_URI = os.getenv('MONGODB_URI')


Empty file.
125 changes: 125 additions & 0 deletions backend/flashcards/knowledge_base/db_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from abc import ABC, abstractmethod
from flashcards.knowledge_base.embeddings import cosine_similarity
from config import Config
from pymongo import MongoClient


class DatabaseInterface(ABC):
"""
Abstract class for Connecting to a Database
"""

@classmethod
def __instancecheck__(cls, instance: any) -> bool:
return cls.__subclasscheck__(type(instance))

@classmethod
def __subclasscheck__(cls, subclass: any) -> bool:
return (
hasattr(subclass, "get_curriculum") and callable(subclass.get_curriculum)
) and (
hasattr(subclass, "post_curriculum") and callable(subclass.post_curriculum)
)

@abstractmethod
def get_curriculum(self, embedding: list[float]) -> list[str]:
"""
Get the curriculum from the database
Args:
embedding (list[float]): The embedding of the question
Returns:
list[str]: The curriculum related to the question
"""
pass

@abstractmethod
def post_curriculum(
self, curriculum: str, page_num: int, paragraph_num: int, embedding: list[float]
) -> bool:
"""
Post the curriculum to the database
Args:
curriculum (str): The curriculum to be posted
embedding (list[float]): The embedding of the question
Returns:
bool: True if the curriculum was posted, False otherwise
"""
pass


class MongoDB(DatabaseInterface):
def __init__(self):
self.client = MongoClient(Config().MONGODB_URI)
self.db = self.client["test-curriculum-database"]
self.collection = self.db["test-curriculum-collection"]
self.similarity_threshold = 0.83

def get_curriculum(self, embedding: list[float]) -> list[str]:
# Checking if embedding consists of decimals or "none"
if not embedding:
raise ValueError("Embedding cannot be None")

# Define the MongoDB query that utilizes the search index "embeddings".
query = {
"$vectorSearch": {
"index": "embeddings",
"path": "embedding",
"queryVector": embedding,
"numCandidates": 30, # MongoDB suggests using numCandidates=10*limit or numCandidates=20*limit
"limit": 3,
}
}

# Execute the query
documents = self.collection.aggregate([query])

if not documents:
raise ValueError("No documents found")

# Convert the documents to a list
documents = list(documents)

# Filter out the documents with low similarity
for document in documents:
if (
cosine_similarity(embedding, document["embedding"])
< self.similarity_threshold
):
documents.remove(document)

# Return only the text content of the documents
documents = [document["text"] for document in documents]
return documents

def post_curriculum(
self, curriculum: str, page_num: int, paragraph_num: int, embedding: list[float]
) -> bool:
if not curriculum:
raise ValueError("Curriculum cannot be None")

if not page_num:
raise ValueError("Page number cannot be None")

if not paragraph_num:
raise ValueError("Paragraph number cannot be None")

if not embedding:
raise ValueError("Embedding cannot be None")

try:
# Insert the curriculum into the database with metadata
self.collection.insert_one(
{
"text": curriculum,
"pageNum": page_num,
"paragraphNum": paragraph_num,
"embedding": embedding,
}
)
return True
except:
return False
54 changes: 54 additions & 0 deletions backend/flashcards/knowledge_base/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from abc import ABC, abstractmethod
import openai
from config import Config
import numpy as np


class EmbeddingsInterface(ABC):
@abstractmethod
def get_embedding(self, text: str) -> list[float]:
"""
Get the embedding of the text using the model
Args:
text (str): The text to be embedded
model (str): The model to be used for embedding
Returns:
list[float]: The embedding of the text
"""
pass


class OpenAIEmbedding(EmbeddingsInterface):
def __init__(self, model_name: str = "text-embedding-ada-002"):
api_key = Config().API_KEY
self.client = openai.Client(api_key=api_key)
self.model_name = model_name

def get_embedding(self, text: str) -> list[float]:
text = text.replace("\n", " ")
response = self.client.embeddings.create(input=text)
return response.data[0].embedding


def cosine_similarity(embedding1: list[float], embedding2: list[float]) -> float:
"""
Calculate the cosine similarity between two embeddings
Args:
embedding1 (list[float]): The first embedding
embedding2 (list[float]): The second embedding
Returns:
float: The cosine similarity between the two embeddings
"""
sum_times = 0
embedding1_sq = 0
embedding2_sq = 0
for i in range(len(embedding1)):
sum_times += embedding1[i] * embedding2[i]
embedding1_sq += embedding1[i] ** 2
embedding2_sq += embedding2[i] ** 2
root_times_emb1_emb2 = np.sqrt(embedding1_sq * embedding2_sq)
return sum_times / root_times_emb1_emb2
18 changes: 18 additions & 0 deletions backend/flashcards/knowledge_base/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from flashcards.knowledge_base.db_interface import DatabaseInterface, MongoDB
from flashcards.knowledge_base.embeddings import EmbeddingsInterface, OpenAIEmbedding


def create_database(database_system: str = "mongodb") -> DatabaseInterface:
match database_system.lower():
case "mongodb":
return MongoDB()
case _:
raise ValueError(f"Database system {database_system} not supported")


def create_embeddings_model(embeddings_model: str = "openai") -> EmbeddingsInterface:
match embeddings_model.lower():
case "openai":
return OpenAIEmbedding()
case _:
raise ValueError(f"Embeddings model {embeddings_model} not supported")
47 changes: 47 additions & 0 deletions backend/flashcards/rag_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
""" Retrieval Augmented Generation Service """

from flashcards.knowledge_base.db_interface import DatabaseInterface
from flashcards.knowledge_base.embeddings import EmbeddingsInterface


def get_context(
query: str, db: DatabaseInterface, embeddings: EmbeddingsInterface
) -> list[str]:
"""
Get the context of the query
Args:
query (str): The query to get the context of
db (DatabaseInterface): The database to be used
embeddings (EmbeddingsInterface): The embeddings to be used
Returns:
list[str]: The context of the query
"""
embedding = embeddings.get_embedding(query)
context = db.get_curriculum(embedding)
return context


def post_context(
context: str,
page_num: int,
paragraph_num: int,
db: DatabaseInterface,
embeddings: EmbeddingsInterface,
) -> bool:
"""
Post the context to the database
Args:
context (str): The context to be posted
page_num (int): The page number of the context
paragraph_num (int): The paragraph number of the context
db (DatabaseInterface): The database to be used
embeddings (EmbeddingsInterface): The embeddings to be used
Returns:
bool: True if the context was posted, False otherwise
"""
embedding = embeddings.get_embedding(context)
return db.post_curriculum(context, page_num, paragraph_num, embedding)
2 changes: 2 additions & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ httpcore==1.0.2
httpx==0.26.0
idna==3.6
inflection==0.5.1
numpy==1.26.3
openai==1.6.1
packaging==23.2
pdfminer.six==20231228
Expand All @@ -28,6 +29,7 @@ pycparser==2.21
pydantic==2.5.3
pydantic_core==2.14.6
PyJWT==2.8.0
pymongo==4.6.2
PyPDF2==1.26.0
python-dotenv==1.0.0
pytz==2023.3.post1
Expand Down

0 comments on commit 316f7fa

Please sign in to comment.