From 477dd87afab8f83439620092e2eebe57801f1856 Mon Sep 17 00:00:00 2001 From: "joxeankoret@yahoo.es" Date: Thu, 22 Feb 2024 18:15:09 +0100 Subject: [PATCH] Multiple bug fixes and little improvements CORE: Try to use `cdifflib` instead of Python's standard `difflib` when possible to get some performance gains. BUG: High addresses in operands could cause the Python's sqlite3 module to crash when inserting into the database. ML: Try to use the Ridge classifier as just another method to get a similarity ratio in `check_ratio`. ML: Simplifications of the supervised learning based experimental engine. CONFIG: Added parameter `COMMIT_AFTER_EACH_GUI_UPDATE` to force committing. CONFIG: Added parameter `EXPORTING_COMPILATION_UNITS` to enable/disable exporting them (with some huge databases it might take even hours!). CONFIG: Added parameters handling SQLite pragmas `SQLITE_JOURNAL_MODE` and `SQLITE_PRAGMA_SYNCHRONOUS`. CONFIG: Added parameter `SHOW_IMPORT_WARNINGS` to enable/disable showing warnings when some important but optional Python packages aren't found. BUG: Be sure to delete orphaned comments when importing pseudo-code comments. BUG: The workaround for "max non-trivial tinfo_t count has been reached" was wrong. Now, the Hex-Rays functions cache is cleared every 10,000 rows. GUI: Display the progress when exporting a large number of compilation units. BUG: Inserting the link between functions and compilation units was terribly-utterly-horribly wrong. VULN: Add pattern "UNC" to potentially detect vulnerabilities fixed in Windows components involving UNC paths. EXTRAS: Added independent IDA plugin `extras/diaphora_local.py` to be able to diff functions inside the current binary. BUG: Do a commit after all functions are exported so, in case IDA crashes for a reason/bug, Diaphora can properly recover from errors and have all the functions already exported there. --- diaphora.py | 100 +++++++--- diaphora_config.py | 33 +++- diaphora_ida.py | 52 +++-- extras/README.md | 13 ++ extras/diaphora_local.py | 368 ++++++++++++++++++++++++++++++++++++ jkutils/IDAMagicStrings.py | 8 +- ml/model.py | 127 +++++++------ scripts/patch_diff_vulns.py | 6 +- 8 files changed, 604 insertions(+), 103 deletions(-) create mode 100644 extras/README.md create mode 100644 extras/diaphora_local.py diff --git a/diaphora.py b/diaphora.py index 9410ba8..7e7c4a3 100755 --- a/diaphora.py +++ b/diaphora.py @@ -35,11 +35,22 @@ from io import StringIO from threading import Lock from multiprocessing import cpu_count -from difflib import SequenceMatcher, unified_diff import diaphora_config as config import diaphora_heuristics +try: + from cdifflib import CSequenceMatcher as SequenceMatcher + HAS_CDIFFLIB = True +except ImportError: + HAS_CDIFFLIB = False + if config.SHOW_IMPORT_WARNINGS: + print("WARNING: Python library 'cdifflib' not found. Installing it will significantly improve text diffing performance.") + print("INFO: Alternatively, you can silence this warning by changing the value of SHOW_IMPORT_WARNINGS in diaphora_config.py.") + from difflib import SequenceMatcher + +from difflib import unified_diff + import ml.model from ml.model import ML_ENABLED, train, predict, get_model_name @@ -709,6 +720,10 @@ def save_instructions_to_database(self, cur, bb_data, func_id): cls=CBytesEncoder, ) ) + elif isinstance(instruction_property, int): + if instruction_property > 0x8000000000000000: + instruction_property = str(instruction_property) + instruction_properties.append(instruction_property) else: instruction_properties.append(instruction_property) @@ -1104,7 +1119,7 @@ def save_function(self, props): insert_args.append([func_id, str(caller), "caller"]) for callee in callees: - insert_args.append([func_id, str(callee), "callee"]) + insert_args.append([func_id, str(callee), "callee"]) cur.executemany(sql, insert_args) # Phase 3: Insert the constants of the function @@ -1880,12 +1895,17 @@ def check_ratio(self, main_d, diff_d): self.ratios_cache[key] = 1.0 return 1.0 - r = max(v1, v2, v3, v4, v5) + v6 = 0.0 + if ML_ENABLED and self.machine_learning: + v6 = self.get_ml_ratio(main_d, diff_d) + + values_set = set([v1, v2, v3, v4, v5, v6]) + r = max(values_set) if r == 1.0 and md1 != md2: # We cannot assign a 1.0 ratio if both MD indices are different, that's an # error r = 0 - for v in [v1, v2, v3, v4, v5]: + for v in values_set: if v != 1.0 and v > r: r = v @@ -1893,7 +1913,10 @@ def check_ratio(self, main_d, diff_d): score = self.deep_ratio(main_d, diff_d, r) if r + score < 1.0: r += score + else: + r = 0.99 + debug_refresh(f"self.ratios_cache[{main_d['name']}-{diff_d['name']}] = {r}") self.ratios_cache[key] = r return r @@ -2868,6 +2891,45 @@ def add_multimatches_to_chooser(self, multi, ignore_list, dones): return ignore_list, dones + def get_ml_ratio(self, main_d, diff_d): + ea1 = int(main_d["ea"]) + ea2 = int(diff_d["ea"]) + + ml_ratio = 0.0 + + cur = self.db_cursor() + sql = "select * from {db}.functions where address = ?" + try: + cur.execute(sql.format(db="main"), (str(ea1),)) + main_row = cur.fetchone() + + cur.execute(sql.format(db="diff"), (str(ea2),)) + diff_row = cur.fetchone() + + ml_add = False + ml_ratio = 0 + if ML_ENABLED and self.machine_learning: + ml_ratio = predict(main_row, diff_row) + if ml_ratio >= config.ML_MIN_PREDICTION_RATIO: + log(f"ML ratio {ml_ratio} for {main_d['name']} - {diff_d['name']}") + ml_add = True + else: + ml_ratio = 0.0 + + if ml_add: + vfname1 = main_d["name"] + vfname2 = diff_d["name"] + nodes1 = main_d["nodes"] + nodes2 = diff_d["nodes"] + desc = f"ML {get_model_name()}" + + tmp_item = CChooser.Item(ea1, vfname1, ea2, vfname2, desc, ml_ratio, nodes1, nodes2) + self.ml_chooser.add_item(tmp_item) + finally: + cur.close() + + return ml_ratio + def deep_ratio(self, main_d, diff_d, ratio): """ Try to get a score to add to the value returned by `check_ratio()` so less @@ -2940,25 +3002,6 @@ def deep_ratio(self, main_d, diff_d, ratio): set_result = set1.intersection(set2) if len(set_result) > 0: score += len(set_result) * 0.0005 - - ml_add = False - if ML_ENABLED and self.machine_learning: - ml_ratio = predict(main_row, diff_row, ratio) - if ml_ratio > 0: - debug_refresh(f"ML ratio {ml_ratio} for {main_d['name']} - {diff_d['name']}") - score += config.ML_DEEP_RATIO_ADDED_SCORE - ml_add = True - - if ml_add: - vfname1 = main_d["name"] - vfname2 = diff_d["name"] - nodes1 = main_d["nodes"] - nodes2 = diff_d["nodes"] - desc = f"ML {get_model_name()}" - - tmp_item = CChooser.Item(ea1, vfname1, ea2, vfname2, desc, ratio, nodes1, nodes2) - self.ml_chooser.add_item(tmp_item) - finally: cur.close() @@ -3626,6 +3669,17 @@ def train_local_model(self): debug_refresh("[i] Machine learning module enabled.") train(self, self.all_matches) + def get_callers_callees(self, db_name, func_id): + cur = self.db_cursor() + rows = [] + try: + sql = "select * from {db}.callgraph where func_id = ?" + cur.execute(sql.format(db=db_name), (func_id,)) + rows = list(cur.fetchall()) + finally: + cur.close() + return rows + def diff(self, db): """ Diff the current two databases (main and diff). diff --git a/diaphora_config.py b/diaphora_config.py index 45dad2c..691eb6b 100644 --- a/diaphora_config.py +++ b/diaphora_config.py @@ -60,6 +60,12 @@ # Number of rows that must be inserted to commit the transaction EXPORTING_FUNCTIONS_TO_COMMIT = 5000 +# Every time the GUI export dialog is updated a commit is issued. This is useful +# whenever we are facing long export times with known IDA bugs that might cause +# it to fail at an unknown moment and we want to recover from errors. You might +# want to set it to False if you're finding small little performance wins. +COMMIT_AFTER_EACH_GUI_UPDATE = True + # The minimum number of functions in a database to, by default, disable running # slow queries. MIN_FUNCTIONS_TO_DISABLE_SLOW = 4001 @@ -73,13 +79,23 @@ # Block size to use to generate fuzzy hashes for pseudo-codes with DeepToad FUZZY_HASHING_BLOCK_SIZE = 512 +# Use it to disable finding compilation units. In some rare cases, there are too +# many compilation units and Diaphora might take very long to find them. +EXPORTING_COMPILATION_UNITS = True + ################################################################################ -# Default SQL related configuration options +# Default SQL and SQLite related configuration options # Diaphora won't process more than the given value of rows (per heuristic) SQL_MAX_PROCESSED_ROWS = 1000000 # SQL queries will timeout after the given number of seconds SQL_TIMEOUT_LIMIT = 60 * 5 +# Set this to DELETE, TRUNCATE, PERSIST, MEMORY, WAL, OFF, or None to use the +# default value. +SQLITE_JOURNAL_MODE = "MEMORY" +# Set this to 0/OFF, 1/NORMAL, 2/FULL, 3/EXTRA, or None to use the default +# value. +SQLITE_PRAGMA_SYNCHRONOUS = "1" ################################################################################ # Heuristics related configuration options @@ -189,12 +205,23 @@ # What is the minimum ratio required for a match to be considered for usage to # train a local model? -ML_MATCHES_MIN_RATIO = 0.5 +ML_MATCHES_MIN_RATIO = 0.6 +ML_MIN_PREDICTION_RATIO = 0.72 # What value should be added to the final similarity ratio when the specialized # classifier (trained with known good and bad results found for the current two # binaries being compared) finds what it thinks is a good match. -ML_DEEP_RATIO_ADDED_SCORE = 0.04 +ML_DEEP_RATIO_ADDED_SCORE = 0.1 # Show a chooser with all the matches that the classifier think are good ones? ML_DEBUG_SHOW_MATCHES = True + +#------------------------------------------------------------------------------- +# Some imports improve performance or add features to Diaphora but aren't 100% +# required. Diaphora will warn the reverser when these libraries failed to be +# imported. Change this directive to shutup this warning. +SHOW_IMPORT_WARNINGS = True + +#------------------------------------------------------------------------------- +# Workarounds for IDA bugs +DIAPHORA_WORKAROUND_MAX_TINFO_T = True diff --git a/diaphora_ida.py b/diaphora_ida.py index e6a5924..d380123 100644 --- a/diaphora_ida.py +++ b/diaphora_ida.py @@ -1152,9 +1152,16 @@ def recalculate_primes(self): return callgraph_primes, callgraph_all_primes def commit_and_start_transaction(self): - self.db.commit() - self.db.execute("PRAGMA synchronous = OFF") - self.db.execute("PRAGMA journal_mode = MEMORY") + try: + self.db.execute("commit") + except sqlite3.OperationalError as e: + # Ignore the "cannot commit - no transaction active" error + pass + + if config.SQLITE_PRAGMA_SYNCHRONOUS is not None: + self.db.execute(f"PRAGMA synchronous = {config.SQLITE_PRAGMA_SYNCHRONOUS}") + if config.SQLITE_JOURNAL_MODE is not None: + self.db.execute(f"PRAGMA journal_mode = {config.SQLITE_JOURNAL_MODE}") self.db.execute("BEGIN transaction") def do_export(self, crashed_before=False): @@ -1185,10 +1192,12 @@ def do_export(self, crashed_before=False): self._funcs_cache = {} for func in func_list: if user_cancelled(): - raise Exception("Canceled.") + raise Exception("Cancelled.") i += 1 if (total_funcs >= 100) and i % (int(total_funcs / 100)) == 0 or i == 1: + if config.COMMIT_AFTER_EACH_GUI_UPDATE: + self.commit_and_start_transaction() line = "Exported %d function(s) out of %d total.\nElapsed %d:%02d:%02d second(s), remaining time ~%d:%02d:%02d" elapsed = time.monotonic() - t remaining = (elapsed / i) * (total_funcs - i) @@ -1231,6 +1240,7 @@ def do_export(self, crashed_before=False): if i % (total_funcs / 10) == 0: self.commit_and_start_transaction() + self.commit_and_start_transaction() md5sum = GetInputFileMD5() self.save_callgraph( str(callgraph_primes), json.dumps(callgraph_all_primes), md5sum @@ -1240,7 +1250,9 @@ def do_export(self, crashed_before=False): self.export_til() except: log(f"Error reading type libraries: {str(sys.exc_info()[1])}") - self.save_compilation_units() + + if config.EXPORTING_COMPILATION_UNITS: + self.save_compilation_units() log_refresh("Creating indices...") self.create_indices() @@ -1898,6 +1910,7 @@ def import_instruction(self, ins_data1, ins_data2): comment = mcmt cfunc.set_user_cmt(tl, comment) + cfunc.del_orphan_cmts() cfunc.save_user_cmts() tmp_ea = None @@ -2342,8 +2355,9 @@ def decompile_and_get(self, ea): # # max non-trivial tinfo_t count has been reached # - if os.getenv("DIAPHORA_WORKAROUND_MAX_TINFO_T") is not None: - idaapi.clear_cached_cfuncs() + if config.DIAPHORA_WORKAROUND_MAX_TINFO_T: + if len(self._funcs_cache) % 10000 == 0: + idaapi.clear_cached_cfuncs() decompiler_plugin = os.getenv("DIAPHORA_DECOMPILER_PLUGIN") if decompiler_plugin is None: @@ -3225,7 +3239,9 @@ def get_modules_using_lfa(self): return new_modules def save_compilation_units(self): + log_refresh("Finding compilation units...") lfa_modules = self.get_modules_using_lfa() + log_refresh("Saving compilation units...") sql1 = """insert into compilation_units (name, start_ea, end_ea) values (?, ?, ?)""" @@ -3241,7 +3257,12 @@ def save_compilation_units(self): cur = self.db_cursor() try: dones = set() - for module in lfa_modules: + total = len(lfa_modules) + checkpoint = int(total / 10) + for i, module in enumerate(lfa_modules): + if i > 0 and checkpoint > 0 and i % checkpoint == 0: + log_refresh(f"Processing compilation unit {i} out of {total}...") + module_name = None if module["name"] != "": module_name = module["name"] @@ -3250,12 +3271,15 @@ def save_compilation_units(self): cur.execute(sql1, vals) cu_id = cur.lastrowid - for values in self._funcs_cache.values(): - func_id = values[0] - if func_id not in dones: - dones.add(func_id) - cur.execute(sql2, (cu_id, func_id)) - cur.execute(sql4, (module_name, func_id)) + for func in self._funcs_cache: + item = self._funcs_cache[func] + func = int(func) + if func >= module["start"] and func <= module["end"]: + func_id = item[0] + if func_id not in dones: + dones.add(func_id) + cur.execute(sql2, (cu_id, func_id)) + cur.execute(sql4, (module_name, func_id)) cur.execute( sql3, diff --git a/extras/README.md b/extras/README.md new file mode 100644 index 0000000..c4e1fb9 --- /dev/null +++ b/extras/README.md @@ -0,0 +1,13 @@ +# Diaphora local + +This is a pure Python IDA plugin to diff pseudo-codes and assembly for functions inside the current binary, instead of diffing functions in different binaries. + +# Installation + +Simply copy this script in the directory `$IDA_DIR/plugins`. + +# Usage + +Put the cursor in IDA in some function and press Ctrl + Shift + D, choose the function to diff against the current function and 2 choosers (windows) will open showing the differences at pseudo-code and assembly levels. + + diff --git a/extras/diaphora_local.py b/extras/diaphora_local.py new file mode 100644 index 0000000..006b447 --- /dev/null +++ b/extras/diaphora_local.py @@ -0,0 +1,368 @@ +#!/usr/bin/python3 + +""" +Diaphora, a binary diffing tool +Copyright (c) 2015-2024, Joxean Koret + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . +""" + +import re +import os +import sys +import time +import difflib + +import idc +import idaapi +import idautils + +from PyQt5 import QtWidgets +from pygments import highlight +from pygments.formatters import HtmlFormatter +from pygments.lexers import NasmLexer, CppLexer, DiffLexer + +#------------------------------------------------------------------------------- +DIFF_COLOR_ADDED = "#aaffaa" +DIFF_COLOR_CHANGED = "#ffff77" +DIFF_COLOR_SUBTRACTED = "#ffaaaa" +DIFF_COLOR_LINE_NO = "#e0e0e0" + +#------------------------------------------------------------------------------- +def log(msg): + idaapi.msg(f"[{time.asctime()}] {msg}") + +#------------------------------------------------------------------------------- +def do_decompile(f): + return idaapi.decompile(f, flags=idaapi.DECOMP_NO_WAIT) + +#------------------------------------------------------------------------------- +class CHtmlDiff: + """A replacement for difflib.HtmlDiff that tries to enforce a max width + + The main challenge is to do this given QTextBrowser's limitations. In + particular, QTextBrowser only implements a minimum of CSS. + """ + + _html_template = """ + + + + + + + %(rows)s +
+ + + """ + + _style = ( + """ + table.diff_tab { + font-family: Courier, monospace; + table-layout: fixed; + width: 100%; + } + + .diff_add { + background-color: """ + + DIFF_COLOR_ADDED + + """; + } + .diff_chg { + background-color: """ + + DIFF_COLOR_CHANGED + + """; + } + .diff_sub { + background-color: """ + + DIFF_COLOR_SUBTRACTED + + """; + } + .diff_lineno { + text-align: right; + background-color: """ + + DIFF_COLOR_LINE_NO + + """; + } + """ + ) + + _row_template = """ + + %s + %s + %s + %s + + """ + + _rexp_too_much_space = re.compile("^\t[.\\w]+ {8}") + + def make_file(self, lhs, rhs, fmt, lex): + rows = [] + for left, right, changed in difflib._mdiff(lhs, rhs): + lno, ltxt = left + rno, rtxt = right + + if not changed: + ltxt = highlight(ltxt, lex, fmt) + rtxt = highlight(rtxt, lex, fmt) + else: + ltxt = self._stop_wasting_space(ltxt) + rtxt = self._stop_wasting_space(rtxt) + + ltxt = ltxt.replace(" ", " ") + rtxt = rtxt.replace(" ", " ") + ltxt = ltxt.replace("<", "<") + ltxt = ltxt.replace(">", ">") + rtxt = rtxt.replace("<", "<") + rtxt = rtxt.replace(">", ">") + + row = self._row_template % (str(lno), ltxt, str(rno), rtxt) + rows.append(row) + + all_the_rows = "\n".join(rows) + all_the_rows = ( + all_the_rows.replace("\x00+", '') + .replace("\x00-", '') + .replace("\x00^", '') + .replace("\x01", "") + .replace("\t", 4 * " ") + ) + + res = self._html_template % {"style": self._style, "rows": all_the_rows} + return res + + def _stop_wasting_space(self, s): + """I never understood why you'd want to have 13 spaces between instruction and args'""" + m = self._rexp_too_much_space.search(s) + if m: + mlen = len(m.group(0)) + return s[: mlen - 4] + s[mlen:] + else: + return s + +#------------------------------------------------------------------------------- +class CHtmlViewer(idaapi.PluginForm): + """ + Class used to graphically show the differences. + """ + + def OnCreate(self, form): + self.parent = self.FormToPyQtWidget(form) + self.PopulateForm() + + self.browser = None + self.layout = None + return 1 + + def PopulateForm(self): + self.layout = QtWidgets.QVBoxLayout() + self.browser = QtWidgets.QTextBrowser() + self.browser.setLineWrapMode(QtWidgets.QTextEdit.FixedColumnWidth) + self.browser.setLineWrapColumnOrWidth(150) + self.browser.setHtml(self.text) + self.browser.setReadOnly(True) + self.layout.addWidget(self.browser) + self.parent.setLayout(self.layout) + + def Show(self, text, title): + self.text = text + return idaapi.PluginForm.Show(self, title) + +#------------------------------------------------------------------------------- +def decompile_and_get(ea): + decompiler_plugin = os.getenv("DIAPHORA_DECOMPILER_PLUGIN") + if decompiler_plugin is None: + decompiler_plugin = "hexrays" + if not idaapi.init_hexrays_plugin() and not ( + load_plugin(decompiler_plugin) and idaapi.init_hexrays_plugin() + ): + return False + + f = idaapi.get_func(ea) + if f is None: + return False + + cfunc = do_decompile(f) + if cfunc is None: + # Failed to decompile + return False + + sv = cfunc.get_pseudocode() + lines = [] + first_line = None + for sline in sv: + line = idaapi.tag_remove(sline.line) + if line.startswith("//"): + continue + + if first_line is None: + first_line = line + else: + lines.append(line) + + return first_line, "\n".join(lines) + +#------------------------------------------------------------------------------- +def get_disasm(ea): + mnem = idc.print_insn_mnem(ea) + op1 = idc.print_operand(ea, 0) + op2 = idc.print_operand(ea, 1) + line = f"{mnem.ljust(8)} {op1}" + if op2 != "": + line += f", {op2}" + return line + +#------------------------------------------------------------------------------- +def get_assembly(ea): + f = int(ea) + func = idaapi.get_func(f) + if not func: + log("Cannot get a function object for 0x%x" % f) + return False + + lines = [] + flow = idaapi.FlowChart(func) + for block in flow: + if block.end_ea == 0 or block.end_ea == idaapi.BADADDR: + log("0x%08x: Skipping bad basic block" % f) + continue + + if block.start_ea != func.start_ea: + lines.append("loc_%08x:" % (block.start_ea)) + for head in idautils.Heads(block.start_ea, block.end_ea): + lines.append(" %s" % (get_disasm(head))) + + return "\n".join(lines) + +#------------------------------------------------------------------------------- +class CLocalDiffer: + def __init__(self): + pass + + def get_pseudo_diff_data(self, ea1, ea2): + html_diff = CHtmlDiff() + tmp = decompile_and_get(int(ea1)) + if not tmp: + log("[i] Cannot get the pseudo-code for the current function") + return False + proto1, tmp1 = tmp + buf1 = proto1 + "\n" + tmp1 + + tmp = decompile_and_get(int(ea2)) + if not tmp: + log("Cannot get the pseudo-code for the second function") + return False + proto2, tmp2 = tmp + buf2 = proto2 + "\n" + tmp2 + + if buf1 == buf2: + warning("Both pseudo-codes are equal.") + + fmt = HtmlFormatter() + fmt.noclasses = True + fmt.linenos = False + fmt.nobackground = True + src = html_diff.make_file( + buf1.split("\n"), buf2.split("\n"), fmt, CppLexer() + ) + + name1 = idaapi.get_func_name(int(ea1)) + name2 = idaapi.get_func_name(int(ea2)) + title = f'Diff pseudo-code {name1} - {name2}' + res = (src, title) + return res + + def get_asm_diff_data(self, ea1, ea2): + html_diff = CHtmlDiff() + asm1 = get_assembly(ea1) + asm2 = get_assembly(ea2) + name1 = idaapi.get_func_name(int(ea1)) + name2 = idaapi.get_func_name(int(ea2)) + buf1 = f'{name1} proc near\n{asm1}\n{name1} endp' + buf2 = f'{name2} proc near\n{asm2}\n{name2} endp' + + fmt = HtmlFormatter() + fmt.noclasses = True + fmt.linenos = False + fmt.nobackground = True + src = html_diff.make_file( + buf1.split("\n"), buf2.split("\n"), fmt, NasmLexer() + ) + + title = f"Diff assembly {name1} - {name2}" + res = (src, title) + return res + + def diff_pseudo(self, main_ea, diff_ea): + res = self.get_pseudo_diff_data(main_ea, diff_ea) + self.show_res(res) + + def diff_assembly(self, main_ea, diff_ea): + res = self.get_asm_diff_data(main_ea, diff_ea) + self.show_res(res) + + def show_res(self, res): + if res: + (src, title) = res + cdiffer = CHtmlViewer() + cdiffer.Show(src, title) + + def diff(self, main_ea, diff_ea): + self.diff_assembly(main_ea, diff_ea) + self.diff_pseudo(main_ea, diff_ea) + +#------------------------------------------------------------------------------- +class myplugin_t(idaapi.plugin_t): + flags = idaapi.PLUGIN_UNL + comment = "Locally diff functions" + help = "Tool to diff functions inside this database" + wanted_name = "Diaphora: Diff Local Function" + wanted_hotkey = "Ctrl+Shift+D" + + def init(self): + return idaapi.PLUGIN_OK + + def run(self, arg): + main() + + def term(self): + pass + +def PLUGIN_ENTRY(): + return myplugin_t() + +#------------------------------------------------------------------------------- +def main(): + ea = idc.get_screen_ea() + func = idaapi.get_func(ea) + if func is None: + warning("Please place the cursor over a function before calling this plugin.") + return + + func_name = idaapi.get_func_name(ea) + line = f"Select the function to diff {func_name} against" + diff_ea = idc.choose_func(line) + if diff_ea == idaapi.BADADDR: + return + + log("Selected function address 0x%08x\n" % diff_ea) + differ = CLocalDiffer() + differ.diff(ea, diff_ea) + +if __name__ == "__main__": + main() diff --git a/jkutils/IDAMagicStrings.py b/jkutils/IDAMagicStrings.py index be5a766..f6605f8 100644 --- a/jkutils/IDAMagicStrings.py +++ b/jkutils/IDAMagicStrings.py @@ -27,14 +27,18 @@ except ImportError as e: print(f'{os.path.basename(__file__)} importerror {e}') +sys.path.append("..") +import diaphora_config as config + try: import nltk from nltk.tokenize import word_tokenize from nltk.tag import pos_tag has_nltk = True except ImportError as e: - print("NLTK is not installed. It's recommended to install python-nltk.") - print("It's optional, but significantly improves the results.") + if config.SHOW_IMPORT_WARNINGS: + print("WARNING: NLTK is not installed. It's recommended to install python-nltk. It's optional, but significantly improves the results.") + print("INFO: Alternatively, you can silence this warning by changing the value of SHOW_IMPORT_WARNINGS in diaphora_config.py.") has_nltk = False #------------------------------------------------------------------------------- diff --git a/ml/model.py b/ml/model.py index b8082bf..9e43ec7 100644 --- a/ml/model.py +++ b/ml/model.py @@ -6,7 +6,10 @@ import json import random -from difflib import SequenceMatcher +try: + from cdifflib import CSequenceMatcher as SequenceMatcher +except ImportError: + from difflib import SequenceMatcher #------------------------------------------------------------------------------- try: @@ -53,14 +56,52 @@ def quick_ratio(buf1 : str, buf2 : str) -> float: """ if buf1 is None or buf2 is None or buf1 == "" or buf1 == "": return 0 - seq = SequenceMatcher(None, buf1.split("\n"), buf2.split("\n")) - return seq.quick_ratio() + + if buf1 == buf2: + return 1.0 + + s1 = buf1.lower().split('\n') + s2 = buf2.lower().split('\n') + seq = SequenceMatcher(None, s1, s2) + return seq.ratio() + +#------------------------------------------------------------------------------- +def int_compare_ratio(value1 : int, value2 : int) -> float: + """ + Get a similarity ratio for two integers. + """ + if value1 + value2 == 0: + val = 1.0 + else: + val = 1 - ( abs(value1 - value2) / max(value1, value2) ) + return val + +#------------------------------------------------------------------------------- +def count_callers_callees(db_name : str, func_id : int): + """ + Count the callers and the callees for the given @func_id in @db_name database. + """ + global ml_model + calls = ml_model.diaphora.get_callers_callees(db_name, func_id) + callees = 0 + callers = 0 + for call in calls: + call_type = call["type"] + if call_type == 'callee': + callees += 1 + elif call_type == 'caller': + callers += 1 + return callers, callees #------------------------------------------------------------------------------- def compare_rows(row1 : list, row2 : list) -> list[float]: + """ + Compare two function rows and calculate a similarity ratio for it. + """ scores = [] keys = list(row1.keys()) IGNORE = ["id", "db_name", "export_time"] + for key in keys: if key in IGNORE: continue @@ -73,10 +114,7 @@ def compare_rows(row1 : list, row2 : list) -> list[float]: continue if type(value1) is int: - if value1 + value2 == 0: - val = 1.0 - else: - val = 1 - ( abs(value1 - value2) / max(value1, value2) ) + val = int_compare_ratio(value1, value2) scores.append(val) elif type(value1) is str: if value1.startswith('["') and value2.startswith('["'): @@ -94,6 +132,13 @@ def compare_rows(row1 : list, row2 : list) -> list[float]: scores.append(val) else: scores.append(value1 == value2) + + + main_callers, main_callees = count_callers_callees("main", row1["id"]) + diff_callers, diff_callees = count_callers_callees("diff", row2["id"]) + scores.append(int_compare_ratio(main_callers, diff_callees)) + scores.append(int_compare_ratio(diff_callers, diff_callers)) + return scores #------------------------------------------------------------------------------- @@ -102,13 +147,14 @@ def __init__(self, diaphora_obj : object): self.diaphora = diaphora_obj self.clf = RidgeClassifier() self.matches = [] - self.primary = {} - self.secondary = {} self.fitted = False self.model = None def find_matches(self, matches : list): + """ + Find appropriate good matches to build a dataset. + """ for group in matches: if group in ["best", "partial"]: for match in matches[group]: @@ -120,48 +166,17 @@ def find_matches(self, matches : list): self.matches = np.array(self.matches) def get_features(self, row : dict) -> list: + """ + Convert the function's row dict to a list. + """ l = [] for col in COLUMNS: l.append(row[col]) return l - def db_query_values(self): - cur = self.diaphora.db_cursor() - try: - sql = "create temporary table model_functions(db_name, name)" - cur.execute(sql) - - vals = [ ["main", 0], ["diff", 1] ] - for db_name, idx in vals: - sql = "insert into model_functions values ('{db}', ?)" - query = sql.format(db=db_name) - for name in self.matches[:,idx]: - cur.execute(query, (name,)) - - sql = """ select distinct * - from {db}.functions f - where name in (select name - from model_functions - where db_name = ?) """ - for db_name, idx in vals: - query = sql.format(db=db_name) - if db_name == "main": - d = self.primary - else: - d = self.secondary - - cur.execute(query, (db_name,)) - while 1: - row = cur.fetchone() - if not row: - break - - features = self.get_features(row) - d[row["name"]] = features - finally: - cur.close() - - def train_local_model(self) -> float: + def train_local_model(self) -> bool: + max_size = len(self.matches) + self.diaphora.log(f"Building dataset for a maximum of {max_size} x {max_size} ({max_size*max_size})") X = [] Y = [] total_round = 0 @@ -189,6 +204,7 @@ def train_local_model(self) -> float: ratio = self.diaphora.compare_function_rows(row1, row2) features1 = self.get_features(row1) features2 = self.get_features(row2) + comparisons = compare_rows(row1, row2) final = features1 + features2 + comparisons final = convert2numbers(final) @@ -203,7 +219,7 @@ def train_local_model(self) -> float: else: ratio = 0.0 - y = [ round(ratio), ] + y = [ ratio, ] X.append(x) Y.append(y) @@ -211,10 +227,9 @@ def train_local_model(self) -> float: X = np.array(X) Y = np.array(Y) + self.diaphora.log("Done building dataset") if found_some_good: - self.clf.fit(X, Y) - calibrator = CalibratedClassifierCV(self.clf, cv='prefit') - self.model = calibrator.fit(X, Y) + self.model = self.clf.fit(X, Y) self.diaphora.log(f"ML model score {self.clf.score(X, Y)}") else: self.diaphora.log(f"The ML model did not find any good enough match to use for training") @@ -224,7 +239,6 @@ def train_local_model(self) -> float: def train(self, matches : list): self.find_matches(matches) if len(self.matches) > 0: - self.db_query_values() self.diaphora.log_refresh("Training local model...") self.fitted = self.train_local_model() self.diaphora.log_refresh("Done training local model...") @@ -232,13 +246,8 @@ def train(self, matches : list): def predict(self, row : dict) -> float: ret = 0.0 if self.fitted: - if 'predict_proba' in dir(self.clf): - ret = self.clf.predict(row) - if ret[0] == 1: - ret = self.model.predict_proba(row) - return round(ret[0][1], 3) - else: - ret = self.clf.predict(row) + d = self.clf.decision_function(row)[0] + ret = np.exp(d) / (1 + np.exp(d)) return ret #------------------------------------------------------------------------------- @@ -248,7 +257,7 @@ def train(diaphora_obj : object, matches : list): ml_model.train(matches) #------------------------------------------------------------------------------- -def predict(main_row : dict, diff_row : dict, ratio : float) -> float: +def predict(main_row : dict, diff_row : dict) -> float: global ml_model ratio = 0.0 if ml_model is not None: diff --git a/scripts/patch_diff_vulns.py b/scripts/patch_diff_vulns.py index 1474758..f351615 100644 --- a/scripts/patch_diff_vulns.py +++ b/scripts/patch_diff_vulns.py @@ -24,7 +24,9 @@ # Windows 'unsafe' APIs "ShellExecute", "WinExec", "LoadLibrary", "CreateProcess", # Functions that may be interesting for Windows kernel drivers - "ProbeForWrite", "ProbeForRead" + "ProbeForWrite", "ProbeForRead", + # UNC paths related pattern + "UNC" ] COMPARISONS = [" < ", " > ", " <= ", " >= "] @@ -220,7 +222,7 @@ def on_match(self, func1, func2, description, ratio): if results.found: # Report matches while it's still finding for exciting reversers msg = f"Potentially interesting patch found (pattern {repr(results.description)}): {name1} - {name2}" - log(msg) + self.diaphora.log_refresh(msg) log(f"> {results.line}") # And finally add the item in the chooser we created