From 4e43954c9c6fa2d1cfbf537fc45a28925921c831 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Wed, 7 Aug 2024 14:46:42 +1200 Subject: [PATCH] update models --- .vscode/settings.json | 3 +- chain.py | 94 +++++++++++++++++++++++-------------------- main.py | 20 +++++---- utils/snowchat_ui.py | 4 +- 4 files changed, 67 insertions(+), 54 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 9068edb..48d8756 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -20,8 +20,7 @@ "titleBar.activeBackground": "#51103e", "titleBar.activeForeground": "#e7e7e7", "titleBar.inactiveBackground": "#51103e99", - "titleBar.inactiveForeground": "#e7e7e799", - "tab.activeBorder": "#7c185f" + "titleBar.inactiveForeground": "#e7e7e799" }, "peacock.color": "#51103e" } \ No newline at end of file diff --git a/chain.py b/chain.py index 1896620..a16b862 100644 --- a/chain.py +++ b/chain.py @@ -18,6 +18,7 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_anthropic import ChatAnthropic DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") @@ -31,13 +32,6 @@ class ModelConfig(BaseModel): secrets: Dict[str, Any] callback_handler: Optional[Callable] = None - @validator("model_type", pre=True, always=True) - def validate_model_type(cls, v): - valid_model_types = ["qwen", "llama", "claude", "mixtral8x7b", "arctic"] - if v not in valid_model_types: - raise ValueError(f"Unsupported model type: {v}") - return v - class ModelWrapper: def __init__(self, config: ModelConfig): @@ -48,47 +42,61 @@ def __init__(self, config: ModelConfig): def _setup_llm(self): model_config = { - "qwen": { - "model_name": "qwen/qwen-2-72b-instruct", - "api_key": self.secrets["OPENROUTER_API_KEY"], - "base_url": "https://openrouter.ai/api/v1", - }, - "claude": { - "model_name": "anthropic/claude-3-haiku", - "api_key": self.secrets["OPENROUTER_API_KEY"], - "base_url": "https://openrouter.ai/api/v1", + "gpt-4o-mini": { + "model_name": "gpt-4o-mini", + "api_key": self.secrets["OPENAI_API_KEY"], }, - "mixtral8x7b": { - "model_name": "mixtral-8x7b-32768", + "gemma2-9b": { + "model_name": "gemma2-9b-it", "api_key": self.secrets["GROQ_API_KEY"], "base_url": "https://api.groq.com/openai/v1", }, - "llama": { - "model_name": "meta-llama/llama-3-70b-instruct", - "api_key": self.secrets["OPENROUTER_API_KEY"], - "base_url": "https://openrouter.ai/api/v1", + "claude3-haiku": { + "model_name": "claude-3-haiku-20240307", + "api_key": self.secrets["ANTHROPIC_API_KEY"], + }, + "mixtral-8x22b": { + "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct", + "api_key": self.secrets["FIREWORKS_API_KEY"], + "base_url": "https://api.fireworks.ai/inference/v1", }, - "arctic": { - "model_name": "snowflake/snowflake-arctic-instruct", - "api_key": self.secrets["OPENROUTER_API_KEY"], - "base_url": "https://openrouter.ai/api/v1", + "llama-3.1-405b": { + "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct", + "api_key": self.secrets["FIREWORKS_API_KEY"], + "base_url": "https://api.fireworks.ai/inference/v1", }, } config = model_config[self.model_type] - return ChatOpenAI( - model_name=config["model_name"], - temperature=0.1, - api_key=config["api_key"], - max_tokens=700, - callbacks=[self.callback_handler], - streaming=True, - base_url=config["base_url"], - default_headers={ - "HTTP-Referer": "https://snowchat.streamlit.app/", - "X-Title": "Snowchat", - }, + return ( + ChatOpenAI( + model_name=config["model_name"], + temperature=0.1, + api_key=config["api_key"], + max_tokens=700, + callbacks=[self.callback_handler], + streaming=True, + base_url=config["base_url"] + if config["model_name"] != "gpt-4o-mini" + else None, + default_headers={ + "HTTP-Referer": "https://snowchat.streamlit.app/", + "X-Title": "Snowchat", + }, + ) + if config["model_name"] != "claude-3-haiku-20240307" + else ( + ChatAnthropic( + model=config["model_name"], + temperature=0.1, + max_tokens=700, + timeout=None, + max_retries=2, + callbacks=[self.callback_handler], + streaming=True, + ) + ) ) def get_chain(self, vectorstore): @@ -129,11 +137,11 @@ def load_chain(model_name="qwen", callback_handler=None): ) model_type_mapping = { - "qwen 2-72b": "qwen", - "mixtral 8x7b": "mixtral8x7b", - "claude-3 haiku": "claude", - "llama 3-70b": "llama", - "snowflake arctic": "arctic", + "gpt-4o-mini": "gpt-4o-mini", + "gemma2-9b": "gemma2-9b", + "claude3-haiku": "claude3-haiku", + "mixtral-8x22b": "mixtral-8x22b", + "llama-3.1-405b": "llama-3.1-405b", } model_type = model_type_mapping.get(model_name.lower()) diff --git a/main.py b/main.py index e193f6e..1a491be 100644 --- a/main.py +++ b/main.py @@ -32,15 +32,19 @@ st.markdown(gradient_text_html, unsafe_allow_html=True) st.caption("Talk your way through data") + +model_options = { + "gpt-4o-mini": "GPT-4o Mini", + "llama-3.1-405b": "Llama 3.1 405B", + "gemma2-9b": "Gemma 2 9B", + "claude3-haiku": "Claude 3 Haiku", + "mixtral-8x22b": "Mixtral 8x22B", +} + model = st.radio( - "", - options=[ - "Claude-3 Haiku", - "Mixtral 8x7B", - "Llama 3-70B", - "Qwen 2-72B", - "Snowflake Arctic", - ], + "Choose your AI Model:", + options=list(model_options.keys()), + format_func=lambda x: model_options[x], index=0, horizontal=True, ) diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 11ed874..03f5f58 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -29,10 +29,12 @@ def get_model_url(model_name): return claude_url elif "llama" in model_name.lower(): return meta_url - elif "gemini" in model_name.lower(): + elif "gemma" in model_name.lower(): return gemini_url elif "arctic" in model_name.lower(): return snow_url + elif "gpt" in model_name.lower(): + return openai_url return mistral_url