Skip to content

Commit

Permalink
Update fast_bm25.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Inspirateur authored Oct 20, 2021
1 parent 649b459 commit 37b8349
Showing 1 changed file with 16 additions and 57 deletions.
73 changes: 16 additions & 57 deletions fast_bm25.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import collections
import heapq
import math
import pickle
import sys
import heapq
PARAM_K1 = 1.5
PARAM_B = 0.75
IDF_CUTOFF = 3
IDF_CUTOFF = 4


class BM25:
Expand All @@ -21,7 +22,6 @@ class BM25:
avgdl : float
Average length of document in `corpus`.
"""

def __init__(self, corpus, k1=PARAM_K1, b=PARAM_B, alpha=IDF_CUTOFF):
"""
Parameters
Expand All @@ -42,7 +42,6 @@ def __init__(self, corpus, k1=PARAM_K1, b=PARAM_B, alpha=IDF_CUTOFF):
IDF cutoff, terms with a lower idf score than alpha will be dropped. A higher alpha will lower the accuracy
of BM25 but increase performance
"""

self.k1 = k1
self.b = b
self.alpha = alpha
Expand Down Expand Up @@ -93,31 +92,6 @@ def _initialize(self, corpus):
file=sys.stderr
)

def get_score(self, query, index):
"""Computes BM25 score of given `document` in relation to item of corpus selected by `index`.
Parameters
----------
query : list of str
The tokenized query to score.
index : int
Index of document in corpus selected to score with `query`.
Returns
-------
float
BM25 score.
"""
score = 0.0
numerator_constant = self.k1 + 1
denominator_constant = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
for token in query:
if token in self.t2d and index in self.t2d[token]:
df = self.t2d[token][index]
idf = self.idf[token]
score += (idf * df * numerator_constant) / (df + denominator_constant)
return score

def get_top_n(self, query, documents, n=5):
"""
Retrieve the top n documents for the query.
Expand All @@ -137,35 +111,20 @@ def get_top_n(self, query, documents, n=5):
The top n documents
"""
assert self.corpus_size == len(documents), "The documents given don't match the index corpus!"
indexes = set(
i
for token in query
if token in self.t2d
for i in self.t2d[token].keys()
)
return [documents[i] for i in heapq.nlargest(n, indexes, key=lambda idx: self.get_score(query, idx))]
scores = collections.defaultdict(float)
for token in query:
if token in self.t2d:
for index, freq in self.t2d[token].items():
denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
scores[index] += self.idf[token]*freq*(self.k1 + 1)/(freq + denom_cst)

return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]

def save(self, filename):
json_object = {
"k1": self.k1, "b": self.b, "alpha": self.alpha, "avgdl": self.avgdl,
"t2d": self.t2d, "idf": self.idf, "doc_len": self.doc_len
}
with open(f"{filename}.json", "w") as fsave:
json.dump(json_object, fsave)
with open(f"{filename}.pkl", "wb") as fsave:
pickle.dump(self, fsave, protocol=pickle.HIGHEST_PROTOCOL)

@staticmethod
def load(filename):
with open(f"{filename}.json", "r") as fsave:
json_object = json.load(fsave)
# we have to do this terribleness because json does not have int as keys
for tok in json_object["t2d"]:
json_object["t2d"][tok] = {int(i): f for i, f in json_object["t2d"][tok].items()}
bm25 = BM25([])
bm25.k1 = json_object["k1"]
bm25.b = json_object["b"]
bm25.alpha = json_object["alpha"]
bm25.avgdl = json_object["avgdl"]
bm25.t2d = json_object["t2d"]
bm25.idf = json_object["idf"]
bm25.doc_len = json_object["doc_len"]
return bm25
with open(f"{filename}.pkl", "rb") as fsave:
return pickle.load(fsave)

0 comments on commit 37b8349

Please sign in to comment.