From 30b6cf6a25bfc0096c4fbc72a823ee299d116e93 Mon Sep 17 00:00:00 2001 From: Zion Leonahenahe Basque Date: Mon, 11 Nov 2024 17:24:27 -0700 Subject: [PATCH] Fix a bug in model settings across threads (#64) --- dailalib/__init__.py | 2 +- dailalib/api/litellm/litellm_api.py | 33 +++++++++++++++++++------- dailalib/api/litellm/prompts/prompt.py | 5 ++++ setup.cfg | 2 +- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/dailalib/__init__.py b/dailalib/__init__.py index d20d6e0..48737e4 100644 --- a/dailalib/__init__.py +++ b/dailalib/__init__.py @@ -1,4 +1,4 @@ -__version__ = "3.10.3" +__version__ = "3.10.4" import os # stop LiteLLM from querying at all to the remote server diff --git a/dailalib/api/litellm/litellm_api.py b/dailalib/api/litellm/litellm_api.py index eecda27..660da8d 100644 --- a/dailalib/api/litellm/litellm_api.py +++ b/dailalib/api/litellm/litellm_api.py @@ -5,6 +5,9 @@ from ..ai_api import AIAPI +active_model = None +active_prompt_style = None + class LiteLLMAIAPI(AIAPI): prompts_by_name = [] @@ -14,11 +17,10 @@ class LiteLLMAIAPI(AIAPI): # TODO: update the token values for o1 "o1-mini": 8_000, "o1-preview": 8_000, - "gpt-4-turbo": 128_000, - "gpt-4": 8_000, "gpt-4o": 8_000, - "gpt-3.5-turbo": 4_096, - "claude-2": 200_000, + "gpt-4o-mini": 16_000, + "gpt-4-turbo": 128_000, + "claude-3-5-sonnet-20240620": 200_000, "gemini/gemini-pro": 12_288, "vertex_ai_beta/gemini-pro": 12_288, } @@ -49,6 +51,11 @@ def __init__( prompts = prompts + PROMPTS if prompts else PROMPTS self.prompts_by_name = {p.name: p for p in prompts} + # update the globals (for threading hacks) + global active_model, active_prompt_style + active_model = self.model + active_prompt_style = self.prompt_style + def __dir__(self): return list(super().__dir__()) + list(self.prompts_by_name.keys()) @@ -161,8 +168,9 @@ def ask_prompt_style(self, *args, **kwargs): prompt_style = self.prompt_style style_choices = ALL_STYLES.copy() - style_choices.remove(self.prompt_style) - style_choices = [self.prompt_style] + style_choices + if self.prompt_style: + style_choices.remove(self.prompt_style) + style_choices = [self.prompt_style] + style_choices p_style = self._dec_interface.gui_ask_for_choice( "What prompting style would you like to use?", @@ -177,11 +185,16 @@ def ask_prompt_style(self, *args, **kwargs): self.prompt_style = p_style self._dec_interface.info(f"Prompt style set to {p_style}") + # update global + global active_prompt_style + active_prompt_style = p_style + def ask_model(self, *args, **kwargs): if self._dec_interface is not None: model_choices = list(LiteLLMAIAPI.MODEL_TO_TOKENS.keys()) - model_choices.remove(self.model) - model_choices = [self.model] + model_choices + if self.model: + model_choices.remove(self.model) + model_choices = [self.model] + model_choices model = self._dec_interface.gui_ask_for_choice( "What LLM model would you like to use?", @@ -190,3 +203,7 @@ def ask_model(self, *args, **kwargs): ) self.model = model self._dec_interface.info(f"Model set to {model}") + + # update global + global active_model + active_model = model diff --git a/dailalib/api/litellm/prompts/prompt.py b/dailalib/api/litellm/prompts/prompt.py index 14deba7..13adbf1 100644 --- a/dailalib/api/litellm/prompts/prompt.py +++ b/dailalib/api/litellm/prompts/prompt.py @@ -58,6 +58,11 @@ def query_model(self, *args, context=None, function=None, dec_text=None, use_dec if self.ai_api is None: raise Exception("api must be set before querying!") + # this is a hack to get the active model and prompt style in many threads in IDA Pro + from ..litellm_api import active_model, active_prompt_style + self.ai_api.model = active_model + self.ai_api.prompt_style = active_prompt_style + @AIAPI.requires_function def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kwargs) -> Union[Dict, str]: if not ai_api: diff --git a/setup.cfg b/setup.cfg index 8610688..1bdbe4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ install_requires = litellm>=1.44.27 tiktoken Jinja2 - libbs>=2.2.1 + libbs>=2.6.0 python_requires = >= 3.10 include_package_data = True