From 709b3344342239314659ba472b769d2774391d8f Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Thu, 8 Feb 2024 18:17:14 +0530 Subject: [PATCH] Add mistral medium --- README.md | 6 +++--- chain.py | 14 +++++++------- main.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index accad71..afe0593 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,9 @@ ## Supported LLM's -- GPT-3.5-turbo-16k -- Claude-instant-v1 -- Mixtral 8x7B +- GPT-3.5-turbo-0125 +- CodeLlama-70B +- Mistral Medium # diff --git a/chain.py b/chain.py index b5b64df..7505f36 100644 --- a/chain.py +++ b/chain.py @@ -34,7 +34,7 @@ class ModelConfig(BaseModel): @validator("model_type", pre=True, always=True) def validate_model_type(cls, v): - if v not in ["gpt", "codellama", "mixtral"]: + if v not in ["gpt", "codellama", "mistral"]: raise ValueError(f"Unsupported model type: {v}") return v @@ -55,7 +55,7 @@ def setup(self): self.setup_gpt() elif self.model_type == "codellama": self.setup_codellama() - elif self.model_type == "mixtral": + elif self.model_type == "mistral": self.setup_mixtral() def setup_gpt(self): @@ -71,13 +71,13 @@ def setup_gpt(self): def setup_mixtral(self): self.llm = ChatOpenAI( - model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", + model_name="mistralai/mistral-medium", temperature=0.2, - api_key=self.secrets["MIXTRAL_API_KEY"], + api_key=self.secrets["OPENROUTER_API_KEY"], max_tokens=500, callbacks=[self.callback_handler], streaming=True, - base_url="https://api.together.xyz/v1", + base_url="https://openrouter.ai/api/v1", ) def setup_codellama(self): @@ -157,8 +157,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None): model_type = "codellama" elif "GPT-3.5" in model_name: model_type = "gpt" - elif "mixtral" in model_name.lower(): - model_type = "mixtral" + elif "mistral" in model_name.lower(): + model_type = "mistral" else: raise ValueError(f"Unsupported model name: {model_name}") diff --git a/main.py b/main.py index d803b84..157e63a 100644 --- a/main.py +++ b/main.py @@ -18,7 +18,7 @@ st.caption("Talk your way through data") model = st.radio( "", - options=["✨ GPT-3.5", "♾️ codellama", "⛰️ Mixtral"], + options=["✨ GPT-3.5", "♾️ codellama", "👑 Mistral"], index=0, horizontal=True, )