Skip to content

Commit

Permalink
update models
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Aug 7, 2024
1 parent 6405509 commit 4e43954
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 54 deletions.
3 changes: 1 addition & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
"titleBar.activeBackground": "#51103e",
"titleBar.activeForeground": "#e7e7e7",
"titleBar.inactiveBackground": "#51103e99",
"titleBar.inactiveForeground": "#e7e7e799",
"tab.activeBorder": "#7c185f"
"titleBar.inactiveForeground": "#e7e7e799"
},
"peacock.color": "#51103e"
}
94 changes: 51 additions & 43 deletions chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_anthropic import ChatAnthropic

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

Expand All @@ -31,13 +32,6 @@ class ModelConfig(BaseModel):
secrets: Dict[str, Any]
callback_handler: Optional[Callable] = None

@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
valid_model_types = ["qwen", "llama", "claude", "mixtral8x7b", "arctic"]
if v not in valid_model_types:
raise ValueError(f"Unsupported model type: {v}")
return v


class ModelWrapper:
def __init__(self, config: ModelConfig):
Expand All @@ -48,47 +42,61 @@ def __init__(self, config: ModelConfig):

def _setup_llm(self):
model_config = {
"qwen": {
"model_name": "qwen/qwen-2-72b-instruct",
"api_key": self.secrets["OPENROUTER_API_KEY"],
"base_url": "https://openrouter.ai/api/v1",
},
"claude": {
"model_name": "anthropic/claude-3-haiku",
"api_key": self.secrets["OPENROUTER_API_KEY"],
"base_url": "https://openrouter.ai/api/v1",
"gpt-4o-mini": {
"model_name": "gpt-4o-mini",
"api_key": self.secrets["OPENAI_API_KEY"],
},
"mixtral8x7b": {
"model_name": "mixtral-8x7b-32768",
"gemma2-9b": {
"model_name": "gemma2-9b-it",
"api_key": self.secrets["GROQ_API_KEY"],
"base_url": "https://api.groq.com/openai/v1",
},
"llama": {
"model_name": "meta-llama/llama-3-70b-instruct",
"api_key": self.secrets["OPENROUTER_API_KEY"],
"base_url": "https://openrouter.ai/api/v1",
"claude3-haiku": {
"model_name": "claude-3-haiku-20240307",
"api_key": self.secrets["ANTHROPIC_API_KEY"],
},
"mixtral-8x22b": {
"model_name": "accounts/fireworks/models/mixtral-8x22b-instruct",
"api_key": self.secrets["FIREWORKS_API_KEY"],
"base_url": "https://api.fireworks.ai/inference/v1",
},
"arctic": {
"model_name": "snowflake/snowflake-arctic-instruct",
"api_key": self.secrets["OPENROUTER_API_KEY"],
"base_url": "https://openrouter.ai/api/v1",
"llama-3.1-405b": {
"model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct",
"api_key": self.secrets["FIREWORKS_API_KEY"],
"base_url": "https://api.fireworks.ai/inference/v1",
},
}

config = model_config[self.model_type]

return ChatOpenAI(
model_name=config["model_name"],
temperature=0.1,
api_key=config["api_key"],
max_tokens=700,
callbacks=[self.callback_handler],
streaming=True,
base_url=config["base_url"],
default_headers={
"HTTP-Referer": "https://snowchat.streamlit.app/",
"X-Title": "Snowchat",
},
return (
ChatOpenAI(
model_name=config["model_name"],
temperature=0.1,
api_key=config["api_key"],
max_tokens=700,
callbacks=[self.callback_handler],
streaming=True,
base_url=config["base_url"]
if config["model_name"] != "gpt-4o-mini"
else None,
default_headers={
"HTTP-Referer": "https://snowchat.streamlit.app/",
"X-Title": "Snowchat",
},
)
if config["model_name"] != "claude-3-haiku-20240307"
else (
ChatAnthropic(
model=config["model_name"],
temperature=0.1,
max_tokens=700,
timeout=None,
max_retries=2,
callbacks=[self.callback_handler],
streaming=True,
)
)
)

def get_chain(self, vectorstore):
Expand Down Expand Up @@ -129,11 +137,11 @@ def load_chain(model_name="qwen", callback_handler=None):
)

model_type_mapping = {
"qwen 2-72b": "qwen",
"mixtral 8x7b": "mixtral8x7b",
"claude-3 haiku": "claude",
"llama 3-70b": "llama",
"snowflake arctic": "arctic",
"gpt-4o-mini": "gpt-4o-mini",
"gemma2-9b": "gemma2-9b",
"claude3-haiku": "claude3-haiku",
"mixtral-8x22b": "mixtral-8x22b",
"llama-3.1-405b": "llama-3.1-405b",
}

model_type = model_type_mapping.get(model_name.lower())
Expand Down
20 changes: 12 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,19 @@
st.markdown(gradient_text_html, unsafe_allow_html=True)

st.caption("Talk your way through data")

model_options = {
"gpt-4o-mini": "GPT-4o Mini",
"llama-3.1-405b": "Llama 3.1 405B",
"gemma2-9b": "Gemma 2 9B",
"claude3-haiku": "Claude 3 Haiku",
"mixtral-8x22b": "Mixtral 8x22B",
}

model = st.radio(
"",
options=[
"Claude-3 Haiku",
"Mixtral 8x7B",
"Llama 3-70B",
"Qwen 2-72B",
"Snowflake Arctic",
],
"Choose your AI Model:",
options=list(model_options.keys()),
format_func=lambda x: model_options[x],
index=0,
horizontal=True,
)
Expand Down
4 changes: 3 additions & 1 deletion utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ def get_model_url(model_name):
return claude_url
elif "llama" in model_name.lower():
return meta_url
elif "gemini" in model_name.lower():
elif "gemma" in model_name.lower():
return gemini_url
elif "arctic" in model_name.lower():
return snow_url
elif "gpt" in model_name.lower():
return openai_url
return mistral_url


Expand Down

0 comments on commit 4e43954

Please sign in to comment.