Skip to content

Commit

Permalink
Merge pull request #21 from decodingml/module-5
Browse files Browse the repository at this point in the history
Module 5
  • Loading branch information
alexandruvesa authored May 26, 2024
2 parents d649889 + e0c6fcf commit 2ef453d
Show file tree
Hide file tree
Showing 55 changed files with 6,752 additions and 106 deletions.
6 changes: 6 additions & 0 deletions course/module-3/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ QDRANT_DATABASE_HOST="localhost"
QDRANT_DATABASE_PORT=6333
QDRANT_APIKEY=str
USE_QDRANT_CLOUD=False

# MQ config
RABBITMQ_DEFAULT_USERNAME="guest"
RABBITMQ_DEFAULT_PASSWORD="guest"
Expand All @@ -21,3 +22,8 @@ RABBITMQ_PORT= 5673

# Retrieval config
OPENAI_API_KEY="str"

# Comet ML config
COMET_API_KEY = "str"
COMET_WORKSPACE = "str"
COMET_PROJECT = "llm-twin-course"
4 changes: 2 additions & 2 deletions course/module-3/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ local-insert-data-mongo: #Insert data to mongodb
local-bytewax: # Run bytewax pipeline
RUST_BACKTRACE=full poetry run python -m bytewax.run data_flow/bytewax_pipeline

generate-dataset: # Generate dataset for finetuning and version it in CometML
python finetuning/generate_data.py
generate-dataset: # Generate dataset for finetuning and version it in Comet ML
python -m finetuning.generate_data

local-test-retriever: # Test retriever
poetry run python retriever.py
Expand Down
4 changes: 3 additions & 1 deletion course/module-3/data_flow/bytewax_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
EmbeddingDispatcher,
RawDispatcher,
)
from db.qdrant import connection
from db.qdrant import QdrantDatabaseConnector

connection = QdrantDatabaseConnector()

flow = Dataflow("Streaming ingestion pipeline")
stream = op.input("input", flow, RabbitMQSource())
Expand Down
34 changes: 22 additions & 12 deletions course/module-3/db/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from qdrant_client import QdrantClient
import logger_utils
from qdrant_client import QdrantClient, models
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import Batch, Distance, VectorParams

import logger_utils

from settings import settings

logger = logger_utils.get_logger(__name__)


class QdrantDatabaseConnector:
_instance: QdrantClient = None

def __init__(self):
_instance: QdrantClient | None = None
def __init__(self) -> None:
if self._instance is None:
try:
if settings.USE_QDRANT_CLOUD:
Expand All @@ -24,12 +25,12 @@ def __init__(self):
host=settings.QDRANT_DATABASE_HOST,
port=settings.QDRANT_DATABASE_PORT,
)

except UnexpectedResponse:
logger.exception(
"Couldn't connect to the database.",
"Couldn't connect to Qdrant.",
host=settings.QDRANT_DATABASE_HOST,
port=settings.QDRANT_DATABASE_PORT,
url=settings.QDRANT_CLOUD_URL,
)

raise
Expand All @@ -38,12 +39,16 @@ def get_collection(self, collection_name: str):
return self._instance.get_collection(collection_name=collection_name)

def create_non_vector_collection(self, collection_name: str):
self._instance.create_collection(collection_name=collection_name, vectors_config={})
self._instance.create_collection(
collection_name=collection_name, vectors_config={}
)

def create_vector_collection(self, collection_name: str):
self._instance.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=settings.EMBEDDING_SIZE, distance=Distance.COSINE),
vectors_config=VectorParams(
size=settings.EMBEDDING_SIZE, distance=Distance.COSINE
),
)

def write_data(self, collection_name: str, points: Batch):
Expand All @@ -53,6 +58,14 @@ def write_data(self, collection_name: str, points: Batch):
logger.exception("An error occurred while inserting data.")

raise

def search(self, collection_name: str, query_vector: list, query_filter: models.Filter, limit: int) -> list:
return self._instance.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=query_filter,
limit=limit,
)

def scroll(self, collection_name: str, limit: int):
return self._instance.scroll(collection_name=collection_name, limit=limit)
Expand All @@ -62,6 +75,3 @@ def close(self):
self._instance.close()

logger.info("Connected to database has been closed.")


connection = QdrantDatabaseConnector()
64 changes: 36 additions & 28 deletions course/module-3/finetuning/generate_data.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,55 @@
import json
import logging

# sys.path.append(str(Path(__file__).resolve().parent.parent))
from comet_ml import Artifact, Experiment

from db.qdrant import connection as client
import logger_utils
from db.qdrant import QdrantDatabaseConnector
from finetuning.file_handler import FileHandler
from finetuning.llm_communication import GptCommunicator
from settings import settings

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logger_utils.get_logger(__name__)

data_type = "posts"
USER_PROMPT = (
f"I will give you batches of contents of {data_type}. Please generate me exactly 1 instruction for each of them. The {data_type} text "
f"for which you have to generate the instructions is under Content number x lines. Please structure the answer in json format,"
f"ready to be loaded by json.loads(), a list of objects only with fields called instruction and content. For the content field, copy the number of the content only!."
f"Please do not add any extra characters and make sure it is a list with objects in valid json format!\n"
)

client = QdrantDatabaseConnector()


class DataFormatter:
@classmethod
def get_system_prompt(cls, data_type: str) -> str:
return (
f"I will give you batches of contents of {data_type}. Please generate me exactly 1 instruction for each of them. The {data_type} text "
f"for which you have to generate the instructions is under Content number x lines. Please structure the answer in json format,"
f"ready to be loaded by json.loads(), a list of objects only with fields called instruction and content. For the content field, copy the number of the content only!."
f"Please do not add any extra characters and make sure it is a list with objects in valid json format!\n"
)

@classmethod
def format_data(cls, data_points: list, is_example: bool, start_index: int) -> str:
text = ""
for index, data_point in enumerate(data_points):
if not is_example:
text += f"Content number {start_index + index }\n"
text += str(data_point) + "\n"

return text

@classmethod
def format_batch(cls, context_msg: str, data_points: list, start_index: int) -> str:
delimiter_msg = context_msg
delimiter_msg += cls.format_data(data_points, False, start_index)

return delimiter_msg

@classmethod
def format_prompt(cls, inference_posts: list, start_index: int):
initial_prompt = USER_PROMPT
def format_prompt(cls, inference_posts: list, data_type: str, start_index: int) -> str:
initial_prompt = cls.get_system_prompt(data_type)
initial_prompt += f"You must generate exactly a list of {len(inference_posts)} json objects, using the contents provided under CONTENTS FOR GENERATION\n"
initial_prompt += cls.format_batch(
"\nCONTENTS FOR GENERATION: \n", inference_posts, start_index
)

return initial_prompt


Expand All @@ -59,21 +64,21 @@ def __init__(
self.api_communicator = api_communicator
self.data_formatter = data_formatter

def generate_training_data(self, collection_name: str, batch_size: int = 1):
def generate_training_data(self, collection_name: str, data_type: str, batch_size: int = 1):
all_contents = self.fetch_all_cleaned_content(collection_name)
response = []
for i in range(0, len(all_contents), batch_size):
batch = all_contents[i : i + batch_size]
initial_prompt = data_formatter.format_prompt(batch, i)
response += self.api_communicator.send_prompt(initial_prompt)
prompt = data_formatter.format_prompt(batch, data_type, i)
response += self.api_communicator.send_prompt(prompt)
for j in range(i, i + batch_size):
response[j]["content"] = all_contents[j]

self.push_to_comet(response, collection_name)
self.push_to_comet(response, data_type, collection_name)

def push_to_comet(self, data: list, collection_name: str):
def push_to_comet(self, data: list, data_type: str, collection_name: str):
try:
logging.info(f"Starting to push data to Comet: {collection_name}")
logger.info(f"Starting to push data to Comet: {collection_name}")

# Assuming the settings module has been properly configured with the required attributes
experiment = Experiment(
Expand All @@ -88,18 +93,18 @@ def push_to_comet(self, data: list, collection_name: str):
with open(file_name, "w") as f:
json.dump(data, f)

logging.info("Data written to file successfully")
logger.info("Data written to file successfully")

artifact = Artifact(collection_name)
artifact = Artifact(f"{data_type}-instruct-dataset")
artifact.add(file_name)
logging.info(f"Artifact created and file added: {file_name}")
logger.info(f"Artifact created and file added: {file_name}")

experiment.log_artifact(artifact)
experiment.end()
logging.info("Data pushed to Comet successfully and experiment ended")
logger.info("Data pushed to Comet successfully and experiment ended")

except Exception as e:
logging.error(f"Failed to push data to Comet: {e}", exc_info=True)
logger.error(f"Failed to push data to Comet: {e}", exc_info=True)

def fetch_all_cleaned_content(self, collection_name: str) -> list:
all_cleaned_contents = []
Expand All @@ -116,12 +121,15 @@ def fetch_all_cleaned_content(self, collection_name: str) -> list:


if __name__ == "__main__":
collection_names = ["cleaned_articles", "cleaned_posts"]
file_handler = FileHandler()
api_communicator = GptCommunicator()
data_formatter = DataFormatter()
dataset_generator = DatasetGenerator(file_handler, api_communicator, data_formatter)
for collection in collection_names:

collections = [("cleaned_articles", "articles"), ("cleaned_posts", "posts")]
for (collection_name, data_type) in collections:
logger.info("Generating training data.", collection_name=collection_name, data_type=data_type)

dataset_generator.generate_training_data(
collection_name=collection, batch_size=1
collection_name=collection_name, data_type=data_type, batch_size=1
)
18 changes: 11 additions & 7 deletions course/module-3/finetuning/llm_communication.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json
import logging

from openai import OpenAI

from finetuning.exceptions import APICommunicationError
import logger_utils
from settings import settings

MAX_LENGTH = 16384
SYSTEM_PROMPT = "You are a technical writer handing someone's account to post about AI and MLOps."
SYSTEM_PROMPT = (
"You are a technical writer handing someone's account to post about AI and MLOps."
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logger_utils.get_logger(__name__)


class GptCommunicator:
Expand All @@ -20,7 +21,7 @@ def __init__(self, gpt_model: str = "gpt-3.5-turbo"):
def send_prompt(self, prompt: str) -> list:
try:
client = OpenAI(api_key=self.api_key)
logging.info("Sending batch to LLM")
logger.info("Sending batch to LLM")
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
Expand All @@ -30,8 +31,11 @@ def send_prompt(self, prompt: str) -> list:
)
response = chat_completion.choices[0].message.content
return json.loads(self.clean_response(response))
except Exception as e:
logging.error(f"Skipping batch! An error occurred while communicating with API: {e}")
except Exception:
logger.exception(
f"Skipping batch! An error occurred while communicating with API."
)

return []

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion course/module-3/insert_data_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def insert_posts(file_name: str, author_id: str) -> None:


def insert_articles(file_name: str, author_id: str) -> None:
file_name = "/Users/vesaalexandru/Workspaces/decodeML/llm-twin-course/course/module-3/dataset/articles_paul_iusztin.json"
file_name = file_name
try:
with open(file_name, "r") as file:
articles: list[dict] = json.load(file)
Expand Down
Loading

0 comments on commit 2ef453d

Please sign in to comment.