Skip to content

Commit

Permalink
Update to new openai api
Browse files Browse the repository at this point in the history
  • Loading branch information
mahaloz committed Nov 15, 2023
1 parent d0a4f5f commit 7e4b7f7
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 17 deletions.
2 changes: 1 addition & 1 deletion dailalib/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.2"
__version__ = "1.3.0"
8 changes: 6 additions & 2 deletions dailalib/binsync_plugin/ai_bs_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,20 @@ def _function_is_large_enough(self, func: Function):
def commit_ai_changes_to_state(self, state: State, decompiled_functions):
ai_initiated_changes = 0
update_cnt = 0
round_updates = 0
for func_addr, (decompilation, func) in tqdm(decompiled_functions.items(), desc=f"Querying AI for {len(decompiled_functions)} funcs..."):
ai_initiated_changes += self.run_all_ai_commands_for_dec(decompilation, func, state)
round_changes = self.run_all_ai_commands_for_dec(decompilation, func, state)
ai_initiated_changes += round_changes
if ai_initiated_changes:
update_cnt += 1
round_updates += round_changes

if update_cnt >= 1:
update_cnt = 0
self.controller.client.commit_state(state, msg="AI Initiated change to functions")
self.controller.client.push()
_l.info(f"Pushed some changes to user {self.username}...")
_l.info(f"Pushed {round_updates} changes to user {self.username}...")
round_updates = 0

return ai_initiated_changes

Expand Down
3 changes: 2 additions & 1 deletion dailalib/binsync_plugin/openai_bs_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def run_all_ai_commands_for_dec(self, decompilation: str, func: Function, state:

try:
resp = self.ai_interface.query_for_cmd(cmd, decompilation=decompilation)
except Exception:
except Exception as e:
_l.error(f"Failed to query for cmd {cmd} with error {e}")
continue

if not resp:
Expand Down
25 changes: 12 additions & 13 deletions dailalib/interfaces/openai_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
from functools import wraps

import openai
from openai import OpenAI
import tiktoken

from .generic_ai_interface import GenericAIInterface
Expand Down Expand Up @@ -97,7 +97,8 @@ def __init__(self, openai_api_key=None, model=DEFAULT_MODEL, decompiler_controll
self._register_menu_item(menu_str, callback_str, callback_func)

self._api_key = os.getenv("OPENAI_API_KEY") or openai_api_key
openai.api_key = self._api_key
self._openai_client = OpenAI(api_key=self._api_key)


@property
def api_key(self):
Expand All @@ -106,7 +107,7 @@ def api_key(self):
@api_key.setter
def api_key(self, data):
self._api_key = data
openai.api_key = self._api_key


#
# OpenAI Interface
Expand All @@ -123,21 +124,19 @@ def _query_openai_model(
):
# TODO: at some point add back frequency_penalty and presence_penalty to be used
try:
response = openai.ChatCompletion.create(
model=model or self.model,
messages=[
{"role": "user", "content": question}
],
max_tokens=max_tokens,
timeout=60,
stop=['}'],
)
response = self._openai_client.chat.completions.create(model=model or self.model,
messages=[
{"role": "user", "content": question}
],
max_tokens=max_tokens,
timeout=60,
stop=['}'])
except openai.OpenAIError as e:
raise Exception(f"ChatGPT could not complete the request: {str(e)}")

answer = None
try:
answer = response.choices[0]["message"]["content"]
answer = response.choices[0].message.content
except (KeyError, IndexError) as e:
pass

Expand Down

0 comments on commit 7e4b7f7

Please sign in to comment.