From acd43d0797e77febaeecefd946e1cf30568793f6 Mon Sep 17 00:00:00 2001 From: Zion Leonahenahe Basque Date: Mon, 4 Dec 2023 16:29:52 -0700 Subject: [PATCH] Major Refactor: Restructure backend to depend on LibBS (#22) * Running script in decompiler stubs * More changes * Working abstraction tested in IDA. Still need an `askKey` api * Fixed the Binja side code * Add some headless Ghidra code * Ghidra, IDA, Binja working. Ghidra still needs a way to start UI without terminal * Update the README, and GHIDRA WORKS!!! --- .gitignore | 44 +++ README.md | 93 +++-- dailalib/__init__.py | 32 +- dailalib/__main__.py | 26 +- dailalib/api/__init__.py | 2 + dailalib/api/ai_api.py | 99 ++++++ dailalib/api/openai/__init__.py | 2 + dailalib/api/openai/openai_api.py | 106 ++++++ dailalib/api/openai/prompt.py | 84 +++++ dailalib/api/openai/prompts.py | 134 +++++++ dailalib/binsync_plugin/ai_bs_user.py | 4 +- dailalib/binsync_plugin/openai_bs_user.py | 4 +- dailalib/controller_server.py | 122 ------- dailalib/daila_plugin.py | 47 +++ dailalib/installer.py | 74 ++-- dailalib/interfaces/__init__.py | 7 - dailalib/interfaces/generic_ai_interface.py | 27 -- dailalib/interfaces/openai_interface.py | 366 -------------------- dailalib/utils.py | 2 - plugins/daila_binja.py | 176 ---------- plugins/daila_ghidra.py | 184 ---------- plugins/daila_ida.py | 229 ------------ pyproject.toml | 18 +- setup.cfg | 32 -- setup.py | 52 --- 25 files changed, 666 insertions(+), 1300 deletions(-) create mode 100644 .gitignore create mode 100644 dailalib/api/__init__.py create mode 100644 dailalib/api/ai_api.py create mode 100644 dailalib/api/openai/__init__.py create mode 100644 dailalib/api/openai/openai_api.py create mode 100644 dailalib/api/openai/prompt.py create mode 100644 dailalib/api/openai/prompts.py delete mode 100644 dailalib/controller_server.py create mode 100644 dailalib/daila_plugin.py delete mode 100644 dailalib/interfaces/__init__.py delete mode 100644 dailalib/interfaces/generic_ai_interface.py delete mode 100644 dailalib/interfaces/openai_interface.py delete mode 100644 dailalib/utils.py delete mode 100644 plugins/daila_binja.py delete mode 100644 plugins/daila_ghidra.py delete mode 100644 plugins/daila_ida.py delete mode 100644 setup.cfg delete mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..58270b7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,44 @@ +*.pyc +.idea/ +*.egg-info/ +testing/ +*.o +*.so +*.a +.gdb_history +*.i64 +*.idb +*.id0 +*.id1 +*.id2 +*.nam +*.til +*.swp +*.dll +*.obj +*.lib +*.exp +*.pdb +*.ilk +angr/tests/*.png +screenlog.0 +angr/tests/screenlog.0 +angr/screenlog.0 +.idea +*.egg-info +/build +/tags +MANIFEST +dist +.eggs +.vscode/ +*.db +.DS_Store +.pytest_cache/ +binsync/decompilers/ghidra/client/build/ +binsync/decompilers/ghidra/client/dist/ +binsync/decompilers/ghidra/client/bin/ +binsync/decompilers/ghidra/client/.gradle/ +binsync/decompilers/ghidra/client/.classpath +binsync/decompilers/ghidra/client/.project +binsync/decompilers/ghidra/client/Ghidra/ diff --git a/README.md b/README.md index e7c6e8f..83f0e84 100644 --- a/README.md +++ b/README.md @@ -1,44 +1,73 @@ # DAILA -Decompiler Artificial Intelligence Language Assistant - Built on OpenAI. -Utilize OpenAI to improve your decompilation experience in most modern decompilers. +The Decompiler Artificial Intelligence Language Assistant (DAILA) is a unified interface for AI systems to be used in decompilers. +Using DAILA, you can utilize various AI systems, like local and remote LLMs, all in the same scripting and GUI interfaces. -![](./assets/ida_daila.png) +DAILA context menu + +DAILA's main purpose is to provide a unified interface for AI systems to be used in decompilers. +To accomplish this, DAILA provides a lifted interface, relying on the BinSync library [LibBS](https://github.com/binsync/libbs) to abstract away the decompiler. +**All decompilers supported in LibBS are supported in DAILA, which currently includes IDA, Ghidra, Binja, and angr-management.** ## Installation -Clone down this repo and pip install and use the daila installer: +Install our library backend through pip and our decompiler plugin through our installer: ```bash -pip3 install -e . && dailalib --install +pip3 install dailalib && daila --install ``` -Depending on your decompiler, this will attempt to copy the script files into your decompiler and install -the DAILA core to your current Python. If you are using Binja or IDA, make sure your Python is the same -as the one you are using in your decompiler. - -If you are using Ghidra, you may be required to enable the `$USER_HOME/ghidra_scripts` as a valid -scripts path. +### Ghidra Extras +You need to do a few extra steps to get Ghidra working. +Next, enable the DAILA plugin: +1. Start Ghidra and open a binary +2. Goto the `Windows > Script Manager` menu +3. Search for `daila` and enable the script -If your decompiler does not have access to the `OPENAI_API_KEY`, then you must use the decompiler option from -DAILA to set the API key. A popup will appear for you to enter your key. +You must have `python3` in your path for the Ghidra version to work. We quite literally call it from inside Python 2. +You may also need to enable the `$USER_HOME/ghidra_scripts` as a valid scripts path in Ghidra. -### Manual Install +### Manual Install (if above fails) If the above fails, you will need to manually install. -To manually install, first `pip3 install -e .` on the repo, then copy the python file for your decompiler in your -decompilers plugins/scripts folder. +To manually install, first `pip3 install dailalib` on the repo, then copy the [daila_plugin.py](./dailalib/daila_plugin.py) file to your decompiler's plugin directory. -### Ghidra Gotchas -You must have `python3` in your path for the Ghidra version to work. We quite literally call it from inside Python 2. ## Usage -In your decompiler you can access the DAILA options in one of two ways: -1. If you are not in Ghidra, you can right-click a function and go to `Plugins` or directly use the `DAILA ...` menu. -2. If you are in Ghidra, use `Tools->DAILA` then use the operation selector +DAILA is designed to be used in two ways: +1. As a decompiler plugin with a GUI +2. As a scripting library in your decompiler + +### Decompiler GUI +With the exception of Ghidra (see below), when you start your decompiler you will have a new context menu +which you can access when you right-click anywhere in a function: + +DAILA context menu + +If you are using Ghidra, go to `Tools->DAILA->Start DAILA Backend` to start the backend server. +After you've done this, you can use the context menu as shown above. + +### Scripting +You can use DAILA in your own scripts by importing the `dailalib` package. +Here is an example using the OpenAI API: +```python +from dailalib import OpenAIAPI +from libbs.api import DecompilerInterface + +deci = DecompilerInterface.discover_interface() +ai_api = OpenAIAPI(decompiler_interface=deci) +for function in deci.functions: + summary = ai_api.summarize_function(function) +``` -All operations that DAILA can perform can be found from the DAILA context menu, which in some decompilers may just be -the menu described above. -![](./assets/ida_show_menu_daila.png) +## Supported AI Backends +### OpenAI +DAILA supports the OpenAI API. To use the OpenAI API, you must have an OpenAI API key. +If your decompiler does not have access to the `OPENAI_API_KEY` environment variable, then you must use the decompiler option from +DAILA to set the API key. -Comments will appear in the function header with the response or an error message. +Currently, DAILA supports the following prompts: +- Summarize a function +- Rename variables +- Rename function +- Identify the source of a function ## Supported Decompilers - IDA @@ -48,16 +77,4 @@ Comments will appear in the function header with the response or an error messag ![](./assets/binja_daila.png) - Ghidra -![](./assets/ghidra_daila.png) - -## Features -### Function Identification -We use ChatGPT to attempt to: -1. Identify which open-source project this decompilation could be a result of -2. Find a link to that said source if it exists - -### Function Summarization -Summarizes in human-readable text what this function does - -### Vulnerability Detection -Attempts to find and describe the vulnerability in the function \ No newline at end of file +![](./assets/ghidra_daila.png) \ No newline at end of file diff --git a/dailalib/__init__.py b/dailalib/__init__.py index 67bc602..3271a37 100644 --- a/dailalib/__init__.py +++ b/dailalib/__init__.py @@ -1 +1,31 @@ -__version__ = "1.3.0" +__version__ = "2.0.0" + +from .api import AIAPI, OpenAIAPI +from libbs.api import DecompilerInterface + + +def create_plugin(*args, **kwargs): + + ai_api = OpenAIAPI(delay_init=True) + # create context menus for prompts + gui_ctx_menu_actions = { + f"DAILA/{prompt_name}": (prompt.desc, getattr(ai_api, prompt_name)) + for prompt_name, prompt in ai_api.prompts_by_name.items() + } + # create context menus for others + gui_ctx_menu_actions["DAILA/Update API Key"] = ("Update API Key", ai_api.ask_api_key) + + # create decompiler interface + force_decompiler = kwargs.pop("force_decompiler", None) + deci = DecompilerInterface.discover_interface( + force_decompiler=force_decompiler, + # decompiler-creation args + plugin_name="DAILA", + init_plugin=True, + gui_ctx_menu_actions=gui_ctx_menu_actions, + ui_init_args=args, + ui_init_kwargs=kwargs + ) + ai_api.init_decompiler_interface(decompiler_interface=deci) + + return deci.gui_plugin diff --git a/dailalib/__main__.py b/dailalib/__main__.py index b37eeb3..740e855 100644 --- a/dailalib/__main__.py +++ b/dailalib/__main__.py @@ -1,38 +1,38 @@ import argparse from .installer import DAILAInstaller -from .controller_server import DAILAServer +import dailalib def main(): parser = argparse.ArgumentParser( description=""" - The DAILA Command Line Util. + The DAILA CLI is used to install, run, and host the DAILA plugin. """, epilog=""" Examples: - daila --install + daila install """ ) parser.add_argument( - "--install", action="store_true", help=""" - Install the DAILA core to supported decompilers as plugins. This option will start an interactive - prompt asking for install paths for all supported decompilers. Each install path is optional and - will be skipped if not path is provided during install. - """ + "-i", "--install", action="store_true", help="Install DAILA into your decompilers" ) parser.add_argument( - "-server", action="store_true", help=""" - Starts the DAILA Server for use with Ghidra - """ + "-s", "--server", help="Run a a headless server for DAILA", choices=["ghidra"] + ) + parser.add_argument( + "-v", "--version", action="version", version=f"DAILA {dailalib.__version__}" ) args = parser.parse_args() if args.install: DAILAInstaller().install() + elif args.server: + if args.server != "ghidra": + raise NotImplementedError("Only Ghidra is supported for now") - if args.server: - DAILAServer(use_py2_exceptions=True).start_xmlrpc_server() + from dailalib import create_plugin + create_plugin(force_decompiler="ghidra") if __name__ == "__main__": diff --git a/dailalib/api/__init__.py b/dailalib/api/__init__.py new file mode 100644 index 0000000..549db31 --- /dev/null +++ b/dailalib/api/__init__.py @@ -0,0 +1,2 @@ +from .ai_api import AIAPI +from .openai import OpenAIAPI, Prompt diff --git a/dailalib/api/ai_api.py b/dailalib/api/ai_api.py new file mode 100644 index 0000000..e1107b0 --- /dev/null +++ b/dailalib/api/ai_api.py @@ -0,0 +1,99 @@ +from typing import Dict, Optional +from functools import wraps + +from libbs.api import DecompilerInterface + + +class AIAPI: + def __init__( + self, + decompiler_interface: Optional[DecompilerInterface] = None, + decompiler_name: Optional[str] = None, + use_decompiler: bool = True, + delay_init: bool = False, + # size in bytes + min_func_size: int = 0x10, + max_func_size: int = 0xffff, + model=None, + ): + # useful for initing after the creation of a decompiler interface + self._dec_interface: Optional[DecompilerInterface] = None + self._dec_name = None + if not delay_init: + self.init_decompiler_interface(decompiler_interface, decompiler_name, use_decompiler) + + self._min_func_size = min_func_size + self._max_func_size = max_func_size + self.model = model or self.__class__.__name__ + + def init_decompiler_interface( + self, + decompiler_interface: Optional[DecompilerInterface] = None, + decompiler_name: Optional[str] = None, + use_decompiler: bool = True + ): + self._dec_interface: DecompilerInterface = DecompilerInterface.discover_interface(force_decompiler=decompiler_name) \ + if use_decompiler and decompiler_interface is None else decompiler_interface + self._dec_name = decompiler_name if decompiler_interface is None else decompiler_interface.name + if self._dec_interface is None and not self._dec_name: + raise ValueError("You must either provide a decompiler name or a decompiler interface.") + + def info(self, msg): + if self._dec_interface is not None: + self._dec_interface.info(msg) + + def debug(self, msg): + if self._dec_interface is not None: + self._dec_interface.debug(msg) + + def warning(self, msg): + if self._dec_interface is not None: + self._dec_interface.warning(msg) + + def error(self, msg): + if self._dec_interface is not None: + self._dec_interface.error(msg) + + @property + def has_decompiler_gui(self): + return self._dec_interface is not None and not self._dec_interface.headless + + @staticmethod + def requires_function(f): + """ + A wrapper function to make sure an API call has decompilation text to operate on and possibly a Function + object. There are really two modes any API call operates in: + 1. Without Decompiler Backend: requires provided dec text + 2. With Decompiler Backend: + 2a. With UI: Function will be collected from the UI if not provided + 2b. Without UI: requires a FunctionA + + The Function collected from the UI is the one the use is currently looking at. + """ + @wraps(f) + def _requires_function(*args, ai_api: "AIAPI" = None, **kwargs): + function = kwargs.pop("function", None) + dec_text = kwargs.pop("dec_text", None) + use_dec = kwargs.pop("use_dec", True) + + if not dec_text and not use_dec: + raise ValueError("You must provide decompile text if you are not using a dec backend") + + # two mode constructions: with decompiler and without + # with decompiler backend + if use_dec: + if not ai_api.has_decompiler_gui and function is None: + raise ValueError("You must provide a Function when using this with a decompiler") + + # we must have a UI if we have no func + if function is None: + function = ai_api._dec_interface.active_context() + + # get new text with the function that is present + if dec_text is None: + dec_text = ai_api._dec_interface.decompile(function.addr) + + return f(*args, function=function, dec_text=dec_text, use_dec=use_dec, **kwargs) + + return _requires_function + diff --git a/dailalib/api/openai/__init__.py b/dailalib/api/openai/__init__.py new file mode 100644 index 0000000..9b2567d --- /dev/null +++ b/dailalib/api/openai/__init__.py @@ -0,0 +1,2 @@ +from .openai_api import OpenAIAPI +from .prompts import Prompt diff --git a/dailalib/api/openai/openai_api.py b/dailalib/api/openai/openai_api.py new file mode 100644 index 0000000..361070d --- /dev/null +++ b/dailalib/api/openai/openai_api.py @@ -0,0 +1,106 @@ +import typing +from typing import Optional +import os + +from openai import OpenAI +import tiktoken + +from dailalib.api.ai_api import AIAPI +if typing.TYPE_CHECKING: + from dailalib.api.openai.prompt import Prompt + + +class OpenAIAPI(AIAPI): + prompts_by_name = [] + DEFAULT_MODEL = "gpt-4" + MODEL_TO_TOKENS = { + "gpt-4": 8000, + "gpt-3.5-turbo": 4096 + } + + # replacement strings for API calls + def __init__(self, api_key: Optional[str] = None, model: str = DEFAULT_MODEL, prompts: Optional[list] = None, **kwargs): + super().__init__(**kwargs) + self._api_key = None + self._openai_client: OpenAI = None + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + self.model = model + + # delay prompt import + from dailalib.api.openai.prompts import PROMPTS + prompts = prompts + PROMPTS if prompts else PROMPTS + self.prompts_by_name = {p.name: p for p in prompts} + + def __dir__(self): + return list(super().__dir__()) + list(self.prompts_by_name.keys()) + + def __getattribute__(self, item): + if item in object.__getattribute__(self, "prompts_by_name"): + prompt_obj: "Prompt" = self.prompts_by_name[item] + prompt_obj.ai_api = self + return prompt_obj.query_model + else: + return object.__getattribute__(self, item) + + @property + def api_key(self): + return self._api_key + + @api_key.setter + def api_key(self, value): + self._api_key = value + if self._api_key: + self._openai_client = OpenAI(api_key=self._api_key) + + def ask_api_key(self, *args, **kwargs): + self.api_key = self._dec_interface.gui_ask_for_string("Enter you OpenAI API Key:", title="DAILA") + + def query_model( + self, + prompt: str, + model: Optional[str] = None, + max_tokens=None, + ): + if not self._openai_client: + raise ValueError("You must provide an API key before querying the model.") + + response = self._openai_client.chat.completions.create( + model=model or self.model, + messages=[ + {"role": "user", "content": prompt} + ], + max_tokens=max_tokens, + timeout=60, + ) + + try: + answer = response.choices[0].message.content + except (KeyError, IndexError) as e: + answer = None + + return answer + + @staticmethod + def estimate_token_amount(content: str, model=DEFAULT_MODEL): + enc = tiktoken.encoding_for_model(model) + tokens = enc.encode(content) + return len(tokens) + + @staticmethod + def content_fits_tokens(content: str, model=DEFAULT_MODEL): + max_token_count = OpenAIAPI.MODEL_TO_TOKENS[model] + token_count = OpenAIAPI.estimate_token_amount(content, model=model) + return token_count <= max_token_count - 1000 + + @staticmethod + def fit_decompilation_to_token_max(decompilation: str, delta_step=10, model=DEFAULT_MODEL): + if OpenAIAPI.content_fits_tokens(decompilation, model=model): + return decompilation + + dec_lines = decompilation.split("\n") + last_idx = len(dec_lines) - 1 + # should be: [func_prototype] + [nop] + [mid] + [nop] + [end_of_code] + dec_lines = dec_lines[0:2] + ["// ..."] + dec_lines[delta_step:last_idx-delta_step] + ["// ..."] + dec_lines[-2:-1] + decompilation = "\n".join(dec_lines) + + return OpenAIAPI.fit_decompilation_to_token_max(decompilation, delta_step=delta_step, model=model) \ No newline at end of file diff --git a/dailalib/api/openai/prompt.py b/dailalib/api/openai/prompt.py new file mode 100644 index 0000000..b930e4d --- /dev/null +++ b/dailalib/api/openai/prompt.py @@ -0,0 +1,84 @@ +import json +import re +from typing import Optional, Union, Dict, Callable +import textwrap + +from dailalib.api import AIAPI +from dailalib.api.openai.openai_api import OpenAIAPI + +JSON_REGEX = re.compile(r"\{.*}", flags=re.DOTALL) + + +class Prompt: + DECOMP_REPLACEMENT_LABEL = "" + SNIPPET_REPLACEMENT_LABEL = "" + SNIPPET_TEXT = f"\n\"\"\"{SNIPPET_REPLACEMENT_LABEL}\"\"\"" + DECOMP_TEXT = f"\n\"\"\"{DECOMP_REPLACEMENT_LABEL}\"\"\"" + + def __init__( + self, + name: str, + text: str, + desc: str = None, + pretext_response: Optional[str] = None, + posttext_response: Optional[str] = None, + json_response: bool = False, + ai_api=None, + # callback(result, function, ai_api) + gui_result_callback: Optional[Callable] = None + ): + self.name = name + self.text = textwrap.dedent(text) + self._pretext_response = pretext_response + self._posttext_response = posttext_response + self._json_response = json_response + self._gui_result_callback = gui_result_callback + self.desc = desc or name + self.ai_api: "OpenAIAPI" = ai_api + + def query_model(self, *args, function=None, dec_text=None, use_dec=True, **kwargs): + if self.ai_api is None: + raise Exception("api must be set before querying!") + + @AIAPI.requires_function + def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, use_dec=use_dec) -> Union[Dict, str]: + if not ai_api: + return {} + + ai_api.info(f"Querying {self.name} prompt with function {function}...") + response = self._pretext_response if self._pretext_response and not self._json_response else "" + # grab decompilation and replace it in the prompt, make sure to fix the decompilation for token max + query_text = self.text.replace( + self.DECOMP_REPLACEMENT_LABEL, + OpenAIAPI.fit_decompilation_to_token_max(dec_text) + ) + response += self.ai_api.query_model(query_text) + default_response = {} if self._json_response else "" + if not response: + return default_response + + # changes response type to a dict + if self._json_response: + # if the response of OpenAI gets cut off, we have an incomplete JSON + if "}" not in response: + response += "}" + + json_matches = JSON_REGEX.findall(response) + if not json_matches: + return default_response + + json_data = json_matches[0] + try: + response = json.loads(json_data) + except Exception: + response = {} + else: + response += self._posttext_response if self._pretext_response else "" + + if ai_api.has_decompiler_gui and response: + self._gui_result_callback(response, function, ai_api) + + ai_api.info(f"Reponse recieved...") + return response + return _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, use_dec=use_dec) + diff --git a/dailalib/api/openai/prompts.py b/dailalib/api/openai/prompts.py new file mode 100644 index 0000000..7845bc5 --- /dev/null +++ b/dailalib/api/openai/prompts.py @@ -0,0 +1,134 @@ +import typing + +from libbs.data import Comment + +from .prompt import Prompt + +if typing.TYPE_CHECKING: + from dailalib.api import AIAPI + + +def rename_function(result, function, ai_api: "AIAPI"): + new_name = list(result.values())[0] + function.name = new_name + ai_api._dec_interface.functions[function.addr] = function + + +def rename_variables(result, function, ai_api: "AIAPI"): + ai_api._dec_interface.rename_local_variables_by_names(function, result) + + +def comment_function(result, function, ai_api: "AIAPI"): + ai_api._dec_interface.comments[function.addr] = Comment(function.addr, comment=result, func_addr=function.addr) + + +PROMPTS = [ + Prompt( + "summarize", + f''' + You are C expert that summarizes code. When given code, you summarize at a high level what the function does + and you identify if known algorithms are used in the function. As an example: + """ + int sub_404000(int a0, char** a1) + {{ + int v1; // rax + v1 = sub_404100(a0[1]) % 2 == 0 + return v1; + }} + """ + + You responded with: + This function takes two arguments and implements the is_even check on second argument + + Here is another example: + {Prompt.DECOMP_TEXT} + + You responded with: + ''', + desc="Summarize the function", + gui_result_callback=comment_function + ), + Prompt( + "identify_source", + f''' + You are a C expert that identifies the original source given decompilation. Upon discovering the source, + you give a link to the code. You only respond with a json. As an example: + """ + void __fastcall __noreturn usage(int status) + {{ + v2 = program_name; + if ( status ) + {{ + v3 = dcgettext(0LL, "Try '%s --help' for more information.\n", 5); + _fprintf_chk(stderr, 1LL, v3, v2); + }} + // ... + }} + """ + + You responded with: + {{"link": "https://www.gnu.org/software/coreutils/"}} + + When you don't know, you respond with {{}}. + + Here is another example: + {Prompt.DECOMP_TEXT} + + You responded with: + ''', + desc="Identify the original source", + json_response=True, + gui_result_callback=comment_function + ), + Prompt( + "rename_variables", + f''' + You are C expert that renames variables in code. When given code, you rename variables according to the + meaning of the function or its use. You only respond with a json. As an example: + + int sub_404000(int a0, char** a1) + {{ + int v1; // rax + v1 = sub_404100(a0[1]) % 2 == 0 + return v1; + }} + + You responded with: + {{"a0: argc", "a1: argv", "v1": "is_even"}} + + Here is another example: + {Prompt.DECOMP_TEXT} + + You responded with: + ''', + desc="Suggest variable names", + json_response=True, + gui_result_callback=rename_variables, + ), + Prompt( + "rename_function", + f''' + You are C expert that renames functions. When given a function, you rename it according to the + meaning of the function or its use. You only respond with a json. As an example: + + int sub_404000(int a0, char** a1) + {{ + int is_even; // rax + is_even = sub_404100(a0[1]) % 2 == 0 + return is_even; + }} + + You responded with: + {{"sub_404000": "number_is_even"}} + + Here is another example: + {Prompt.DECOMP_TEXT} + + You responded with: + ''', + desc="Suggest a function name", + json_response=True, + gui_result_callback=rename_function + ), +] +PROMPTS_BY_NAME = {p.name: p for p in PROMPTS} diff --git a/dailalib/binsync_plugin/ai_bs_user.py b/dailalib/binsync_plugin/ai_bs_user.py index 5caedca..e5435ea 100644 --- a/dailalib/binsync_plugin/ai_bs_user.py +++ b/dailalib/binsync_plugin/ai_bs_user.py @@ -16,7 +16,7 @@ QDialog, QMessageBox ) -from dailalib.interfaces import OpenAIInterface +from dailalib.api import OpenAIAPI from tqdm import tqdm _l = logging.getLogger(__name__) @@ -148,7 +148,7 @@ def _collect_decompiled_functions(self) -> Dict: callback_stub(update_amt_per_func) continue - decompiled_functions[func.addr] = (OpenAIInterface.fit_decompilation_to_token_max(decompilation), func) + decompiled_functions[func.addr] = (OpenAIAPI.fit_decompilation_to_token_max(decompilation), func) callback_stub(update_amt_per_func) return decompiled_functions diff --git a/dailalib/binsync_plugin/openai_bs_user.py b/dailalib/binsync_plugin/openai_bs_user.py index dff1452..39ad553 100644 --- a/dailalib/binsync_plugin/openai_bs_user.py +++ b/dailalib/binsync_plugin/openai_bs_user.py @@ -3,7 +3,7 @@ from binsync.data import Function, StackVariable, Comment, State, FunctionHeader -from dailalib.interfaces import OpenAIInterface +from dailalib.api import OpenAIAPI from dailalib.binsync_plugin.ai_bs_user import AIBSUser _l = logging.getLogger(__name__) @@ -14,7 +14,7 @@ class OpenAIBSUser(AIBSUser): def __init__(self, openai_api_key, *args, **kwargs): super().__init__(openai_api_key, *args, **kwargs) - self.ai_interface = OpenAIInterface(openai_api_key=openai_api_key, decompiler_controller=self.controller, model=self._model) + self.ai_interface = OpenAIAPI(openai_api_key=openai_api_key, decompiler_controller=self.controller, model=self._model) def run_all_ai_commands_for_dec(self, decompilation: str, func: Function, state: State): changes = 0 diff --git a/dailalib/controller_server.py b/dailalib/controller_server.py deleted file mode 100644 index ab33525..0000000 --- a/dailalib/controller_server.py +++ /dev/null @@ -1,122 +0,0 @@ -from dailalib.interfaces.openai_interface import OpenAIInterface -from xmlrpc.server import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler -from functools import wraps - -class RequestHandler(SimpleXMLRPCRequestHandler): - rpc_paths = ("/RPC2",) - - -def proxy_and_catch(func): - @wraps(func) - def _proxy_and_catch(self, *args, **kwargs): - if not args or not args[0]: - return "" - - try: - output = func(self, *args, **kwargs) - except Exception as e: - if self.use_py2_exceptions: - raise BaseException(*e.args) - else: - raise e - - return output - return _proxy_and_catch - - -class DAILAServer: - def __init__(self, host=None, port=None, use_py2_exceptions=False): - self.host = host - self.port = port - self.running = False - self.controller = OpenAIInterface() - - self.use_py2_exceptions = use_py2_exceptions - - # - # Public API - # - @proxy_and_catch - def identify_function(self, decompilation: str): - if not decompilation: - return "" - - result = self.controller.find_source_of_function(decompilation=decompilation) - if not result: - return "" - - return result - - @proxy_and_catch - def explain_function(self, decompilation: str): - if not decompilation: - return "" - - result = self.controller.summarize_function(decompilation=decompilation) - if not result: - return "" - - return result - - @proxy_and_catch - def find_vulns_function(self, decompilation: str): - if not decompilation: - return "" - - result = self.controller.find_vulnerability_in_function(decompilation=decompilation) - if not result: - return "" - - return result - - @proxy_and_catch - def rename_func_and_vars_function(self, decompilation: str): - if not decompilation: - return "" - - result = self.controller.rename_variables_in_function(decompilation=decompilation) - if not result: - return "" - - return str(result) - - def set_api_key(self, api_key: str): - if api_key: - self.controller.api_key = api_key - - # - # XMLRPC Server - # - - def ping(self): - return True - - def shutdown(self): - self.running = False - - def start_xmlrpc_server(self, host="localhost", port=44414): - """ - Initialize the XMLRPC thread. - """ - host = host or self.host - port = port or self.port - - print("[+] Starting XMLRPC server: {}:{}".format(host, port)) - server = SimpleXMLRPCServer( - (host, port), - requestHandler=RequestHandler, - logRequests=False, - allow_none=True - ) - server.register_introspection_functions() - server.register_function(self.identify_function) - server.register_function(self.explain_function) - server.register_function(self.find_vulns_function) - server.register_function(self.rename_func_and_vars_function) - server.register_function(self.set_api_key) - server.register_function(self.ping) - server.register_function(self.shutdown) - print("[+] Registered decompilation server!") - self.running = True - while self.running: - server.handle_request() diff --git a/dailalib/daila_plugin.py b/dailalib/daila_plugin.py new file mode 100644 index 0000000..5b737ac --- /dev/null +++ b/dailalib/daila_plugin.py @@ -0,0 +1,47 @@ +# Run the DAILA server and open the DAILA selector UI. +# @author mahaloz +# @category AI +# @menupath Tools.DAILA.Start DAILA Backend + + +# replace this with the command to run your plugin remotely +library_command = "dailalib -s ghidra" +# replace this imporate with an import of your plugin's create_plugin function +def create_plugin(*args, **kwargs): + from dailalib import create_plugin as _create_plugin + return _create_plugin(*args, **kwargs) + +# ============================================================================= +# LibBS generic plugin loader (don't touch) +# ============================================================================= + +import sys +# Python 2 has special requirements for Ghidra, which forces us to use a different entry point +# and scope for defining plugin entry points +if sys.version[0] == "2": + # Do Ghidra Py2 entry point + import subprocess + from libbs_vendored.ghidra_bridge_server import GhidraBridgeServer + full_command = "python3 -m " + library_command + + GhidraBridgeServer.run_server(background=True) + process = subprocess.Popen(full_command.split(" ")) + if process.poll() is not None: + raise RuntimeError( + "Failed to run the Python3 backed. It's likely Python3 is not in your Path inside Ghidra.") +else: + # Try plugin discovery for other decompilers + try: + import idaapi + has_ida = True + except ImportError: + has_ida = False + + if not has_ida: + create_plugin() + +def PLUGIN_ENTRY(*args, **kwargs): + """ + This is the entry point for IDA to load the plugin. + """ + return create_plugin(*args, **kwargs) diff --git a/dailalib/installer.py b/dailalib/installer.py index e45710b..2abd887 100644 --- a/dailalib/installer.py +++ b/dailalib/installer.py @@ -1,19 +1,23 @@ import textwrap from pathlib import Path +import importlib.resources -import pkg_resources -from binsync.installer import Installer +from libbs.plugin_installer import PluginInstaller -class DAILAInstaller(Installer): +class DAILAInstaller(PluginInstaller): def __init__(self): super().__init__(targets=("ida", "ghidra", "binja")) - self.plugins_path = Path( - pkg_resources.resource_filename("dailalib", "plugins") - ) + self.pkg_path = Path(str(importlib.resources.files("dailalib"))).absolute() + + def _copy_plugin_to_path(self, path): + src = self.pkg_path / "daila_plugin.py" + dst = Path(path) / "daila_plugin.py" + self.link_or_copy(src, dst, symlink=True) def display_prologue(self): print(textwrap.dedent(""" + Now installing... ▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄ ▄ ▄▄▄▄▄▄▄▄▄▄▄ ▐░░░░░░░░░░▌ ▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌▐░▌ ▐░░░░░░░░░░░▌ ▐░█▀▀▀▀▀▀▀█░▌▐░█▀▀▀▀▀▀▀█░▌ ▀▀▀▀█░█▀▀▀▀ ▐░▌ ▐░█▀▀▀▀▀▀▀█░▌ @@ -25,46 +29,38 @@ def display_prologue(self): ▐░█▄▄▄▄▄▄▄█░▌▐░▌ ▐░▌ ▄▄▄▄█░█▄▄▄▄ ▐░█▄▄▄▄▄▄▄▄▄ ▐░▌ ▐░▌ ▐░░░░░░░░░░▌ ▐░▌ ▐░▌▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌▐░▌ ▐░▌ ▀▀▀▀▀▀▀▀▀▀ ▀ ▀ ▀▀▀▀▀▀▀▀▀▀▀ ▀▀▀▀▀▀▀▀▀▀▀ ▀ ▀ - - Now installing DAILA... - Please input your decompiler install paths as prompted. Enter nothing to either use - the default install path if one exist, or to skip. + + The Decompiler AI Language Assistant """)) - def install_ida(self, path=None): - ida_plugin_path = super().install_ida(path=path) - if ida_plugin_path is None: + def install_ida(self, path=None, interactive=True): + path = path or super().install_ida(path=path, interactive=interactive) + if not path: return - src_ida_py = self.plugins_path.joinpath("daila_ida.py").absolute() - dst_ida_py = ida_plugin_path.joinpath("daila_ida.py").absolute() - self.link_or_copy(src_ida_py, dst_ida_py) - return dst_ida_py + self._copy_plugin_to_path(path) + return path - def install_ghidra(self, path=None): - ghidra_path = self.ask_path("Ghidra Scripts Path:") if path is None else path - if not ghidra_path: - return None + def install_ghidra(self, path=None, interactive=True): + path = path or super().install_ghidra(path=path, interactive=interactive) + if not path: + return - ghidra_path: Path = ghidra_path.expanduser().absolute() - if not ghidra_path.exists(): - return None + self._copy_plugin_to_path(path) + return path - src_ghidra_py = self.plugins_path.joinpath("daila_ghidra.py").absolute() - dst_path = ghidra_path.expanduser().absolute() - if not dst_path.exists(): - return None + def install_binja(self, path=None, interactive=True): + path = path or super().install_binja(path=path, interactive=interactive) + if not path: + return - dst_ghidra_py = dst_path.joinpath(src_ghidra_py.name) - self.link_or_copy(src_ghidra_py, dst_ghidra_py) - return dst_ghidra_py + self._copy_plugin_to_path(path) + return path - def install_binja(self, path=None): - binja_plugin_path = super().install_binja(path=path) - if binja_plugin_path is None: - return None + def install_angr(self, path=None, interactive=True): + path = path or super().install_angr(path=path, interactive=interactive) + if not path: + return - src_path = self.plugins_path.joinpath("daila_binja.py") - dst_path = binja_plugin_path.joinpath("daila_binja.py") - self.link_or_copy(src_path, dst_path) - return dst_path + self._copy_plugin_to_path(path) + return path diff --git a/dailalib/interfaces/__init__.py b/dailalib/interfaces/__init__.py deleted file mode 100644 index 697c433..0000000 --- a/dailalib/interfaces/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .generic_ai_interface import GenericAIInterface -from .openai_interface import OpenAIInterface - -__all__ = [ - "GenericAIInterface", - "OpenAIInterface", -] diff --git a/dailalib/interfaces/generic_ai_interface.py b/dailalib/interfaces/generic_ai_interface.py deleted file mode 100644 index 2debc94..0000000 --- a/dailalib/interfaces/generic_ai_interface.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Optional - -from binsync.api import load_decompiler_controller, BSController - - -class GenericAIInterface: - def __init__(self, decompiler_controller: Optional[BSController] = None): - self.decompiler_controller = decompiler_controller or load_decompiler_controller() - - # - # decompiler interface - # - - def _cmt_func(self, func_addr: int, comment: str, **kwargs): - return False - - def _decompile(self, func_addr: int, **kwargs): - return self.decompiler_controller.decompile(func_addr) - - def _current_function_addr(self, **kwargs): - return None - - def _register_menu_item(self, name, action_string, callback_func): - return False - - def ask_api_key(self): - pass diff --git a/dailalib/interfaces/openai_interface.py b/dailalib/interfaces/openai_interface.py deleted file mode 100644 index 5575a57..0000000 --- a/dailalib/interfaces/openai_interface.py +++ /dev/null @@ -1,366 +0,0 @@ -import re -from typing import Optional, Dict -import os -import json -from functools import wraps - -from openai import OpenAI -import tiktoken - -from .generic_ai_interface import GenericAIInterface -from ..utils import HYPERLINK_REGEX - - -def addr_ctx_when_none(f): - @wraps(f) - def _addr_ctx_when_none(self: "OpenAIInterface", *args, **kwargs): - func_addr = kwargs.get("func_addr", None) - if func_addr is None: - func_addr = self._current_function_addr() - kwargs.update({"func_addr": func_addr, "edit_dec": True}) - return f(self, *args, **kwargs) - return _addr_ctx_when_none - -JSON_REGEX = re.compile(r"\{.*\}", flags=re.DOTALL) -QUESTION_START = "Q>" -QUESTION_REGEX = re.compile(rf"({QUESTION_START})([^?]*\?)") -ANSWER_START = "A>" - -DEFAULT_MODEL = "gpt-4" -MODEL_TO_TOKENS = { - "gpt-4": 8000, - "gpt-3.5-turbo": 4096 -} - -class OpenAIInterface(GenericAIInterface): - # API Command Constants - SUMMARIZE_CMD = "daila_summarize" - RENAME_FUNCS_CMD = "daial_rename_funcs" - RENAME_VARS_CMD = "daila_rename_vars" - RETYPE_VARS_CMD = "daila_retype_vars" - FIND_VULN_CMD = "daila_find_vuln" - ID_SOURCE_CMD = "daila_id_source" - ANSWER_QUESTION_CMD = "daila_answer_question" - AI_COMMANDS = { - RENAME_VARS_CMD: {"json_response": True}, - RETYPE_VARS_CMD: {"json_response": True}, - SUMMARIZE_CMD: {}, - RENAME_FUNCS_CMD: {"json_response": True}, - ID_SOURCE_CMD: {"increase_new_text": False}, - FIND_VULN_CMD: {}, - ANSWER_QUESTION_CMD: {"extra_handler": "answer_questions_in_decompilation"}, - } - - # replacement strings for API calls - DECOMP_REPLACEMENT_LABEL = "" - SNIPPET_REPLACEMENT_LABEL = "" - SNIPPET_TEXT = f"\n\"\"\"{SNIPPET_REPLACEMENT_LABEL}\"\"\"" - DECOMP_TEXT = f"\n\"\"\"{DECOMP_REPLACEMENT_LABEL}\"\"\"" - PROMPTS = { - RENAME_VARS_CMD: 'Analyze what the following function does. Suggest better variable names. ' - 'Reply with only a JSON array where keys are the original names and values are ' - f'the proposed names. Here is an example response: {{"v1": "buff"}} {DECOMP_TEXT}', - RETYPE_VARS_CMD: "Analyze what the following function does. Suggest better C types for the variables. " - "Reply with only a JSON where keys are the original names and values are the " - f"proposed types: {DECOMP_TEXT}", - SUMMARIZE_CMD: f"Please summarize the following code:{DECOMP_TEXT}", - FIND_VULN_CMD: "Can you find the vulnerability in the following function and suggest the " - f"possible way to exploit it?{DECOMP_TEXT}", - ID_SOURCE_CMD: "What open source project is this code from. Please only give me the program name and " - f"package name:{DECOMP_TEXT}", - RENAME_FUNCS_CMD: "The following code is C/C++. Rename the function according to its purpose using underscore_case. Reply with only a JSON array where keys are the " - f"original names and values are the proposed names:{DECOMP_TEXT}", - ANSWER_QUESTION_CMD: "You are a code comprehension assistant. You answer questions based on code that is " - f"provided. Here is some code: {DECOMP_TEXT}. Focus on this snippet of the code: " - f"{SNIPPET_TEXT}\n\n Answer the following question as concisely as possible, guesses " - f"are ok: " - } - - def __init__(self, openai_api_key=None, model=DEFAULT_MODEL, decompiler_controller=None): - super().__init__(decompiler_controller=decompiler_controller) - self.model = model - - self.menu_commands = { - "daila_set_key": ("Update OpenAPI Key...", self.ask_api_key), - f"{self.ID_SOURCE_CMD}": ("Identify the source code", self.find_source_of_function), - f"{self.SUMMARIZE_CMD}": ("Summarize function", self.summarize_function), - f"{self.FIND_VULN_CMD}": ("Find vulnerabilities", self.find_vulnerability_in_function), - f"{self.RENAME_FUNCS_CMD}": ("Rename functions used in function", self.rename_functions_in_function), - f"{self.RENAME_VARS_CMD}": ("Rename variables in function", self.rename_variables_in_function), - f"{self.RETYPE_VARS_CMD}": ("Retype variables in function", self.retype_variables_in_function), - f"{self.ANSWER_QUESTION_CMD}": ("Answer questions in function", self.answer_questions), - } - - for menu_str, callback_info in self.menu_commands.items(): - callback_str, callback_func = callback_info - self._register_menu_item(menu_str, callback_str, callback_func) - - self._api_key = os.getenv("OPENAI_API_KEY") or openai_api_key - self._openai_client = OpenAI(api_key=self._api_key) - - - @property - def api_key(self): - return self._api_key - - @api_key.setter - def api_key(self, data): - self._api_key = data - - - # - # OpenAI Interface - # - - def _query_openai_model( - self, - question: str, - model: Optional[str] = None, - temperature=0.1, - max_tokens=None, - frequency_penalty=0, - presence_penalty=0 - ): - # TODO: at some point add back frequency_penalty and presence_penalty to be used - try: - response = self._openai_client.chat.completions.create(model=model or self.model, - messages=[ - {"role": "user", "content": question} - ], - max_tokens=max_tokens, - timeout=60, - stop=['}']) - except openai.OpenAIError as e: - raise Exception(f"ChatGPT could not complete the request: {str(e)}") - - answer = None - try: - answer = response.choices[0].message.content - except (KeyError, IndexError) as e: - pass - - return answer - - def _query_openai(self, prompt: str, json_response=False, increase_new_text=True, **kwargs): - freq_penalty = 0 - pres_penalty = 0 - default_response = {} if json_response else None - - if increase_new_text: - freq_penalty = 1 - pres_penalty = 1 - - resp = self._query_openai_model(prompt, frequency_penalty=freq_penalty, presence_penalty=pres_penalty) - if resp is None: - return default_response - - if json_response: - # if the response of OpenAI gets cut off, we have an incomplete JSON - if "}" not in resp: - resp += "}" - - json_matches = JSON_REGEX.findall(resp) - if not json_matches: - return default_response - - json_data = json_matches[0] - try: - data = json.loads(json_data) - except Exception: - data = {} - - return data - - return resp - - def query_openai_for_function( - self, func_addr: int, prompt: str, decompile=True, json_response=False, increase_new_text=True, - **kwargs - ): - default_response = {} if json_response else None - if decompile: - decompilation = self._decompile(func_addr, **kwargs) - if not decompilation: - return default_response - - prompt = prompt.replace(self.DECOMP_REPLACEMENT_LABEL, decompilation) - else: - decompilation = kwargs.get("decompilation", None) - - if "extra_handler" in kwargs: - extra_handler = getattr(self, kwargs.pop("extra_handler")) - kwargs["decompilation"] = decompilation - return extra_handler(func_addr, prompt, **kwargs) - else: - return self._query_openai(prompt, json_response=json_response, increase_new_text=increase_new_text) - - def query_for_cmd(self, cmd, func_addr=None, decompilation=None, edit_dec=False, **kwargs): - if cmd not in self.AI_COMMANDS: - raise ValueError(f"Command {cmd} is not supported") - - kwargs.update(self.AI_COMMANDS[cmd]) - if func_addr is None and decompilation is None: - raise Exception("You must provide either a function address or decompilation!") - - prompt = self.PROMPTS[cmd] - if decompilation is not None: - prompt = prompt.replace(self.DECOMP_REPLACEMENT_LABEL, decompilation) - - return self.query_openai_for_function(func_addr, prompt, decompile=decompilation is None, decompilation=decompilation, **kwargs) - - # - # Extra Handlers - # - - def answer_questions_in_decompilation(self, func_addr, prompt: str, context_window=25, **kwargs) -> Dict[str, str]: - questions = [m for m in QUESTION_REGEX.finditer(prompt)] - if not questions: - return {} - - decompilation = kwargs.get("decompilation", None) - if decompilation is None: - print("Decompilation is required for answering questions!") - return {} - - dec_lines = decompilation.split("\n") - snippets = {} - for question in questions: - full_str = question.group(0) - for i, line in enumerate(dec_lines): - if full_str in line: - break - else: - snippets[full_str] = None - continue - - context_window_lines = dec_lines[i+1:i+1+context_window] - snippets[full_str] = "\n".join(context_window_lines) - - answers = {} - for question in questions: - full_str = question.group(0) - question_str = question.group(2) - snippet = snippets.get(full_str, None) - - if snippet is None: - continue - - if full_str.endswith("X?"): - normalized_prompt = question_str.replace("X?", "?") - else: - normalized_prompt = prompt.replace(self.SNIPPET_REPLACEMENT_LABEL, snippet) + question_str - - answer = self.query_openai_for_function( - func_addr, normalized_prompt, decompile=False, **kwargs - ) - if answer is not None: - answers[full_str] = f"A> {answer}" - - return answers - - # - # API Alias Wrappers - # - - @addr_ctx_when_none - def answer_questions(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs) -> Dict[str, str]: - resp = self.query_for_cmd(self.ANSWER_QUESTION_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) - if resp and edit_dec: - out = "" - for q, a in resp.items(): - out += f"{q}\n{a}\n\n" - - self._cmt_func(func_addr, out) - - return resp - - @addr_ctx_when_none - def summarize_function(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs) -> str: - resp = self.query_for_cmd(self.SUMMARIZE_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) - if resp: - resp = f"\nGuessed Summarization:\n{resp}" - - if edit_dec and resp: - self._cmt_func(func_addr, resp) - - return resp - - @addr_ctx_when_none - def find_vulnerability_in_function(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs) -> str: - resp = self.query_for_cmd(self.FIND_VULN_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) - if resp: - resp = f"\nGuessed Vuln:\n{resp}" - - if edit_dec and resp: - self._cmt_func(func_addr, resp) - - return resp - - @addr_ctx_when_none - def find_source_of_function(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs) -> str: - resp = self.query_for_cmd(self.ID_SOURCE_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) - links = re.findall(HYPERLINK_REGEX, resp) - if not links: - return "" - - resp = f"\nGuessed Source:\n{links}" - if edit_dec: - self._cmt_func(func_addr, resp) - - return resp - - @addr_ctx_when_none - def rename_functions_in_function(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs) -> dict: - resp: Dict = self.query_for_cmd(self.RENAME_FUNCS_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) - if edit_dec and resp: - # TODO: reimplement this code with self.decompiler_controller.set_function(func) - pass - - return resp - - @addr_ctx_when_none - def rename_variables_in_function(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs) -> dict: - resp: Dict = self.query_for_cmd(self.RENAME_VARS_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) - if edit_dec and resp: - # TODO: reimplement this code with self.decompiler_controller.set_function(func) - pass - - return resp - - @addr_ctx_when_none - def retype_variables_in_function(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs) -> dict: - resp: Dict = self.query_for_cmd(self.RETYPE_VARS_CMD, func_addr=func_addr, decompilation=decompilation, **kwargs) - if edit_dec and resp: - # TODO: reimplement this code with self.decompiler_controller.set_function(func) - pass - - return resp - - # - # helpers - # - - @staticmethod - def estimate_token_amount(content: str, model=DEFAULT_MODEL): - enc = tiktoken.encoding_for_model(model) - tokens = enc.encode(content) - return len(tokens) - - @staticmethod - def content_fits_tokens(content: str, model=DEFAULT_MODEL): - max_token_count = MODEL_TO_TOKENS[model] - token_count = OpenAIInterface.estimate_token_amount(content, model=model) - return token_count <= max_token_count - 1000 - - @staticmethod - def fit_decompilation_to_token_max(decompilation: str, delta_step=10, model=DEFAULT_MODEL): - if OpenAIInterface.content_fits_tokens(decompilation, model=model): - return decompilation - - dec_lines = decompilation.split("\n") - last_idx = len(dec_lines) - 1 - # should be: [func_prototype] + [nop] + [mid] + [nop] + [end_of_code] - dec_lines = dec_lines[0:2] + ["// ..."] + dec_lines[delta_step:last_idx-delta_step] + ["// ..."] + dec_lines[-2:-1] - decompilation = "\n".join(dec_lines) - - return OpenAIInterface.fit_decompilation_to_token_max(decompilation, delta_step=delta_step, model=model) diff --git a/dailalib/utils.py b/dailalib/utils.py deleted file mode 100644 index 3b4c601..0000000 --- a/dailalib/utils.py +++ /dev/null @@ -1,2 +0,0 @@ - -HYPERLINK_REGEX = r"""(?i)\b((?:https?:(?:/{1,3}|[a-z0-9%])|[a-z0-9.\-]+[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|name|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)/)(?:[^\s()<>{}\[\]]+|\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\))+(?:\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\)|[^\s`!()\[\]{};:'".,<>?«»“”‘’])|(?:(? str: - bv, address = args[0:2] - self.bv = bv - return super().find_source_of_function(func_addr=address, bv=bv, edit_dec=True) - - @with_loading_popup - @addr_ctx_when_none - def summarize_function(self, *args, **kwargs) -> str: - bv, address = args[0:2] - self.bv = bv - return super().summarize_function(func_addr=address, bv=bv, edit_dec=True) - - @with_loading_popup - @addr_ctx_when_none - def find_vulnerability_in_function(self, *args, **kwargs) -> str: - bv, address = args[0:2] - self.bv = bv - return super().find_vulnerability_in_function(func_addr=address, bv=bv, edit_dec=True) - - @with_loading_popup - @addr_ctx_when_none - def answer_questions(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs): - bv, address = args[0:2] - self.bv = bv - return super().answer_questions(func_addr=address, bv=bv, edit_dec=True) - - @with_loading_popup - @addr_ctx_when_none - def rename_functions_in_function(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs): - bv, address = args[0:2] - self.bv = bv - return super().rename_functions_in_function(func_addr=address, bv=bv, edit_dec=True) - - @with_loading_popup - @addr_ctx_when_none - def rename_variables_in_function(self, *args, func_addr=None, decompilation=None, edit_dec=False, **kwargs): - bv, address = args[0:2] - self.bv = bv - return super().rename_variables_in_function(func_addr=address, bv=bv, edit_dec=True) - - - - - -class DAILAPlugin: - def __init__(self): - self.controller = BinjaOpenAIInterface(plugin=self) - - @staticmethod - def get_func(bv, address): - funcs = bv.get_functions_containing(address) - try: - func = funcs[0] - except IndexError: - return None - - return func - - @staticmethod - def is_func(bv, address): - func = DAILAPlugin.get_func(bv, address) - return func is not None - - -plugin = DAILAPlugin() diff --git a/plugins/daila_ghidra.py b/plugins/daila_ghidra.py deleted file mode 100644 index 137dc0c..0000000 --- a/plugins/daila_ghidra.py +++ /dev/null @@ -1,184 +0,0 @@ -# Activates DAILA where the user will be able to select from a series of possible commands -#@category AI -#@menupath Tools.DAILA: Run a DAILA operation... -#@keybinding ctrl alt shift d -import os - -import xmlrpclib -from time import sleep -import subprocess - -from ghidra.app.decompiler.flatapi import FlatDecompilerAPI -from ghidra.program.flatapi import FlatProgramAPI - -fpapi = FlatProgramAPI(getState().getCurrentProgram()) -fdapi = FlatDecompilerAPI(fpapi) - - -# -# Ghidra decompiler utils -# - -def current_func(): - func = fpapi.getFunctionContaining(currentAddress) - return func if func else None - - -def decompile_curr_func(): - func = current_func() - if func is None: - return None - - decomp = fdapi.decompile(func) - if not decomp: - return None - - return str(decomp) - - -def set_func_comment(comment): - func = current_func() - if func is None: - return False - - func.setComment(comment) - return True - - -class ServerCtx: - def __init__(self, host="http://localhost", port=44414): - self.proxy_host = "%s:%d" % (host, port) - - self.server_proc = None - self.server = None - - def __enter__(self, *args, **kwargs): - self._start_server() - key = os.getenv("OPENAI_API_KEY") - if key: - self.server.set_api_key(key) - - return self.server - - def __exit__(self, *args, **kwargs): - self._stop_server() - - def _start_server(self): - self.server_proc = subprocess.Popen(["python3", "-m", "dailalib", "-server"]) - sleep(3) - self.server = xmlrpclib.ServerProxy(self.proxy_host) - - try: - self.server.ping() - except BaseException as e: - print("[!] Encountered an error when creating the py3 server", e) - if self.server_proc is not None: - self.server_proc.kill() - - self.server_proc = None - self.server = None - - def _stop_server(self): - if self.server_proc is None or self.server is None: - return - - self.server.shutdown() - self.server_proc.kill() - self.server = None - self.server_proc = None - -# -# DAILA API -# - - -def identify_function(): - with ServerCtx() as server: - decomp = decompile_curr_func() - if decomp is None: - return None - - resp = server.identify_function(decomp) - if not resp: - return None - - set_func_comment(resp) - - -def explain_function(): - with ServerCtx() as server: - decomp = decompile_curr_func() - if decomp is None: - return None - - resp = server.explain_function(decomp) - if not resp: - return None - - set_func_comment(resp) - -def find_function_vuln(): - with ServerCtx() as server: - decomp = decompile_curr_func() - if decomp is None: - return None - - resp = server.find_vulns_function(decomp) - if not resp: - return None - - set_func_comment(resp) - -def rename_func_and_vars(): - with ServerCtx() as server: - decomp = decompile_curr_func() - if decomp is None: - return None - - resp = server.rename_func_and_vars_function(decomp) - if not resp: - return None - - set_func_comment(resp) - - -def set_api_key(): - old_key = os.getenv("OPENAI_API_KEY") - key = askString("DAILA", "Enter your OpenAI API Key: ", old_key) - if not key: - return - - with ServerCtx() as server: - server.set_api_key("") - - os.putenv("OPENAI_API_KEY", key) - -# -# Operation Selector -# - - -DAILA_OPS = { - "Identify function source": identify_function, - "Explain function": explain_function, - "Find function vulns": find_function_vuln, - "Rename function name and variables": rename_func_and_vars, - "Set OpenAPI Key...": set_api_key, -} - - -def select_and_run_daila_op(): - daila_op_keys = list(DAILA_OPS.keys()) - choice = askChoice("DAILA Operation Selector", "Please choose a DAILA operation", daila_op_keys, daila_op_keys[0]) - if not choice: - print("[+] Cancelled...") - return - - op_func = DAILA_OPS[choice] - op_func() - - -if __name__ == "__main__": - print("[+] Starting DAILA Operations selector for " + str(current_func())) - select_and_run_daila_op() - diff --git a/plugins/daila_ida.py b/plugins/daila_ida.py deleted file mode 100644 index a602586..0000000 --- a/plugins/daila_ida.py +++ /dev/null @@ -1,229 +0,0 @@ -import threading -import functools -from typing import Optional, Dict - -import idaapi -import ida_hexrays -import idc -from PyQt5.QtWidgets import QProgressDialog - -from dailalib.interfaces.openai_interface import OpenAIInterface - -controller: Optional["IDADAILAController"] = None - - -# -# IDA Threading -# - - -def is_mainthread(): - """ - Return a bool that indicates if this is the main application thread. - """ - return isinstance(threading.current_thread(), threading._MainThread) - - -def execute_sync(func, sync_type): - """ - Synchronize with the disassembler for safe database access. - Modified from https://github.com/vrtadmin/FIRST-plugin-ida - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs) -> object: - output = [None] - - # - # this inline function definition is technically what will execute - # in the context of the main thread. we use this thunk to capture - # any output the function may want to return to the user. - # - - def thunk(): - output[0] = func(*args, **kwargs) - return 1 - - if is_mainthread(): - thunk() - else: - idaapi.execute_sync(thunk, sync_type) - - # return the output of the synchronized execution - return output[0] - return wrapper - - -def execute_read(func): - return execute_sync(func, idaapi.MFF_READ) - - -def execute_write(func): - return execute_sync(func, idaapi.MFF_WRITE) - - -def execute_ui(func): - return execute_sync(func, idaapi.MFF_FAST) - -# -# IDA PLUGIN -# - - -def PLUGIN_ENTRY(*args, **kwargs): - return DAILAPlugin(*args, **kwargs) - - -class DAILAPlugin(idaapi.plugin_t): - flags = idaapi.PLUGIN_FIX - comment = "Start the DAILA Interface" - help = "DAILA Help" - wanted_name = "Identify the current function you are looking at!" - - id_action_name = "dailalib:identify_function" - id_menu_path = "Edit/DAILA/Explain function" - - def __init__(self, *args, **kwargs): - print("[DAILA] loaded!") - self.controller = IDADAILAController() - - - def init(self): - if not ida_hexrays.init_hexrays_plugin(): - return idaapi.PLUGIN_SKIP - - self.controller.hook_menu() - return idaapi.PLUGIN_KEEP - - def run(self, arg): - pass - - def term(self): - pass - - -class ContextMenuHooks(idaapi.UI_Hooks): - def __init__(self, *args, menu_strs=None, **kwargs): - idaapi.UI_Hooks.__init__(self) - self.menu_strs = menu_strs or [] - - def finish_populating_widget_popup(self, form, popup): - # Add actions to the context menu of the Pseudocode view - if idaapi.get_widget_type(form) == idaapi.BWN_PSEUDOCODE or idaapi.get_widget_type(form) == idaapi.BWN_DISASM: - for menu_str in self.menu_strs: - #idaapi.attach_action_to_popup(form, popup, DAILAPlugin.id_action_name, "DAILA/") - idaapi.attach_action_to_popup(form, popup, menu_str, "DAILA/") - - -class GenericAction(idaapi.action_handler_t): - def __init__(self, action_target, action_function): - idaapi.action_handler_t.__init__(self) - self.action_target = action_target - self.action_function = action_function - - def activate(self, ctx): - if ctx is None or ctx.action != self.action_target: - return - - dec_view = ida_hexrays.get_widget_vdui(ctx.widget) - # show a thing while we query - prg = QProgressDialog("Querying AI...", "Stop", 0, 1, None) - prg.show() - - self.action_function() - - # close the panel we showed while running - prg.setValue(1) - prg.close() - - dec_view.refresh_view(False) - return 1 - - # This action is always available. - def update(self, ctx): - return idaapi.AST_ENABLE_ALWAYS - - -class IDADAILAController(OpenAIInterface): - def __init__(self): - self.menu_actions = [] - super().__init__(self) - self.menu = None - - def _decompile(self, func_addr: int, **kwargs): - try: - cfunc = ida_hexrays.decompile(func_addr) - except Exception: - return None - - return str(cfunc) - - def _current_function_addr(self, **kwargs): - curr_ea = idaapi.get_screen_ea() - if curr_ea is None: - return None - - func = idaapi.get_func(curr_ea) - if func is None: - return None - - return func.start_ea - - def _cmt_func(self, func_addr: int, comment: str, **kwargs): - idc.set_func_cmt(func_addr, comment, 0) - return True - - def _rename_variables_by_name(self, func_addr: int, names: Dict[str, str]): - dec = idaapi.decompile(func_addr) - if dec is None: - return False - - lvars = { - lvar.name: lvar for lvar in dec.get_lvars() if lvar.name - } - update = False - for name, lvar in lvars.items(): - new_name = names.get(name, None) - if new_name is None: - continue - - lvar.name = new_name - update |= True - - if update: - dec.refresh_func_ctext() - - return update - - def _register_menu_item(self, name, action_string, callback_func): - # Function explaining action - explain_action = idaapi.action_desc_t( - name, - action_string, - GenericAction(name, callback_func), - "", - action_string, - 199 - ) - idaapi.register_action(explain_action) - idaapi.attach_action_to_menu(f"Edit/DAILA/{name}", name, idaapi.SETMENU_APP) - self.menu_actions.append(name) - - def hook_menu(self): - self.menu = ContextMenuHooks(menu_strs=self.menu_actions) - self.menu.hook() - - def ask_api_key(self, *args, **kwargs): - resp = idaapi.ask_str(self.api_key if isinstance(self.api_key, str) else "", 0, "Enter you OpenAI API Key: ") - if resp is None: - return - - self.api_key = resp - - @execute_write - def identify_current_function(self): - return super().identify_current_function() - - @execute_write - def explain_current_function(self, **kwargs): - return super().explain_current_function(**kwargs) diff --git a/pyproject.toml b/pyproject.toml index 67001d8..3680cef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,20 +7,24 @@ name = "dailalib" classifiers = [ "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.8", ] license = {text = "BSD 2 Clause"} description = "Decompiler Artificial Intelligence Language Assistant" urls = {Homepage = "https://github.com/mahaloz/DAILA"} -requires-python = ">= 3.5" +requires-python = ">= 3.8" dependencies = [ - "openai>=0.27.4", - "binsync", - "PySide6-Essentials", + "openai>=1.0.0", + "libbs", "tiktoken", ] dynamic = ["version"] +[project.optional-dependencies] +ghidra = [ + "libbs[ghidra]", +] + [project.readme] file = "README.md" content-type = "text/markdown" @@ -35,8 +39,6 @@ license-files = ["LICENSE"] [tool.setuptools.packages] find = {namespaces = false} -[tool.setuptools.package-data] -daila = ["plugins/*.py"] - [tool.setuptools.dynamic] version = {attr = "dailalib.__version__"} + diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 1221e04..0000000 --- a/setup.cfg +++ /dev/null @@ -1,32 +0,0 @@ -[metadata] -name = dailalib -version = attr: dailalib.__version__ -url = https://github.com/mahaloz/DAILA -classifiers = - License :: OSI Approved :: BSD License - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.6 -license = BSD 2 Clause -license_files = LICENSE -description = Decompiler Artificial Intelligence Language Assistant -long_description = file: README.md -long_description_content_type = text/markdown - -[options] -install_requires = - openai>=0.27.4 - binsync - PySide6-Essentials - tiktoken - -python_requires = >= 3.5 -include_package_data = True -packages = find: - -[options.package_data] -daila = - plugins/*.py - -[options.entry_points] -console_scripts = - daila = dailalib.__main__:main diff --git a/setup.py b/setup.py deleted file mode 100644 index c1968e2..0000000 --- a/setup.py +++ /dev/null @@ -1,52 +0,0 @@ -# pylint: disable=missing-class-docstring -import os -import shutil -from pathlib import Path -from distutils.command.build import build as st_build - -from setuptools import setup -from setuptools.command.develop import develop as st_develop - - -def _copy_plugins(): - local_plugins = Path("plugins").absolute() - daila_loc = Path("dailalib").absolute() - pip_e_plugins = daila_loc.joinpath("plugins").absolute() - - # clean the install location of symlink or folder - shutil.rmtree(pip_e_plugins, ignore_errors=True) - try: - os.unlink(pip_e_plugins) - except: - pass - - # first attempt a symlink, if it works, exit early - try: - os.symlink(local_plugins, pip_e_plugins, target_is_directory=True) - return - except: - pass - - # copy if symlinking is not available on target system - try: - shutil.copytree("plugins", "dailalib/plugins") - except: - pass - -class build(st_build): - def run(self, *args): - self.execute(_copy_plugins, (), msg="Copying plugins...") - super().run(*args) - -class develop(st_develop): - def run(self, *args): - self.execute(_copy_plugins, (), msg="Linking or copying local plugins folder...") - super().run(*args) - - -cmdclass = { - "build": build, - "develop": develop, -} - -setup(cmdclass=cmdclass)