-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
146 lines (120 loc) · 6.09 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import warnings
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.vectorstores import DeepLake
from langchain_core.messages import AIMessage
from langchain_core.prompts import PromptTemplate, load_prompt
from langchain_google_genai import ChatGoogleGenerativeAI
from typing import List
from langchain_core.documents.base import Document
class DrakeLM:
def __init__(self, model_path: str, db: DeepLake, config: dict):
"""
Parameters:
model_path (str): The path to the model in case running Llama
db (DeepLake): The DeepLake DB object
config (dict): The configuration for the llama model
Initialize the DrakeLM model
"""
self.gemini = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
self.retriever = db.as_retriever()
self.chat_history = ChatMessageHistory()
self.chat_history.add_user_message("You are assisting a student to understand topics.")
self.notes_prompt = load_prompt("prompt_templates/notes_prompt.yaml")
self.chat_prompt = load_prompt("prompt_templates/chat_prompt.yaml")
def _chat_prompt(self, query: str, context: str) -> (PromptTemplate, str):
"""
Parameters:
query (str): The question asked by the user
context (str): The context retrieved from the DB
Returns:
PromptTemplate: The prompt template for the chat
prompt (str): The prompt string for the chat
Create the chat prompt for the LLM model
"""
prompt = """You are assisting a student to understand topics. \n\n
You have to answer the below question by utilising the below context to answer the question. \n\n
Note to follow the rules given below \n\n
Question: {query} \n\n
Context: {context} \n\n
Rules: {rules} \n\n
Answer:
"""
rules = """
- If the question says answer for X number of marks, you have to provide X number of points.
- Each point has to be explained in 3-4 sentences.
- In case the context express a mathematical equation, provide the equation in LaTeX format as shown in the example.
- In case the user requests for a code snippet, provide the code snippet in the language specified in the example.
- If the user requests to summarise or use the previous message as context ignoring the explicit context given in the message.
"""
prompt = prompt.format(query=query, context=context, rules=rules)
return PromptTemplate.from_template(prompt), prompt
def _retrieve(self, query: str, metadata_filter, k=3, distance_metric="cos") -> str:
"""
Parameters:
query (str): The question asked by the user
metadata_filter (dict): The metadata filter for the DB
k (int): The number of documents to retrieve
distance_metric (str): The distance metric for retrieval
Returns:
str: The context retrieved from the DB
Retrieve the context from the DB
"""
self.retriever.search_kwargs["distance_metric"] = distance_metric
self.retriever.search_kwargs["k"] = k
if metadata_filter:
self.retriever.search_kwargs["filter"] = {
"metadata": {
"id": metadata_filter["id"]
}
}
retrieved_docs = self.retriever.get_relevant_documents(query)
context = ""
for rd in retrieved_docs:
context += "\n" + rd.page_content
return context
def ask_llm(self, query: str, metadata_filter: dict = None) -> str:
"""
Parameters:
query (str): The question asked by the user
metadata_filter (dict): The metadata filter for the DB
Returns:
str: The response from the LLM model
Ask the LLM model a question
"""
warnings.filterwarnings("ignore", message="Convert_system_message_to_human will be deprecated!")
context = self._retrieve(query, metadata_filter)
print("Retrieved context")
prompt_template, prompt_string = self._chat_prompt(query, context)
self.chat_history.add_user_message(prompt_string)
print("Generating response...")
rules = """
- If the question says answer for X number of marks, you have to provide X number of points.
- Each point has to be explained in 3-4 sentences.
- In case the context express a mathematical equation, provide the equation in LaTeX format as shown in the example.
- In case the user requests for a code snippet, provide the code snippet in the language specified in the example.
- If the user requests to summarise or use the previous message as context ignoring the explicit context given in the message.
"""
prompt_template = self.chat_prompt.format(query=query, context=context, rules=rules)
self.chat_history.add_ai_message(AIMessage(content=self.gemini.invoke(prompt_template).content))
return self.chat_history.messages[-1].content
def create_notes(self, documents: List[Document]) -> str:
"""
Parameters:
documents (List[Document]): The list of documents to create notes from
Returns:
str: The notes generated from the LLM model
Create notes from the LLM model
"""
rules = """
- Follow the Markdown format for creating notes as shown in the example.
- The heading of the content should be the title of the markdown file.
- Create subheadings for each section.
- Use numbered bullet points for each point.
"""
warnings.filterwarnings("ignore", message="Convert_system_message_to_human will be deprecated!")
notes_chunk = []
for doc in documents:
prompt = self.notes_prompt.format(content_chunk=doc.page_content, rules=rules)
response = self.gemini.invoke(prompt)
notes_chunk.append(response.content)
return '\n'.join(notes_chunk)