diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index c32b449f7dc6b..fc93bf62a67fa 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -7,6 +7,10 @@ import os from glob import glob import sqlite3 +import json +import csv +from typing import Optional, Union +from collections.abc import Iterator, Sequence try: import git @@ -17,6 +21,28 @@ logger = logging.getLogger("compare-llama-bench") +# All llama-bench SQL fields +DB_FIELDS = [ + "build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename", + "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", + "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", + "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides", + "defrag_thold", + "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", + "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", +] + +DB_TYPES = [ + "TEXT", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT", + "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", + "TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER", + "TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT", + "REAL", + "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", + "TEXT", "INTEGER", "INTEGER", "REAL", "REAL", +] +assert len(DB_FIELDS) == len(DB_TYPES) + # Properties by which to differentiate results per commit: KEY_PROPERTIES = [ "cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type", @@ -42,7 +68,7 @@ GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables. MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"} -DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux): +DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux): $ git checkout master $ make clean && make llama-bench @@ -70,12 +96,13 @@ ) parser.add_argument("-c", "--compare", help=help_c) help_i = ( - "Input SQLite file for comparing commits. " + "JSON/JSONL/SQLite/CSV files for comparing commits. " + "Specify multiple times to use multiple input files (JSON/CSV only). " "Defaults to 'llama-bench.sqlite' in the current working directory. " "If no such file is found and there is exactly one .sqlite file in the current directory, " "that file is instead used as input." ) -parser.add_argument("-i", "--input", help=help_i) +parser.add_argument("-i", "--input", action="append", help=help_i) help_o = ( "Output format for the table. " "Defaults to 'pipe' (GitHub compatible). " @@ -110,119 +137,321 @@ sys.exit(1) input_file = known_args.input -if input_file is None and os.path.exists("./llama-bench.sqlite"): - input_file = "llama-bench.sqlite" -if input_file is None: +if not input_file and os.path.exists("./llama-bench.sqlite"): + input_file = ["llama-bench.sqlite"] +if not input_file: sqlite_files = glob("*.sqlite") if len(sqlite_files) == 1: - input_file = sqlite_files[0] + input_file = sqlite_files -if input_file is None: +if not input_file: logger.error("Cannot find a suitable input file, please provide one.\n") parser.print_help() sys.exit(1) -connection = sqlite3.connect(input_file) -cursor = connection.cursor() -build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0] -build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0] +class LlamaBenchData: + repo: Optional[git.Repo] + build_len_min: int + build_len_max: int + build_len: int = 8 + builds: list[str] = [] + check_keys = set(KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"]) -if build_len_min != build_len_max: - logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. " - "Try purging the the database of old commits.") - cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});") + def __init__(self): + try: + self.repo = git.Repo(".", search_parent_directories=True) + except git.InvalidGitRepositoryError: + self.repo = None -build_len: int = build_len_min + def _builds_init(self): + self.build_len = self.build_len_min -builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() -builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str] + def _check_keys(self, keys: set) -> Optional[set]: + """Private helper method that checks against required data keys and returns missing ones.""" + if not keys >= self.check_keys: + return self.check_keys - keys + return None -if not builds: - raise RuntimeError(f"{input_file} does not contain any builds.") + def find_parent_in_data(self, commit: git.Commit) -> Optional[str]: + """Helper method to find the most recent parent measured in number of commits for which there is data.""" + heap: list[tuple[int, git.Commit]] = [(0, commit)] + seen_hexsha8 = set() + while heap: + depth, current_commit = heapq.heappop(heap) + current_hexsha8 = commit.hexsha[:self.build_len] + if current_hexsha8 in self.builds: + return current_hexsha8 + for parent in commit.parents: + parent_hexsha8 = parent.hexsha[:self.build_len] + if parent_hexsha8 not in seen_hexsha8: + seen_hexsha8.add(parent_hexsha8) + heapq.heappush(heap, (depth + 1, parent)) + return None -try: - repo = git.Repo(".", search_parent_directories=True) -except git.InvalidGitRepositoryError: - repo = None - - -def find_parent_in_data(commit: git.Commit): - """Helper function to find the most recent parent measured in number of commits for which there is data.""" - heap: list[tuple[int, git.Commit]] = [(0, commit)] - seen_hexsha8 = set() - while heap: - depth, current_commit = heapq.heappop(heap) - current_hexsha8 = commit.hexsha[:build_len] - if current_hexsha8 in builds: - return current_hexsha8 - for parent in commit.parents: - parent_hexsha8 = parent.hexsha[:build_len] - if parent_hexsha8 not in seen_hexsha8: - seen_hexsha8.add(parent_hexsha8) - heapq.heappush(heap, (depth + 1, parent)) - return None - - -def get_all_parent_hexsha8s(commit: git.Commit): - """Helper function to recursively get hexsha8 values for all parents of a commit.""" - unvisited = [commit] - visited = [] - - while unvisited: - current_commit = unvisited.pop(0) - visited.append(current_commit.hexsha[:build_len]) - for parent in current_commit.parents: - if parent.hexsha[:build_len] not in visited: - unvisited.append(parent) - - return visited - - -def get_commit_name(hexsha8: str): - """Helper function to find a human-readable name for a commit if possible.""" - if repo is None: + def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]: + """Helper method to recursively get hexsha8 values for all parents of a commit.""" + unvisited = [commit] + visited = [] + + while unvisited: + current_commit = unvisited.pop(0) + visited.append(current_commit.hexsha[:self.build_len]) + for parent in current_commit.parents: + if parent.hexsha[:self.build_len] not in visited: + unvisited.append(parent) + + return visited + + def get_commit_name(self, hexsha8: str) -> str: + """Helper method to find a human-readable name for a commit if possible.""" + if self.repo is None: + return hexsha8 + for h in self.repo.heads: + if h.commit.hexsha[:self.build_len] == hexsha8: + return h.name + for t in self.repo.tags: + if t.commit.hexsha[:self.build_len] == hexsha8: + return t.name return hexsha8 - for h in repo.heads: - if h.commit.hexsha[:build_len] == hexsha8: - return h.name - for t in repo.tags: - if t.commit.hexsha[:build_len] == hexsha8: - return t.name - return hexsha8 - - -def get_commit_hexsha8(name: str): - """Helper function to search for a commit given a human-readable name.""" - if repo is None: + + def get_commit_hexsha8(self, name: str) -> Optional[str]: + """Helper method to search for a commit given a human-readable name.""" + if self.repo is None: + return None + for h in self.repo.heads: + if h.name == name: + return h.commit.hexsha[:self.build_len] + for t in self.repo.tags: + if t.name == name: + return t.commit.hexsha[:self.build_len] + for c in self.repo.iter_commits("--all"): + if c.hexsha[:self.build_len] == name[:self.build_len]: + return c.hexsha[:self.build_len] return None - for h in repo.heads: - if h.name == name: - return h.commit.hexsha[:build_len] - for t in repo.tags: - if t.name == name: - return t.commit.hexsha[:build_len] - for c in repo.iter_commits("--all"): - if c.hexsha[:build_len] == name[:build_len]: - return c.hexsha[:build_len] - return None + + def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]: + """Helper method that gets rows of (build_commit, test_time) sorted by the latter.""" + return [] + + def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]: + """ + Helper method that gets table rows for some list of properties. + Rows are created by combining those where all provided properties are equal. + The resulting rows are then grouped by the provided properties and the t/s values are averaged. + The returned rows are unique in terms of property combinations. + """ + return [] + + +class LlamaBenchDataSQLite3(LlamaBenchData): + connection: sqlite3.Connection + cursor: sqlite3.Cursor + + def __init__(self): + super().__init__() + self.connection = sqlite3.connect(":memory:") + self.cursor = self.connection.cursor() + self.cursor.execute(f"CREATE TABLE test({', '.join(' '.join(x) for x in zip(DB_FIELDS, DB_TYPES))});") + + def _builds_init(self): + if self.connection: + self.build_len_min = self.cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0] + self.build_len_max = self.cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0] + + if self.build_len_min != self.build_len_max: + logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. " + "Try purging the the database of old commits.") + self.cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});") + + builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() + self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str] + super()._builds_init() + + def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]: + data = self.cursor.execute( + "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall() + return reversed(data) if reverse else data + + def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]: + select_string = ", ".join( + [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"]) + equal_string = " AND ".join( + [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [ + f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"] + ) + group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"]) + query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} " + f"GROUP BY {group_order_string} ORDER BY {group_order_string};") + return self.cursor.execute(query).fetchall() + + +class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3): + def __init__(self, data_file: str): + super().__init__() + + self.connection.close() + self.connection = sqlite3.connect(data_file) + self.cursor = self.connection.cursor() + self._builds_init() + + @staticmethod + def valid_format(data_file: str) -> bool: + connection = sqlite3.connect(data_file) + cursor = connection.cursor() + + try: + if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0: + raise sqlite3.DatabaseError("The provided input file does not exist or is empty.") + except sqlite3.DatabaseError as e: + logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e) + cursor = None + + connection.close() + return True if cursor else False + + +class LlamaBenchDataJSONL(LlamaBenchDataSQLite3): + def __init__(self, data_file: str): + super().__init__() + + with open(data_file, "r", encoding="utf-8") as fp: + for i, line in enumerate(fp): + parsed = json.loads(line) + + for k in parsed.keys() - set(DB_FIELDS): + del parsed[k] + + if (missing_keys := self._check_keys(parsed.keys())): + raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}") + + self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values())) + + self._builds_init() + + @staticmethod + def valid_format(data_file: str) -> bool: + try: + with open(data_file, "r", encoding="utf-8") as fp: + for line in fp: + json.loads(line) + break + except Exception as e: + logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e) + return False + + return True + + +class LlamaBenchDataJSON(LlamaBenchDataSQLite3): + def __init__(self, data_files: list[str]): + super().__init__() + + for data_file in data_files: + with open(data_file, "r", encoding="utf-8") as fp: + parsed = json.load(fp) + + for i, entry in enumerate(parsed): + for k in entry.keys() - set(DB_FIELDS): + del entry[k] + + if (missing_keys := self._check_keys(entry.keys())): + raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}") + + self.cursor.execute(f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values())) + + self._builds_init() + + @staticmethod + def valid_format(data_files: list[str]) -> bool: + if not data_files: + return False + + for data_file in data_files: + try: + with open(data_file, "r", encoding="utf-8") as fp: + json.load(fp) + except Exception as e: + logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e) + return False + + return True + + +class LlamaBenchDataCSV(LlamaBenchDataSQLite3): + def __init__(self, data_files: list[str]): + super().__init__() + + for data_file in data_files: + with open(data_file, "r", encoding="utf-8") as fp: + for i, parsed in enumerate(csv.DictReader(fp)): + keys = set(parsed.keys()) + + for k in keys - set(DB_FIELDS): + del parsed[k] + + if (missing_keys := self._check_keys(keys)): + raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}") + + self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values())) + + self._builds_init() + + @staticmethod + def valid_format(data_files: list[str]) -> bool: + if not data_files: + return False + + for data_file in data_files: + try: + with open(data_file, "r", encoding="utf-8") as fp: + for parsed in csv.DictReader(fp): + break + except Exception as e: + logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e) + return False + + return True + + +bench_data = None +if len(input_file) == 1: + if LlamaBenchDataSQLite3File.valid_format(input_file[0]): + bench_data = LlamaBenchDataSQLite3File(input_file[0]) + elif LlamaBenchDataJSON.valid_format(input_file): + bench_data = LlamaBenchDataJSON(input_file) + elif LlamaBenchDataJSONL.valid_format(input_file[0]): + bench_data = LlamaBenchDataJSONL(input_file[0]) + elif LlamaBenchDataCSV.valid_format(input_file): + bench_data = LlamaBenchDataCSV(input_file) +else: + if LlamaBenchDataJSON.valid_format(input_file): + bench_data = LlamaBenchDataJSON(input_file) + elif LlamaBenchDataCSV.valid_format(input_file): + bench_data = LlamaBenchDataCSV(input_file) + +if not bench_data: + raise RuntimeError("No valid (or some invalid) input files found.") + +if not bench_data.builds: + raise RuntimeError(f"{input_file} does not contain any builds.") hexsha8_baseline = name_baseline = None # If the user specified a baseline, try to find a commit for it: if known_args.baseline is not None: - if known_args.baseline in builds: + if known_args.baseline in bench_data.builds: hexsha8_baseline = known_args.baseline if hexsha8_baseline is None: - hexsha8_baseline = get_commit_hexsha8(known_args.baseline) + hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline) name_baseline = known_args.baseline if hexsha8_baseline is None: logger.error(f"cannot find data for baseline={known_args.baseline}.") sys.exit(1) # Otherwise, search for the most recent parent of master for which there is data: -elif repo is not None: - hexsha8_baseline = find_parent_in_data(repo.heads.master.commit) +elif bench_data.repo is not None: + hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit) if hexsha8_baseline is None: logger.error("No baseline was provided and did not find data for any master branch commits.\n") @@ -235,27 +464,25 @@ def get_commit_hexsha8(name: str): sys.exit(1) -name_baseline = get_commit_name(hexsha8_baseline) +name_baseline = bench_data.get_commit_name(hexsha8_baseline) hexsha8_compare = name_compare = None # If the user has specified a compare value, try to find a corresponding commit: if known_args.compare is not None: - if known_args.compare in builds: + if known_args.compare in bench_data.builds: hexsha8_compare = known_args.compare if hexsha8_compare is None: - hexsha8_compare = get_commit_hexsha8(known_args.compare) + hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare) name_compare = known_args.compare if hexsha8_compare is None: logger.error(f"cannot find data for compare={known_args.compare}.") sys.exit(1) # Otherwise, search for the commit for llama-bench was most recently run # and that is not a parent of master: -elif repo is not None: - hexsha8s_master = get_all_parent_hexsha8s(repo.heads.master.commit) - builds_timestamp = cursor.execute( - "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall() - for (hexsha8, _) in reversed(builds_timestamp): +elif bench_data.repo is not None: + hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit) + for (hexsha8, _) in bench_data.builds_timestamp(reverse=True): if hexsha8 not in hexsha8s_master: hexsha8_compare = hexsha8 break @@ -270,26 +497,7 @@ def get_commit_hexsha8(name: str): parser.print_help() sys.exit(1) -name_compare = get_commit_name(hexsha8_compare) - - -def get_rows(properties): - """ - Helper function that gets table rows for some list of properties. - Rows are created by combining those where all provided properties are equal. - The resulting rows are then grouped by the provided properties and the t/s values are averaged. - The returned rows are unique in terms of property combinations. - """ - select_string = ", ".join( - [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"]) - equal_string = " AND ".join( - [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [ - f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"] - ) - group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"]) - query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} " - f"GROUP BY {group_order_string} ORDER BY {group_order_string};") - return cursor.execute(query).fetchall() +name_compare = bench_data.get_commit_name(hexsha8_compare) # If the user provided columns to group the results by, use them: @@ -303,10 +511,10 @@ def get_rows(properties): logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}") parser.print_usage() sys.exit(1) - rows_show = get_rows(show) + rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare) # Otherwise, select those columns where the values are not all the same: else: - rows_full = get_rows(KEY_PROPERTIES) + rows_full = bench_data.get_rows(KEY_PROPERTIES, hexsha8_baseline, hexsha8_compare) properties_different = [] for i, kp_i in enumerate(KEY_PROPERTIES): if kp_i in DEFAULT_SHOW or kp_i in ["n_prompt", "n_gen", "n_depth"]: @@ -336,7 +544,7 @@ def get_rows(properties): show.remove(prop) except ValueError: pass - rows_show = get_rows(show) + rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare) if not rows_show: logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n")