generated from SverreNystad/template_python_application
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #71 from CogitoNTNU/vector-database-tensor-sierra
Vector database tensor sierra
- Loading branch information
Showing
7 changed files
with
247 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from abc import ABC, abstractmethod | ||
from flashcards.knowledge_base.embeddings import cosine_similarity | ||
from config import Config | ||
from pymongo import MongoClient | ||
|
||
|
||
class DatabaseInterface(ABC): | ||
""" | ||
Abstract class for Connecting to a Database | ||
""" | ||
|
||
@classmethod | ||
def __instancecheck__(cls, instance: any) -> bool: | ||
return cls.__subclasscheck__(type(instance)) | ||
|
||
@classmethod | ||
def __subclasscheck__(cls, subclass: any) -> bool: | ||
return ( | ||
hasattr(subclass, "get_curriculum") and callable(subclass.get_curriculum) | ||
) and ( | ||
hasattr(subclass, "post_curriculum") and callable(subclass.post_curriculum) | ||
) | ||
|
||
@abstractmethod | ||
def get_curriculum(self, embedding: list[float]) -> list[str]: | ||
""" | ||
Get the curriculum from the database | ||
Args: | ||
embedding (list[float]): The embedding of the question | ||
Returns: | ||
list[str]: The curriculum related to the question | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def post_curriculum( | ||
self, curriculum: str, page_num: int, paragraph_num: int, embedding: list[float] | ||
) -> bool: | ||
""" | ||
Post the curriculum to the database | ||
Args: | ||
curriculum (str): The curriculum to be posted | ||
embedding (list[float]): The embedding of the question | ||
Returns: | ||
bool: True if the curriculum was posted, False otherwise | ||
""" | ||
pass | ||
|
||
|
||
class MongoDB(DatabaseInterface): | ||
def __init__(self): | ||
self.client = MongoClient(Config().MONGODB_URI) | ||
self.db = self.client["test-curriculum-database"] | ||
self.collection = self.db["test-curriculum-collection"] | ||
self.similarity_threshold = 0.83 | ||
|
||
def get_curriculum(self, embedding: list[float]) -> list[str]: | ||
# Checking if embedding consists of decimals or "none" | ||
if not embedding: | ||
raise ValueError("Embedding cannot be None") | ||
|
||
# Define the MongoDB query that utilizes the search index "embeddings". | ||
query = { | ||
"$vectorSearch": { | ||
"index": "embeddings", | ||
"path": "embedding", | ||
"queryVector": embedding, | ||
"numCandidates": 30, # MongoDB suggests using numCandidates=10*limit or numCandidates=20*limit | ||
"limit": 3, | ||
} | ||
} | ||
|
||
# Execute the query | ||
documents = self.collection.aggregate([query]) | ||
|
||
if not documents: | ||
raise ValueError("No documents found") | ||
|
||
# Convert the documents to a list | ||
documents = list(documents) | ||
|
||
# Filter out the documents with low similarity | ||
for document in documents: | ||
if ( | ||
cosine_similarity(embedding, document["embedding"]) | ||
< self.similarity_threshold | ||
): | ||
documents.remove(document) | ||
|
||
# Return only the text content of the documents | ||
documents = [document["text"] for document in documents] | ||
return documents | ||
|
||
def post_curriculum( | ||
self, curriculum: str, page_num: int, paragraph_num: int, embedding: list[float] | ||
) -> bool: | ||
if not curriculum: | ||
raise ValueError("Curriculum cannot be None") | ||
|
||
if not page_num: | ||
raise ValueError("Page number cannot be None") | ||
|
||
if not paragraph_num: | ||
raise ValueError("Paragraph number cannot be None") | ||
|
||
if not embedding: | ||
raise ValueError("Embedding cannot be None") | ||
|
||
try: | ||
# Insert the curriculum into the database with metadata | ||
self.collection.insert_one( | ||
{ | ||
"text": curriculum, | ||
"pageNum": page_num, | ||
"paragraphNum": paragraph_num, | ||
"embedding": embedding, | ||
} | ||
) | ||
return True | ||
except: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from abc import ABC, abstractmethod | ||
import openai | ||
from config import Config | ||
import numpy as np | ||
|
||
|
||
class EmbeddingsInterface(ABC): | ||
@abstractmethod | ||
def get_embedding(self, text: str) -> list[float]: | ||
""" | ||
Get the embedding of the text using the model | ||
Args: | ||
text (str): The text to be embedded | ||
model (str): The model to be used for embedding | ||
Returns: | ||
list[float]: The embedding of the text | ||
""" | ||
pass | ||
|
||
|
||
class OpenAIEmbedding(EmbeddingsInterface): | ||
def __init__(self, model_name: str = "text-embedding-ada-002"): | ||
api_key = Config().API_KEY | ||
self.client = openai.Client(api_key=api_key) | ||
self.model_name = model_name | ||
|
||
def get_embedding(self, text: str) -> list[float]: | ||
text = text.replace("\n", " ") | ||
response = self.client.embeddings.create(input=text) | ||
return response.data[0].embedding | ||
|
||
|
||
def cosine_similarity(embedding1: list[float], embedding2: list[float]) -> float: | ||
""" | ||
Calculate the cosine similarity between two embeddings | ||
Args: | ||
embedding1 (list[float]): The first embedding | ||
embedding2 (list[float]): The second embedding | ||
Returns: | ||
float: The cosine similarity between the two embeddings | ||
""" | ||
sum_times = 0 | ||
embedding1_sq = 0 | ||
embedding2_sq = 0 | ||
for i in range(len(embedding1)): | ||
sum_times += embedding1[i] * embedding2[i] | ||
embedding1_sq += embedding1[i] ** 2 | ||
embedding2_sq += embedding2[i] ** 2 | ||
root_times_emb1_emb2 = np.sqrt(embedding1_sq * embedding2_sq) | ||
return sum_times / root_times_emb1_emb2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from flashcards.knowledge_base.db_interface import DatabaseInterface, MongoDB | ||
from flashcards.knowledge_base.embeddings import EmbeddingsInterface, OpenAIEmbedding | ||
|
||
|
||
def create_database(database_system: str = "mongodb") -> DatabaseInterface: | ||
match database_system.lower(): | ||
case "mongodb": | ||
return MongoDB() | ||
case _: | ||
raise ValueError(f"Database system {database_system} not supported") | ||
|
||
|
||
def create_embeddings_model(embeddings_model: str = "openai") -> EmbeddingsInterface: | ||
match embeddings_model.lower(): | ||
case "openai": | ||
return OpenAIEmbedding() | ||
case _: | ||
raise ValueError(f"Embeddings model {embeddings_model} not supported") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" Retrieval Augmented Generation Service """ | ||
|
||
from flashcards.knowledge_base.db_interface import DatabaseInterface | ||
from flashcards.knowledge_base.embeddings import EmbeddingsInterface | ||
|
||
|
||
def get_context( | ||
query: str, db: DatabaseInterface, embeddings: EmbeddingsInterface | ||
) -> list[str]: | ||
""" | ||
Get the context of the query | ||
Args: | ||
query (str): The query to get the context of | ||
db (DatabaseInterface): The database to be used | ||
embeddings (EmbeddingsInterface): The embeddings to be used | ||
Returns: | ||
list[str]: The context of the query | ||
""" | ||
embedding = embeddings.get_embedding(query) | ||
context = db.get_curriculum(embedding) | ||
return context | ||
|
||
|
||
def post_context( | ||
context: str, | ||
page_num: int, | ||
paragraph_num: int, | ||
db: DatabaseInterface, | ||
embeddings: EmbeddingsInterface, | ||
) -> bool: | ||
""" | ||
Post the context to the database | ||
Args: | ||
context (str): The context to be posted | ||
page_num (int): The page number of the context | ||
paragraph_num (int): The paragraph number of the context | ||
db (DatabaseInterface): The database to be used | ||
embeddings (EmbeddingsInterface): The embeddings to be used | ||
Returns: | ||
bool: True if the context was posted, False otherwise | ||
""" | ||
embedding = embeddings.get_embedding(context) | ||
return db.post_curriculum(context, page_num, paragraph_num, embedding) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters