Skip to content

Commit

Permalink
Add Mixtral 8x22B
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Apr 10, 2024
1 parent 84f66a7 commit e3272fb
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 25 deletions.
52 changes: 36 additions & 16 deletions chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ModelConfig(BaseModel):

@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
if v not in ["gpt", "mistral", "gemini"]:
if v not in ["gpt", "mixtral8x22b", "claude", "mixtral8x7b"]:
raise ValueError(f"Unsupported model type: {v}")
return v

Expand All @@ -52,23 +52,26 @@ def __init__(self, config: ModelConfig):
def setup(self):
if self.model_type == "gpt":
self.setup_gpt()
elif self.model_type == "gemini":
self.setup_gemini()
elif self.model_type == "mistral":
self.setup_mixtral()
elif self.model_type == "claude":
self.setup_claude()
elif self.model_type == "mixtral8x7b":
self.setup_mixtral_8x7b()
elif self.model_type == "mixtral8x22b":
self.setup_mixtral_8x22b()


def setup_gpt(self):
self.llm = ChatOpenAI(
model_name="gpt-3.5-turbo-0125",
model_name="gpt-3.5-turbo",
temperature=0.2,
api_key=self.secrets["OPENAI_API_KEY"],
max_tokens=1000,
callbacks=[self.callback_handler],
streaming=True,
base_url=self.gateway_url,
# base_url=self.gateway_url,
)

def setup_mixtral(self):
def setup_mixtral_8x7b(self):
self.llm = ChatOpenAI(
model_name="mixtral-8x7b-32768",
temperature=0.2,
Expand All @@ -79,12 +82,27 @@ def setup_mixtral(self):
base_url="https://api.groq.com/openai/v1",
)

def setup_gemini(self):
def setup_claude(self):
self.llm = ChatOpenAI(
model_name="google/gemini-pro",
temperature=0.2,
model_name="anthropic/claude-3-haiku",
temperature=0.1,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=700,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://openrouter.ai/api/v1",
default_headers={
"HTTP-Referer": "https://snowchat.streamlit.app/",
"X-Title": "Snowchat",
},
)

def setup_mixtral_8x22b(self):
self.llm = ChatOpenAI(
model_name="mistralai/mixtral-8x22b",
temperature=0.1,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=1200,
max_tokens=700,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://openrouter.ai/api/v1",
Expand Down Expand Up @@ -133,10 +151,12 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):

if "GPT-3.5" in model_name:
model_type = "gpt"
elif "mistral" in model_name.lower():
model_type = "mistral"
elif "gemini" in model_name.lower():
model_type = "gemini"
elif "mixtral 8x7b" in model_name.lower():
model_type = "mixtral8x7b"
elif "claude" in model_name.lower():
model_type = "claude"
elif "mixtral 8x22b" in model_name.lower():
model_type = "mixtral8x22b"
else:
raise ValueError(f"Unsupported model name: {model_name}")

Expand Down
15 changes: 13 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
st.caption("Talk your way through data")
model = st.radio(
"",
options=["GPT-3.5 - OpenAI", "Gemini 1.5 - Openrouter", "Mistral 8x7B - Groq"],
options=["Claude-3 Haiku", "Mixtral 8x7B", "Mixtral 8x22B", "GPT-3.5"],
index=0,
horizontal=True,
)
Expand All @@ -43,12 +43,20 @@
if "toast_shown" not in st.session_state:
st.session_state["toast_shown"] = False

if "rate-limit" not in st.session_state:
st.session_state["rate-limit"] = False

# Show the toast only if it hasn't been shown before
if not st.session_state["toast_shown"]:
st.toast("The snowflake data retrieval is disabled for now.", icon="👋")
st.session_state["toast_shown"] = True

if st.session_state["model"] == "👑 Mistral 8x7B - Groq":
# Show a warning if the model is rate-limited
if st.session_state['rate-limit']:
st.toast("Probably rate limited.. Go easy folks", icon="⚠️")
st.session_state['rate-limit'] = False

if st.session_state["model"] == "Mixtral 8x7B":
st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️")

INITIAL_MESSAGE = [
Expand Down Expand Up @@ -173,6 +181,9 @@ def execute_sql(query, conn, retries=2):
)
append_message(result.content)

if st.session_state["model"] == "Mixtral 8x7B" and st.session_state['messages'][-1]['content'] == "":
st.session_state['rate-limit'] = True

# if get_sql(result):
# conn = SnowflakeConnection().get_session()
# df = execute_sql(get_sql(result), conn)
Expand Down
13 changes: 11 additions & 2 deletions template.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries.
Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema.
(CONTEXT IS NOT KNOWN TO USER) it is provided to you as a reference to generate SQL code.
Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code based on the Context provided. Make sure that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema.
**You are only required to write one SQL query per question.**
If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries.
Expand All @@ -28,7 +30,14 @@
Write your response in markdown format.
User: {question}
Do not worry about access to the database or the schema details. The context provided is sufficient to generate the SQL code. The Sql code is not expected to run on any database.
User Question: \n {question}
\n
Context - (Schema Details):
\n
{context}
Assistant:
Expand Down
2 changes: 1 addition & 1 deletion ui/sidebar.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ SnowChat is an intuitive and user-friendly application that allows you to intera

Here are some example queries you can try with SnowChat:

- Show me the total revenue for each product category.
- Write SQL code to show me the total revenue for each product category.
- Who are the top 10 customers by sales?
- What is the average order value for each region?
- How many orders were placed last week?
Expand Down
11 changes: 7 additions & 4 deletions utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
image_url
+ "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-03-01T07%3A41%3A47.586Z"
)

user_url = image_url + "cat-with-sunglasses.png"
claude_url = image_url + "Claude.png?t=2024-03-13T23%3A47%3A16.824Z"

def get_model_url(model_name):
if "gpt" in model_name.lower():
return openai_url
elif "claude" in model_name.lower():
return claude_url
elif "mixtral" in model_name.lower():
return mistral_url
elif "gemini" in model_name.lower():
return gemini_url
elif "mistral" in model_name.lower():
return mistral_url
return mistral_url


Expand Down Expand Up @@ -57,7 +60,7 @@ def message_func(text, is_user=False, is_df=False, model="gpt"):

avatar_url = model_url
if is_user:
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=ShortHairShortFlat&accessoriesType=Prescription01&hairColor=Auburn&facialHairType=BeardLight&facialHairColor=Black&clotheType=Hoodie&clotheColor=PastelBlue&eyeType=Squint&eyebrowType=DefaultNatural&mouthType=Smile&skinColor=Tanned"
avatar_url = user_url
message_alignment = "flex-end"
message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)"
avatar_class = "user-avatar"
Expand Down

0 comments on commit e3272fb

Please sign in to comment.