From b14764cd067e169af2cc70e0dda5a8d9a0a80237 Mon Sep 17 00:00:00 2001 From: Zion Leonahenahe Basque Date: Thu, 12 Sep 2024 15:18:28 -0700 Subject: [PATCH] Add Chain-of-Thought Prompting Style (#56) * Add Chain-of-Thought Prompting Style * Add all CoT prompts and bump libbs --- dailalib/__init__.py | 7 +- dailalib/api/litellm/litellm_api.py | 4 +- dailalib/api/litellm/prompts/__init__.py | 41 ++- dailalib/api/litellm/prompts/cot_prompts.py | 311 ++++++++++++++++++ .../{prompts.py => few_shot_prompts.py} | 20 +- dailalib/api/litellm/prompts/prompt.py | 50 ++- dailalib/api/litellm/prompts/prompt_type.py | 7 +- setup.cfg | 2 +- 8 files changed, 407 insertions(+), 35 deletions(-) create mode 100644 dailalib/api/litellm/prompts/cot_prompts.py rename dailalib/api/litellm/prompts/{prompts.py => few_shot_prompts.py} (85%) diff --git a/dailalib/__init__.py b/dailalib/__init__.py index dabbecc..11487ce 100644 --- a/dailalib/__init__.py +++ b/dailalib/__init__.py @@ -1,4 +1,4 @@ -__version__ = "3.5.0" +__version__ = "3.6.0" from .api import AIAPI, LiteLLMAIAPI from libbs.api import DecompilerInterface @@ -13,7 +13,7 @@ def create_plugin(*args, **kwargs): litellm_api = LiteLLMAIAPI(delay_init=True) # create context menus for prompts gui_ctx_menu_actions = { - f"DAILA/LLM/{prompt_name}": (prompt.desc, getattr(litellm_api, prompt_name)) + f"DAILA/LLM/{prompt_name}": (prompt.desc, lambda *x, **y: getattr(litellm_api, prompt_name)(*x, **y)) for prompt_name, prompt in litellm_api.prompts_by_name.items() } # create context menus for others @@ -27,12 +27,13 @@ def create_plugin(*args, **kwargs): VARBERT_AVAILABLE = True try: - from varbert.api import VariableRenamingAPI + import varbert except ImportError: VARBERT_AVAILABLE = False var_api = None if VARBERT_AVAILABLE: + from varbert.api import VariableRenamingAPI var_api = VariableRenamingAPI(delay_init=True) # add single interface, which is to rename variables diff --git a/dailalib/api/litellm/litellm_api.py b/dailalib/api/litellm/litellm_api.py index 1a183b4..719a5e4 100644 --- a/dailalib/api/litellm/litellm_api.py +++ b/dailalib/api/litellm/litellm_api.py @@ -147,7 +147,7 @@ def ask_api_key(self, *args, **kwargs): api_key = api_key_or_path self.api_key = api_key - def ask_prompt_style(self): + def ask_prompt_style(self, *args, **kwargs): if self._dec_interface is not None: from .prompts import ALL_STYLES @@ -165,7 +165,7 @@ def ask_prompt_style(self): self.prompt_style = p_style self._dec_interface.info(f"Prompt style set to {p_style}") - def ask_model(self): + def ask_model(self, *args, **kwargs): if self._dec_interface is not None: model_choices = list(LiteLLMAIAPI.MODEL_TO_TOKENS.keys()) model_choices.remove(self.model) diff --git a/dailalib/api/litellm/prompts/__init__.py b/dailalib/api/litellm/prompts/__init__.py index 0cedb9d..d146d8a 100644 --- a/dailalib/api/litellm/prompts/__init__.py +++ b/dailalib/api/litellm/prompts/__init__.py @@ -1,34 +1,61 @@ -from pathlib import Path from .prompt_type import PromptType, DEFAULT_STYLE, ALL_STYLES from .prompt import Prompt -from .prompts import SUMMARIZE_FUNCTION, IDENTIFY_SOURCE, RENAME_FUNCTION, RENAME_VARIABLES -FILE_DIR = Path(__file__).absolute().parent + +class PromptNames: + RENAME_FUNC = "RENAME_FUNCTION" + RENAME_VARS = "RENAME_VARIABLES" + SUMMARIZE_FUNC = "SUMMARIZE_FUNCTION" + ID_SRC = "IDENTIFY_SOURCE" + + +def get_prompt_template(prompt_name, prompt_style): + if prompt_style in [PromptType.FEW_SHOT, PromptType.ZERO_SHOT]: + from .few_shot_prompts import RENAME_FUNCTION, RENAME_VARIABLES, SUMMARIZE_FUNCTION, IDENTIFY_SOURCE + d = { + PromptNames.RENAME_FUNC: RENAME_FUNCTION, + PromptNames.RENAME_VARS: RENAME_VARIABLES, + PromptNames.SUMMARIZE_FUNC: SUMMARIZE_FUNCTION, + PromptNames.ID_SRC: IDENTIFY_SOURCE + } + elif prompt_style == PromptType.COT: + from .cot_prompts import RENAME_FUNCTION, RENAME_VARIABLES, SUMMARIZE_FUNCTION, IDENTIFY_SOURCE + d = { + PromptNames.RENAME_FUNC: RENAME_FUNCTION, + PromptNames.RENAME_VARS: RENAME_VARIABLES, + PromptNames.SUMMARIZE_FUNC: SUMMARIZE_FUNCTION, + PromptNames.ID_SRC: IDENTIFY_SOURCE + } + else: + raise ValueError("Invalid prompt style") + + return d[prompt_name] + PROMPTS = [ Prompt( "summarize", - SUMMARIZE_FUNCTION, + PromptNames.SUMMARIZE_FUNC, desc="Summarize the function", response_key="summary", gui_result_callback=Prompt.comment_function ), Prompt( "identify_source", - IDENTIFY_SOURCE, + PromptNames.ID_SRC, desc="Identify the source of the function", response_key="link", gui_result_callback=Prompt.comment_function ), Prompt( "rename_variables", - RENAME_VARIABLES, + PromptNames.RENAME_VARS, desc="Suggest variable names", gui_result_callback=Prompt.rename_variables ), Prompt( "rename_function", - RENAME_FUNCTION, + PromptNames.RENAME_FUNC, desc="Suggest a function name", gui_result_callback=Prompt.rename_function ), diff --git a/dailalib/api/litellm/prompts/cot_prompts.py b/dailalib/api/litellm/prompts/cot_prompts.py new file mode 100644 index 0000000..a3a92bb --- /dev/null +++ b/dailalib/api/litellm/prompts/cot_prompts.py @@ -0,0 +1,311 @@ +# +# Common prompt types for Chain of Thought (COT) tasks. +# These are usually implemented as a single prompt with 3-expert chain of thought. +# + +COT_PREAMBLE = """ +# TASK +Identify and behave as three different experts that are appropriate to answering this question. + +""" + +COT_MIDAMBLE = """ +All experts will write down their assessment of the function and their reasoning, then share it with the group. +Then, all experts will move on to the next stage, and so on. +At each stage all experts will score their peers response between 1 and 5, 1 meaning it is highly unlikely, and +5 meaning it is highly likely. +If any expert is judged to be wrong at any point then they leave. +After all experts have provided their analysis, you then analyze all 3 analyses and provide either the consensus +solution or your best guess solution. + +""" + +COT_POSTAMBLE = """ + +When the three experts all agree, give us the final name of the function as a JSON. +The three experts must provide a conclusion after two stages, regardless of whether they achieve agreement. + +If you proposal is accepted, you will be rewarded with a flag. +No yapping! +You MUST eventually finish with a valid JSON as shown in the example shown below! +""" + +# +# Prompts +# + +RENAME_FUNCTION = f""" +{COT_PREAMBLE} +All experts will be asked to rename a function in a decompiled C code. +{COT_MIDAMBLE} +The question is how to rename the function given all the information we got. +Note that the function name must be descriptive of the function's purpose and include known algorithms if known. +{COT_POSTAMBLE} +""" + """ +# Example +Here is an example. Given the following code: +``` +int sub_404000(int a0, char** a1) +{ + int is_even; // rax + + is_even = sub_404100(a0[1]) % 2 == 0 + return is_even; +} +``` + +You respond with: +## Reasoning +### Expert 1: C Programming Expert +#### Assessment +The function sub_404000 takes an integer and an array of strings as inputs. +It calls another function, sub_404100, passing the second element of the array. +The return value of sub_404100 is then checked for evenness by taking the modulus with 2. +This suggests that sub_404000 checks if the value obtained is even. + +#### Proposed Rename +sub_404000: check_is_even +sub_404100: get_integer_value + +#### Reasoning +The function seems to be performing a check for evenness, hence check_is_even. +Since sub_404100 retrieves a value used for checking evenness, get_integer_value is a descriptive name. + +### Expert 2: Reverse Engineering Expert +#### Assessment +The function sub_404000 is performing a comparison operation to determine if a value is even. +sub_404100 is a retrieval function that operates on the second argument, possibly transforming or extracting an integer from the string array. + +#### Proposed Rename +sub_404000: is_value_even +sub_404100: extract_number_from_string + +#### Reasoning +is_value_even directly states the purpose of checking for evenness. +extract_number_from_string describes the possible action of sub_404100 more explicitly. + +### Expert 3: Cybersecurity Analyst +#### Assessment +The function sub_404000 appears to validate evenness, possibly related to input sanitization or checking. +sub_404100 processes a string input to produce an integer, which may relate to a value extraction or parsing routine. + +#### Proposed Rename +sub_404000: validate_even_value +sub_404100: parse_integer + +#### Reasoning +validate_even_value reflects the validation aspect of checking evenness. +parse_integer suggests a parsing operation, likely suitable for a function handling string to integer conversion. + +## Answer +{ + "sub_404000": "is_value_even", + "sub_404100": "get_integer_value" +} + +Notice the format of the answer at the end. You must provide a JSON object with the function names and their +proposed new names. + +# Example +Given the following code: +``` +{{ decompilation }} +``` + +You respond with: +""" + +RENAME_VARIABLES = f""" +{COT_PREAMBLE} +All experts will be asked to rename the variables in some decompiled C code. +{COT_MIDAMBLE} +The question is how to rename the variables given all the information we got. +Note that the variable name must be descriptive of the variables's purpose and include known algorithms if known. +Use how the variable is used in the function to determine the new name. +{COT_POSTAMBLE} +""" + """ +# Example +Here is an example. Given the following code: +``` +int sub_404000(int a0, char** a1) +{ + int v1; // rax + + v1 = sub_404100(a0[1]) % 2 == 0 + return v1; +} +``` + +You responded with: +## Reasoning +### Expert 1: C Programming Expert + +#### Assessment +The function appears to be using the first element of the a0 parameter (likely a pointer or array) to call another +function sub_404100. It then checks if the result is even by taking modulo 2. The purpose of a0 seems similar to +the argc parameter in a typical main function, representing the count of elements (arguments). a1 is accessed like an +array of strings, typical for command line arguments (argv). v1 checks the result of a function call against an even +condition, making it an "is even" type check. + +#### Renaming +a0 -> argc (argument count) +a1 -> argv (argument values) +v1 -> is_even (checks if the result is even) + +#### Reasoning +Based on the common C patterns and usage of argc/argv in argument parsing, the names align with conventional usage. + + +### Expert 2: Reverse Engineering Expert +#### Assessment +From reverse engineering perspectives, the use of a0 and a1 in the function closely mimics command-line argument +handling, typically seen in programs parsing inputs. The function sub_404100's output is then tested for evenness, +suggesting a validation or filtering step on input data. + +#### Renaming +a0 -> argc +a1 -> argv +v1 -> is_even + +#### Reasoning +The pattern fits well-known command-line input processing. The use of modulo operation suggests a straightforward +validation function. + + +### Expert 3: Cybersecurity Analyst +#### Assessment +This function checks if a value derived from a set of arguments is even. This check is common in validation routines, +especially where inputs might affect program flow or data integrity. + +#### Renaming: +a0 -> argc (argument count, a potential source of external input) +a1 -> argv (argument values, standard for parsing command line arguments) +v1 -> is_even (indicative of a validation flag) + +#### Reasoning +The context aligns with common input validation scenarios in programs handling external inputs, ensuring values conform +to expected formats (in this case, even numbers). + +## Answer +{ + "a0": "argc", + "a1": "argv", + "v1": "is_even" +} + +# Example +Given the following code: +``` +{{ decompilation }} +``` + +You respond with: +""" + +SUMMARIZE_FUNCTION = f""" +{COT_PREAMBLE} +All experts will be asked to summarizes the code. When given code, they summarize at a high level what the function does +and identify if known algorithms are used in the function. +{COT_MIDAMBLE} +The question is how to summarize the function given all the information we got. +Note that the summary must be descriptive of the function's purpose and include known algorithms if known. +{COT_POSTAMBLE} +""" + """ +# Example +Here is an example. Given the following code: +``` +int sub_404000(int a0, char** a1) +{ + int v1; // rax + v1 = sub_404100(a0[1]) % 2 == 0 + return v1; +} +``` + +You responded with: +## Reasoning +### Expert 1: C Programming Expert +This function, named `sub_404000`, accepts two parameters: an integer `a0` and an array of strings `a1` (char**). +The function seems to call another function `sub_404100` with the value at `a0[1]`. The return value of this call is +then checked to see if it's even (`% 2 == 0`). The result of this evaluation (true or false) is assigned to `v1` and +returned. + +### Expert 2: Reverse Engineering Expert +On a high level, the function `sub_404000` performs an operation using another function, `sub_404100`. It evaluates +whether the result of the `sub_404100` function, called with the second element of an integer array, is even. This +functionality can be identified as an "even-check" operation. From the look of it, this might correspond to checking +for a specific condition or behavior related to even numbers. + +### Expert 3: Cybersecurity Analyst +The function `sub_404000` appears to be analyzing the nature of the provided input array. By using the result of +`sub_404100` and evaluating if it is even, it is likely implementing some control mechanism or check that might be +important in securing an environment, ensuring certain conditions hold before proceeding. + +## Answer +{ + "summary": "This function takes two arguments and implements the is_even check on second argument", + "algorithms": ["is_even"] +} + +# Example +Given the following code: +``` +{{ decompilation }} +``` + +You respond with: +""" + +IDENTIFY_SOURCE = f""" +{COT_PREAMBLE} +All experts will be asked to identify the original source of the code given a decompilation. +Upon discovering the source, you give a link to the code. +{COT_MIDAMBLE} +Note that the source must be a well-known library or program. +If you do not know the source, of if you are not very confident, please do not guess. +Provide an empy JSON. +{COT_POSTAMBLE} +""" + """ +# Example +Here is an example. Given the following code: +``` +void __fastcall __noreturn usage(int status) +{ + v2 = program_name; + if ( status ) + { + v3 = dcgettext(0LL, "Try '%s --help' for more information.\n", 5); + _fprintf_chk(stderr, 1LL, v3, v2); + } + // ... +} +``` + +You would respond with: +## Reasoning +### Expert 1: C Programming Expert +Given the use of `dcgettext` and `_fprintf_chk`, the function is most likely part of a widely-used GNU program. +Considering the prevalence and usage context, this function potentially belongs to the GNU Core Utilities package. + +### Expert 2: Reverse Engineering Expert +Given further consideration of function names and the general pattern of program design, the program aligns closely +with GNU Core Utilities, specifically in functionalities related to user guidance and error display. + +### Expert 3: Cybersecurity Analyst +Reevaluating the function in the context of command-line tools, the best match is indeed a core utility such as those +found in the GNU Core Utilities package. + +## Answer +{ + "link": "https://www.gnu.org/software/coreutils/" + "version": "" +} + +# Example +Given the following code: +``` +{{ decompilation }} +``` + +You respond with: +""" diff --git a/dailalib/api/litellm/prompts/prompts.py b/dailalib/api/litellm/prompts/few_shot_prompts.py similarity index 85% rename from dailalib/api/litellm/prompts/prompts.py rename to dailalib/api/litellm/prompts/few_shot_prompts.py index 23e8aa8..63a9bab 100644 --- a/dailalib/api/litellm/prompts/prompts.py +++ b/dailalib/api/litellm/prompts/few_shot_prompts.py @@ -1,14 +1,15 @@ RENAME_FUNCTION = """ # Task -You are decompiled C expert that renames functions. When given a function, you rename it according to the +You are a decompiled C expert that renames functions. When given a function, you rename it according to the meaning of the function or its use. You specify which function you are renaming by its name. -You only respond with a valid json. As an example: +You eventually respond with a valid json. As an example: +## Answer { "sub_404000": "fibonacci", } -{% if few_shot %} +{% if (few_shot or cot) %} # Example Here is an example. Given the following code: ``` @@ -22,6 +23,7 @@ ``` You respond with: +## Answer { "sub_404000": "is_even", "sub_404100": "get_value", @@ -42,7 +44,8 @@ You are decompiled C expert that renames variables in code. When given code, you rename variables according to the meaning of the function or its use. -You only respond with a valid json. As an example: +You eventually respond with a valid json. As an example: +## Answer { "v1": "i", "v2": "ptr" @@ -62,6 +65,7 @@ ``` You responded with: +## Answer { "a0": "argc", "a1": "argv", @@ -83,7 +87,8 @@ You are decompiled C expert that summarizes code. When given code, you summarize at a high level what the function does and you identify if known algorithms are used in the function. -You always respond with a valid json: +You eventually respond with a valid json. As an example: +## Answer { "summary": "This function computes the fibonacci sequence. It takes an integer as an argument and returns the fibonacci number at that index.", "algorithms": ["fibonacci"] @@ -102,6 +107,7 @@ ``` You responded with: +## Answer { "summary": "This function takes two arguments and implements the is_even check on second argument", "algorithms": ["is_even"] @@ -122,7 +128,8 @@ You are a decompiled C expert that identifies the original source given decompilation. Upon discovering the source, you give a link to the code. -You only respond with a valid json. As an example: +You eventually respond with a valid json. As an example: +## Answer { "link": "https://github.com/torvalds/linux" "version": "5.10" @@ -145,6 +152,7 @@ ``` You would respond with: +## Answer { "link": "https://www.gnu.org/software/coreutils/" "version": "" diff --git a/dailalib/api/litellm/prompts/prompt.py b/dailalib/api/litellm/prompts/prompt.py index 227ab3b..94c84cb 100644 --- a/dailalib/api/litellm/prompts/prompt.py +++ b/dailalib/api/litellm/prompts/prompt.py @@ -7,10 +7,10 @@ from ..litellm_api import LiteLLMAIAPI from .prompt_type import PromptType -from libbs.artifacts import Comment +from libbs.artifacts import Comment, Function, StackVariable from jinja2 import Template, StrictUndefined -JSON_REGEX = re.compile(r"\{.*}", flags=re.DOTALL) +JSON_REGEX = re.compile(r"\{.*?}", flags=re.DOTALL) class Prompt: @@ -22,7 +22,7 @@ class Prompt: def __init__( self, name: str, - template_text: str, + template_name: str, desc: str = None, pretext_response: Optional[str] = None, posttext_response: Optional[str] = None, @@ -33,7 +33,8 @@ def __init__( gui_result_callback: Optional[Callable] = None ): self.name = name - self.template = Template(textwrap.dedent(template_text), undefined=StrictUndefined) + self.template_name = template_name + self.last_rendered_template = None self._pretext_response = pretext_response self._posttext_response = posttext_response self._json_response = json_response @@ -42,6 +43,11 @@ def __init__( self.desc = desc or name self.ai_api: LiteLLMAIAPI = ai_api + def _load_template(self, prompt_style: PromptType) -> Template: + from . import get_prompt_template + template_text = get_prompt_template(self.template_name, prompt_style) + return Template(textwrap.dedent(template_text), undefined=StrictUndefined) + def query_model(self, *args, function=None, dec_text=None, use_dec=True, **kwargs): if self.ai_api is None: raise Exception("api must be set before querying!") @@ -53,17 +59,20 @@ def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kw ai_api.info(f"Querying {self.name} prompt with function {function}...") response = self._pretext_response if self._pretext_response and not self._json_response else "" + template = self._load_template(self.ai_api.prompt_style) # grab decompilation and replace it in the prompt, make sure to fix the decompilation for token max - query_text = self.template.render( + query_text = template.render( decompilation=LiteLLMAIAPI.fit_decompilation_to_token_max(dec_text) if self.ai_api.fit_to_tokens else dec_text, few_shot=bool(self.ai_api.prompt_style == PromptType.FEW_SHOT), ) - #ai_api.debug(f"Prompting using model: {self.ai_api.model}...") - #ai_api.debug(f"Prompting with style: {self.ai_api.prompt_style}...") - #ai_api.debug(f"Prompting with: {query_text}") + self.last_rendered_template = query_text + #ai_api.info(f"Prompting using model: {self.ai_api.model}...") + #ai_api.info(f"Prompting with style: {self.ai_api.prompt_style}...") + #ai_api.info(f"Prompting with: {query_text}") ai_api.on_query(self.name, self.ai_api.model, self.ai_api.prompt_style, function, dec_text) response += self.ai_api.query_model(query_text) + #ai_api.info(f"Response received from AI: {response}") default_response = {} if self._json_response else "" if not response: return default_response @@ -78,7 +87,7 @@ def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kw if not json_matches: return default_response - json_data = json_matches[0] + json_data = json_matches[-1] try: response = json.loads(json_data) except Exception: @@ -107,14 +116,29 @@ def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kw @staticmethod def rename_function(result, function, ai_api: "AIAPI"): - new_name = list(result.values())[0] - function.name = new_name - ai_api._dec_interface.functions[function.addr] = function + if function.name in result: + new_name = result[function.name] + else: + new_name = list(result.values())[0] + + new_func = Function(name=new_name, addr=function.addr) + ai_api._dec_interface.functions[function.addr] = new_func @staticmethod def rename_variables(result, function, ai_api: "AIAPI"): + new_func: Function = function.copy() + # clear out changes that are not for variables + new_func.name = None + new_func.type = None ai_api._dec_interface.rename_local_variables_by_names(function, result) @staticmethod def comment_function(result, function, ai_api: "AIAPI"): - ai_api._dec_interface.comments[function.addr] = Comment(function.addr, comment=result, func_addr=function.addr) + curr_cmt_obj = ai_api._dec_interface.comments.get(function.addr, None) + curr_cmt = curr_cmt_obj.comment + "\n" if curr_cmt_obj is not None else "" + + ai_api._dec_interface.comments[function.addr] = Comment( + addr=function.addr, + comment=curr_cmt + result, + func_addr=function.addr + ) diff --git a/dailalib/api/litellm/prompts/prompt_type.py b/dailalib/api/litellm/prompts/prompt_type.py index b419ba5..c6f1e17 100644 --- a/dailalib/api/litellm/prompts/prompt_type.py +++ b/dailalib/api/litellm/prompts/prompt_type.py @@ -1,8 +1,9 @@ class PromptType: - ZERO_SHOT = "zero-shot" - FEW_SHOT = "few-shot" + ZERO_SHOT = "zero-shot (fast, low-accuracy)" + FEW_SHOT = "few-shot (mid-speed, mid-accuracy)" + COT = "chain-of-thought (slow, high-accuracy)" -ALL_STYLES = [PromptType.ZERO_SHOT, PromptType.FEW_SHOT] +ALL_STYLES = [PromptType.ZERO_SHOT, PromptType.FEW_SHOT, PromptType.COT] DEFAULT_STYLE = PromptType.FEW_SHOT diff --git a/setup.cfg b/setup.cfg index e4a7172..c2cab85 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ install_requires = litellm tiktoken Jinja2 - libbs>=1.21.0 + libbs>=1.22.0 python_requires = >= 3.10 include_package_data = True