diff --git a/examples/inference/distilsplade_max/beir_scifact/all_in_one.sh b/examples/inference/distilsplade_max/beir_scifact/all_in_one.sh index aaf06ff..29594b5 100644 --- a/examples/inference/distilsplade_max/beir_scifact/all_in_one.sh +++ b/examples/inference/distilsplade_max/beir_scifact/all_in_one.sh @@ -22,42 +22,110 @@ nohup python -m sparse_retrieval.inference.aio \ # { # "nDCG": { # "NDCG@1": 0.60333, +# "NDCG@2": 0.63895, # "NDCG@3": 0.65969, # "NDCG@5": 0.67204, # "NDCG@10": 0.6925, +# "NDCG@20": 0.70403, # "NDCG@100": 0.7202, # "NDCG@1000": 0.72753 # }, # "MAP": { # "MAP@1": 0.57217, +# "MAP@2": 0.61522, # "MAP@3": 0.63391, # "MAP@5": 0.64403, # "MAP@10": 0.65444, +# "MAP@20": 0.65846, # "MAP@100": 0.66071, # "MAP@1000": 0.66096 # }, # "Recall": { # "Recall@1": 0.57217, +# "Recall@2": 0.65078, # "Recall@3": 0.70172, # "Recall@5": 0.73461, # "Recall@10": 0.79122, +# "Recall@20": 0.833, # "Recall@100": 0.92033, # "Recall@1000": 0.98 # }, # "Precision": { # "P@1": 0.60333, +# "P@2": 0.35167, # "P@3": 0.25444, # "P@5": 0.16267, # "P@10": 0.08967, +# "P@20": 0.0475, # "P@100": 0.01043, # "P@1000": 0.00111 # }, # "mrr": { # "MRR@1": 0.60333, +# "MRR@2": 0.64167, # "MRR@3": 0.65722, # "MRR@5": 0.66306, # "MRR@10": 0.67052, +# "MRR@20": 0.67304, # "MRR@100": 0.67503, # "MRR@1000": 0.67524 -# } -# } +# }, +# "latency": { +# "latency_avg": 0.0632176259594659, +# "query_word_length_avg": 13.85, +# "binned": { +# "word_length_bins": [ +# 5.0, +# 7.6, +# 10.2, +# 12.8, +# 15.4, +# 18.0, +# 20.6, +# 23.2, +# 25.8, +# 28.400000000000002, +# 31.0 +# ], +# "freqs": [ +# 21, +# 69, +# 53, +# 68, +# 24, +# 22, +# 22, +# 9, +# 6, +# 6 +# ], +# "latencies_avg": [ +# 0.06157973294677536, +# 0.062426611955935025, +# 0.06298324124358223, +# 0.0637137847172033, +# 0.06537958721596762, +# 0.06230686320131933, +# 0.06198018672587427, +# 0.06480219447926497, +# 0.06900245506425784, +# 0.06779847725027305 +# ], +# "latencies_std": [ +# 0.007268265966041692, +# 0.00695837791999461, +# 0.007156900485436917, +# 0.007058682842506954, +# 0.00628885084604788, +# 0.007282331879014841, +# 0.007265829058465847, +# 0.005582415388678566, +# 0.0017274655407518776, +# 3.0068899748332633e-05 +# ] +# }, +# "batch_size": 61.06666666666667, +# "processor": " Intel(R) Xeon(R) Platinum 8168 CPU @ 2.70GHz" +# }, +# "index_size": "3.87MB" +# } \ No newline at end of file diff --git a/sparse_retrieval/inference/aio.py b/sparse_retrieval/inference/aio.py index 6baff2a..65dcd05 100644 --- a/sparse_retrieval/inference/aio.py +++ b/sparse_retrieval/inference/aio.py @@ -37,6 +37,7 @@ def run( output_format_search: str = 'trec', # evaluate + bins: int = 10, k_values: List[int] = [1,2,3,5,10,20,100,1000], # default setting @@ -104,13 +105,15 @@ def run( # 5. Search the queries over the index # The output will be ${output_dir}-quantized/${output_format_search}-format/run.tsv output_path_search = os.path.join(output_dir, f'{output_format_search}-format/run.tsv') - if not os.path.exists(output_path_search): + output_path_latency = os.path.join(output_dir, f'{output_format_search}-format/latency.tsv') + if not all([os.path.exists(output_path_search), os.path.exists(output_path_latency)]): search.run( topics=tsv_queries_path, encoder_name=encoder_name, ckpt_name=query_ckpt, index=output_dir_index, output=output_path_search, + output_latency=output_path_latency, impact=True, hits=hits+1, batch_size=batch_size, @@ -126,7 +129,16 @@ def run( qrels_path = os.path.join(eval_data_dir, 'qrels', f'{topic_split}.tsv') output_dir_evaluate = os.path.join(output_dir, 'evaluation') if not os.path.exists(output_dir_evaluate): - evaluate.run(output_path_search, output_format_search, qrels_path, output_dir_evaluate, k_values) + evaluate.run( + result_path=output_path_search, + latency_path=output_path_latency, + index_path=output_dir_index, + format=output_format_search, + qrels_path=qrels_path, + output_dir=output_dir_evaluate, + bins=bins, + k_values=k_values + ) else: print('Escaped evaluation due to the existing output file(s)') @@ -158,6 +170,7 @@ def run( parser.add_argument('--hits', type=int, default=1000) parser.add_argument('--output_format_search', type=str, default='trec', choices=['msmarco', 'trec']) + parser.add_argument('--bins', type=int, default=10, help="Binning query latencies wrt. how many word-length bins.") parser.add_argument('--k_values', nargs='+', type=int, default=[1,2,3,5,10,20,100,1000]) parser.add_argument('--batch_size', type=int, default=64) diff --git a/sparse_retrieval/inference/evaluate.py b/sparse_retrieval/inference/evaluate.py index edf08fb..1905b79 100644 --- a/sparse_retrieval/inference/evaluate.py +++ b/sparse_retrieval/inference/evaluate.py @@ -2,8 +2,10 @@ import csv import json import os -from typing import Dict -from .utils import load_qrels +from typing import Dict, List + +import numpy as np +from .utils import load_qrels, get_processor_name, get_folder_size, bin_and_average, bin_and_std from beir.retrieval.evaluation import EvaluateRetrieval @@ -33,7 +35,7 @@ def load_results(result_path, format) -> Dict[str, Dict[str, float]]: return results -def run(result_path, format, qrels_path, output_dir, k_values=[1,3,5,10,100,1000]): +def run(result_path, latency_path: str, index_path: str, format, qrels_path, output_dir, bins: int=10, k_values=[1,3,5,10,100,1000]): results = load_results(result_path, format) qrels = load_qrels(qrels_path) evaluator = EvaluateRetrieval() @@ -47,6 +49,42 @@ def run(result_path, format, qrels_path, output_dir, k_values=[1,3,5,10,100,1000 ndcg, _map, recall, precision = evaluator.evaluate(qrels, results, k_values) mrr = EvaluateRetrieval.evaluate_custom(qrels, results, k_values, metric='mrr') + # Get latency info: + latencies: List[float] = [] + word_lengths: List[int] = [] + batch_sizes: List[int] = [] + with open(latency_path) as f: + for line in f: + qid, word_length, latency, batch_size = line.strip().split("\t") + latencies.append(float(latency)) + word_lengths.append(int(word_length)) + batch_sizes.append(int(batch_size)) + freqs, word_length_bins = np.histogram(word_lengths, bins=bins) + binned_latencies_avg = bin_and_average(keys=word_lengths, values=latencies, numpy_bins=word_length_bins) + binned_latencies_std = bin_and_std(keys=word_lengths, values=latencies, numpy_bins=word_length_bins) + latency_info = { + "latency": { + "latency_avg": np.mean(latencies), + "latency_std": np.std(latencies), + "query_word_length_avg": np.mean(word_lengths), + "binned": { + "word_length_bins": word_length_bins.tolist(), + "freqs": freqs.tolist(), + "latencies_avg": binned_latencies_avg, + "latencies_std": binned_latencies_std + }, + "batch_size": np.mean(batch_sizes), + "processor": get_processor_name() + } + } + + # Get index info: + index_size = get_folder_size(index_path) + index_info = { + "index_size": index_size + } + + # Get evaluation scores and save all results: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, 'metrics.json'), 'w') as f: metrics = { @@ -56,15 +94,20 @@ def run(result_path, format, qrels_path, output_dir, k_values=[1,3,5,10,100,1000 'Precision': precision, 'mrr': mrr } + metrics.update(latency_info) + metrics.update(index_info) json.dump(metrics, f, indent=4) print(f'{__name__}: Done') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--result_path') + parser.add_argument('--latency_path') + parser.add_argument('--index_path') parser.add_argument('--format', choices=['msmarco', 'trec'], help='Format of the retrieval result. The formats are from pyserini.output_writer.py') parser.add_argument('--qrels_path', help='Path to the BeIR-format file') parser.add_argument('--output_dir') + parser.add_argument('--bins', type=int, default=10, help="Binning query latencies wrt. how many word-length bins.") parser.add_argument('--k_values', nargs='+', type=int, default=[1,2,3,5,10,20,100,1000]) args = parser.parse_args() run(**vars(args)) \ No newline at end of file diff --git a/sparse_retrieval/inference/search.py b/sparse_retrieval/inference/search.py index d2a31d9..f11bba1 100644 --- a/sparse_retrieval/inference/search.py +++ b/sparse_retrieval/inference/search.py @@ -17,12 +17,14 @@ # import argparse +from contextlib import contextmanager import inspect import os +import time from tqdm import tqdm from transformers import AutoTokenizer -from typing import List +from typing import Dict, List from pyserini.analysis import JDefaultEnglishAnalyzer, JWhiteSpaceAnalyzer from pyserini.output_writer import OutputFormat, get_output_writer @@ -33,6 +35,8 @@ from pyserini.search.lucene.reranker import ClassifierType, PseudoRelevanceClassifierReranker from . import encoder_builders +from nltk import word_tokenize + def set_bm25_parameters(searcher, index, k1=None, b=None): if k1 is not None or b is not None: @@ -69,10 +73,46 @@ def __init__(self, *args, **kwargs): self.__dict__ = self +class LatencyReporter: + + def __init__(self) -> None: + self.qids: List[str] = [] + self.latencies: List[float] = [] + self.word_lengths: List[int] = [] + self.batch_sizes: List[int] = [] + + def record(self, qid: str, text: str, latency: float) -> None: + """Record one single query latency (single threaded).""" + self.qids.append(qid) + self.latencies.append(latency) + self.word_lengths.append(len(word_tokenize(text))) + self.batch_sizes.append(1) + + def record_batch(self, qids: List[str], texts: List[str], latency: float) -> None: + """Record a batch of queries, where the latency will be averaged.""" + self.qids.extend(qids) + self.latencies.extend([latency / len(qids)] * len(qids)) + self.word_lengths.extend(map(lambda text: len(word_tokenize(text)), texts)) + self.batch_sizes.extend([len(qids)] * len(qids)) + + def report(self, output_path: str) -> None: + """Report the latency details into the `output_path`.""" + with open(output_path, "w") as f: + for qid, word_length, latency, batch_size in zip(self.qids, self.word_lengths, self.latencies, self.batch_sizes): + f.write(f"{qid}\t{word_length}\t{latency}\t{batch_size}\n") + + @staticmethod + @contextmanager + def timer() -> float: + """Timer context manager. Reference: https://stackoverflow.com/a/62956469/16409125""" + start = time.perf_counter() + yield lambda: time.perf_counter() - start + def run( topics: str, index: str, output: str, + output_latency: str, topics_format: str = TopicsFormat.DEFAULT.value, output_format: str = OutputFormat.TREC.value, max_passage: bool = False, @@ -205,35 +245,42 @@ def run( max_passage_delimiter=args.max_passage_delimiter, max_passage_hits=args.max_passage_hits) + latency_reporter = LatencyReporter() with output_writer: batch_topics = list() batch_topic_ids = list() - for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): + for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()), desc="Doing search")): if (args.tokenizer != None): toks = tokenizer.tokenize(text) text = ' ' text = text.join(toks) if args.batch_size <= 1 and args.threads <= 1: if args.impact: - hits = searcher.search(text, args.hits, fields=fields) + with LatencyReporter.timer() as timer: + hits = searcher.search(text, args.hits, fields=fields) else: - hits = searcher.search(text, args.hits, query_generator=query_generator, fields=fields) + with LatencyReporter.timer() as timer: + hits = searcher.search(text, args.hits, query_generator=query_generator, fields=fields) results = [(topic_id, hits)] + latency_reporter.record(qid=topic_id, text=text, latency=timer()) else: batch_topic_ids.append(str(topic_id)) batch_topics.append(text) if (index + 1) % args.batch_size == 0 or \ index == len(topics.keys()) - 1: if args.impact: - results = searcher.batch_search( - batch_topics, batch_topic_ids, args.hits, args.threads, fields=fields - ) + with LatencyReporter.timer() as timer: + results = searcher.batch_search( + batch_topics, batch_topic_ids, args.hits, args.threads, fields=fields + ) else: - results = searcher.batch_search( - batch_topics, batch_topic_ids, args.hits, args.threads, - query_generator=query_generator, fields=fields - ) + with LatencyReporter.timer() as timer: + results = searcher.batch_search( + batch_topics, batch_topic_ids, args.hits, args.threads, + query_generator=query_generator, fields=fields + ) results = [(id_, results[id_]) for id_ in batch_topic_ids] + latency_reporter.record_batch(qids=batch_topic_ids, texts=batch_topics, latency=timer()) batch_topic_ids.clear() batch_topics.clear() else: @@ -263,6 +310,7 @@ def run( output_writer.write(topic, hits) results.clear() + latency_reporter.report(args.output_latency) print(f'{__name__}: Done') def define_search_args(parser): @@ -318,6 +366,8 @@ def define_search_args(parser): help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") parser.add_argument('--output', type=str, metavar='path', help="Path to output file.") + parser.add_argument('--output-latency', type=str, metavar='path', + help="Path to latency-output file.") parser.add_argument('--max-passage', action='store_true', default=False, help="Select only max passage from document.") parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100, diff --git a/sparse_retrieval/inference/utils.py b/sparse_retrieval/inference/utils.py index eb12397..1484569 100644 --- a/sparse_retrieval/inference/utils.py +++ b/sparse_retrieval/inference/utils.py @@ -1,6 +1,12 @@ import csv import json -from typing import Dict +import os +import platform +import re +import subprocess +from typing import Dict, List + +import numpy as np def load_qrels(qrels_path) -> Dict[str, Dict[str, float]]: # adapted from BeIR: @@ -29,4 +35,49 @@ def load_queries(queries_path) -> Dict[str, str]: line = json.loads(line) queries[line.get("_id")] = line.get("text") - return queries \ No newline at end of file + return queries + + +def get_processor_name() -> str: + """Reference: https://stackoverflow.com/a/13078519/16409125.""" + if platform.system() == "Windows": + return platform.processor() + elif platform.system() == "Darwin": + os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin' + command ="sysctl -n machdep.cpu.brand_string" + return subprocess.check_output(command).strip() + elif platform.system() == "Linux": + command = "cat /proc/cpuinfo" + all_info = subprocess.check_output(command, shell=True).decode().strip() + for line in all_info.split("\n"): + if "model name" in line: + return re.sub(".*model name.*:", "", line,1) + return "Cannot get processor info." + + +def get_folder_size(start_path: str) -> str: + """Reference: https://stackoverflow.com/a/1392549/16409125.""" + total_size = 0 + for dirpath, _, filenames in os.walk(start_path): + for f in filenames: + fp = os.path.join(dirpath, f) + # skip if it is symbolic link + if not os.path.islink(fp): + total_size += os.path.getsize(fp) + return f"{round(total_size / 1024 / 1024, 2)}MB" + + +def bin_and_average(keys: List[float], values: List[float], numpy_bins: List[int]) -> List[float]: + """Reference: https://stackoverflow.com/a/6163403/16409125.""" + digitized = np.digitize(keys, numpy_bins) + values = np.array(values) + bin_means = [values[digitized == i].mean() for i in range(1, len(numpy_bins))] + return bin_means + + +def bin_and_std(keys: List[float], values: List[float], numpy_bins: List[int]) -> List[float]: + """Reference: https://stackoverflow.com/a/6163403/16409125.""" + digitized = np.digitize(keys, numpy_bins) + values = np.array(values) + bin_stds = [values[digitized == i].std() for i in range(1, len(numpy_bins))] + return bin_stds