diff --git a/fast_bm25.py b/fast_bm25.py index f571806..95a7714 100644 --- a/fast_bm25.py +++ b/fast_bm25.py @@ -1,10 +1,11 @@ -import json +import collections +import heapq import math +import pickle import sys -import heapq PARAM_K1 = 1.5 PARAM_B = 0.75 -IDF_CUTOFF = 3 +IDF_CUTOFF = 4 class BM25: @@ -21,7 +22,6 @@ class BM25: avgdl : float Average length of document in `corpus`. """ - def __init__(self, corpus, k1=PARAM_K1, b=PARAM_B, alpha=IDF_CUTOFF): """ Parameters @@ -42,7 +42,6 @@ def __init__(self, corpus, k1=PARAM_K1, b=PARAM_B, alpha=IDF_CUTOFF): IDF cutoff, terms with a lower idf score than alpha will be dropped. A higher alpha will lower the accuracy of BM25 but increase performance """ - self.k1 = k1 self.b = b self.alpha = alpha @@ -93,31 +92,6 @@ def _initialize(self, corpus): file=sys.stderr ) - def get_score(self, query, index): - """Computes BM25 score of given `document` in relation to item of corpus selected by `index`. - - Parameters - ---------- - query : list of str - The tokenized query to score. - index : int - Index of document in corpus selected to score with `query`. - - Returns - ------- - float - BM25 score. - """ - score = 0.0 - numerator_constant = self.k1 + 1 - denominator_constant = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl) - for token in query: - if token in self.t2d and index in self.t2d[token]: - df = self.t2d[token][index] - idf = self.idf[token] - score += (idf * df * numerator_constant) / (df + denominator_constant) - return score - def get_top_n(self, query, documents, n=5): """ Retrieve the top n documents for the query. @@ -137,35 +111,20 @@ def get_top_n(self, query, documents, n=5): The top n documents """ assert self.corpus_size == len(documents), "The documents given don't match the index corpus!" - indexes = set( - i - for token in query - if token in self.t2d - for i in self.t2d[token].keys() - ) - return [documents[i] for i in heapq.nlargest(n, indexes, key=lambda idx: self.get_score(query, idx))] + scores = collections.defaultdict(float) + for token in query: + if token in self.t2d: + for index, freq in self.t2d[token].items(): + denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl) + scores[index] += self.idf[token]*freq*(self.k1 + 1)/(freq + denom_cst) + + return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)] def save(self, filename): - json_object = { - "k1": self.k1, "b": self.b, "alpha": self.alpha, "avgdl": self.avgdl, - "t2d": self.t2d, "idf": self.idf, "doc_len": self.doc_len - } - with open(f"{filename}.json", "w") as fsave: - json.dump(json_object, fsave) + with open(f"{filename}.pkl", "wb") as fsave: + pickle.dump(self, fsave, protocol=pickle.HIGHEST_PROTOCOL) @staticmethod def load(filename): - with open(f"{filename}.json", "r") as fsave: - json_object = json.load(fsave) - # we have to do this terribleness because json does not have int as keys - for tok in json_object["t2d"]: - json_object["t2d"][tok] = {int(i): f for i, f in json_object["t2d"][tok].items()} - bm25 = BM25([]) - bm25.k1 = json_object["k1"] - bm25.b = json_object["b"] - bm25.alpha = json_object["alpha"] - bm25.avgdl = json_object["avgdl"] - bm25.t2d = json_object["t2d"] - bm25.idf = json_object["idf"] - bm25.doc_len = json_object["doc_len"] - return bm25 + with open(f"{filename}.pkl", "rb") as fsave: + return pickle.load(fsave)