Skip to content

Commit

Permalink
Fix a bug in model settings across threads (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
mahaloz authored Nov 12, 2024
1 parent 0f98eca commit 30b6cf6
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
2 changes: 1 addition & 1 deletion dailalib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.10.3"
__version__ = "3.10.4"

import os
# stop LiteLLM from querying at all to the remote server
Expand Down
33 changes: 25 additions & 8 deletions dailalib/api/litellm/litellm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from ..ai_api import AIAPI

active_model = None
active_prompt_style = None


class LiteLLMAIAPI(AIAPI):
prompts_by_name = []
Expand All @@ -14,11 +17,10 @@ class LiteLLMAIAPI(AIAPI):
# TODO: update the token values for o1
"o1-mini": 8_000,
"o1-preview": 8_000,
"gpt-4-turbo": 128_000,
"gpt-4": 8_000,
"gpt-4o": 8_000,
"gpt-3.5-turbo": 4_096,
"claude-2": 200_000,
"gpt-4o-mini": 16_000,
"gpt-4-turbo": 128_000,
"claude-3-5-sonnet-20240620": 200_000,
"gemini/gemini-pro": 12_288,
"vertex_ai_beta/gemini-pro": 12_288,
}
Expand Down Expand Up @@ -49,6 +51,11 @@ def __init__(
prompts = prompts + PROMPTS if prompts else PROMPTS
self.prompts_by_name = {p.name: p for p in prompts}

# update the globals (for threading hacks)
global active_model, active_prompt_style
active_model = self.model
active_prompt_style = self.prompt_style

def __dir__(self):
return list(super().__dir__()) + list(self.prompts_by_name.keys())

Expand Down Expand Up @@ -161,8 +168,9 @@ def ask_prompt_style(self, *args, **kwargs):

prompt_style = self.prompt_style
style_choices = ALL_STYLES.copy()
style_choices.remove(self.prompt_style)
style_choices = [self.prompt_style] + style_choices
if self.prompt_style:
style_choices.remove(self.prompt_style)
style_choices = [self.prompt_style] + style_choices

p_style = self._dec_interface.gui_ask_for_choice(
"What prompting style would you like to use?",
Expand All @@ -177,11 +185,16 @@ def ask_prompt_style(self, *args, **kwargs):
self.prompt_style = p_style
self._dec_interface.info(f"Prompt style set to {p_style}")

# update global
global active_prompt_style
active_prompt_style = p_style

def ask_model(self, *args, **kwargs):
if self._dec_interface is not None:
model_choices = list(LiteLLMAIAPI.MODEL_TO_TOKENS.keys())
model_choices.remove(self.model)
model_choices = [self.model] + model_choices
if self.model:
model_choices.remove(self.model)
model_choices = [self.model] + model_choices

model = self._dec_interface.gui_ask_for_choice(
"What LLM model would you like to use?",
Expand All @@ -190,3 +203,7 @@ def ask_model(self, *args, **kwargs):
)
self.model = model
self._dec_interface.info(f"Model set to {model}")

# update global
global active_model
active_model = model
5 changes: 5 additions & 0 deletions dailalib/api/litellm/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def query_model(self, *args, context=None, function=None, dec_text=None, use_dec
if self.ai_api is None:
raise Exception("api must be set before querying!")

# this is a hack to get the active model and prompt style in many threads in IDA Pro
from ..litellm_api import active_model, active_prompt_style
self.ai_api.model = active_model
self.ai_api.prompt_style = active_prompt_style

@AIAPI.requires_function
def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kwargs) -> Union[Dict, str]:
if not ai_api:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ install_requires =
litellm>=1.44.27
tiktoken
Jinja2
libbs>=2.2.1
libbs>=2.6.0

python_requires = >= 3.10
include_package_data = True
Expand Down

0 comments on commit 30b6cf6

Please sign in to comment.