Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added retrieval aio tests #12

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,5 @@ dmypy.json
# Pyre type checker
.pyre/

*.tar.gz
*.tar.gz
*.zip
28 changes: 24 additions & 4 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ dependencies:
- libapr=1.7.0=hf178f73_5
- libapriconv=1.2.2=h7f8727e_5
- libaprutil=1.6.1=hfefca11_5
- libblas=3.9.0=12_linux64_mkl
- libcblas=3.9.0=12_linux64_mkl
- libdb=6.2.32=hf484d3e_0
- libedit=3.1.20210910=h7f8727e_0
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgomp=9.3.0=h5101ec6_17
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.2=h7f8727e_0
- liblapack=3.9.0=12_linux64_mkl
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtasn1=4.16.0=h27cfd23_0
Expand Down Expand Up @@ -119,20 +122,25 @@ dependencies:
- absl-py==1.0.0
- astunparse==1.6.3
- beir==0.2.3
- black==23.1.0
- blis==0.7.5
- cachetools==5.0.0
- catalogue==2.0.6
- charset-normalizer==2.0.10
- click==8.0.3
- contourpy==1.0.7
- crash-ipdb==0.0.3
- cycler==0.11.0
- cymem==2.0.6
- cython==0.29.26
- datasets==1.1.3
- dill==0.3.4
- elasticsearch==7.16.3
- exceptiongroup==1.1.1
- faiss-cpu==1.7.2
- filelock==3.4.2
- flatbuffers==2.0
- fonttools==4.38.0
- gast==0.4.0
- google-auth==2.5.0
- google-auth-oauthlib==0.4.6
Expand All @@ -142,27 +150,36 @@ dependencies:
- huggingface-hub==0.4.0
- idna==3.3
- importlib-metadata==4.10.1
- importlib-resources==5.10.2
- iniconfig==2.0.0
- ipdb==0.13.9
- jinja2==3.0.3
- joblib==1.1.0
- keras==2.7.0
- keras-preprocessing==1.1.2
- kiwisolver==1.4.4
- langcodes==3.3.0
- libclang==12.0.0
- lightgbm==3.3.2
- markdown==3.3.6
- markupsafe==2.0.1
- matplotlib==3.7.0
- multiprocess==0.70.12.2
- murmurhash==1.0.6
- mypy-extensions==1.0.0
- nltk==3.6.7
- nmslib==2.1.1
- numpy==1.22.1
- oauthlib==3.1.1
- onnxruntime==1.10.0
- opt-einsum==3.3.0
- packaging==21.3
- packaging==23.0
- pandas==1.4.0
- pathspec==0.11.0
- pathy==0.6.1
- patsy==0.5.3
- platformdirs==3.0.0
- pluggy==1.0.0
- preshed==3.0.6
- protobuf==3.19.3
- psutil==5.9.0
Expand All @@ -173,7 +190,8 @@ dependencies:
- pydantic==1.8.2
- pyjnius==1.4.1
- pyparsing==3.0.7
- pyserini==0.15.0
- pyserini==0.20.0
- pytest==7.3.0
- python-dateutil==2.8.2
- pytrec-eval==0.5
- pytz==2021.3
Expand All @@ -192,6 +210,7 @@ dependencies:
- spacy-legacy==3.0.8
- spacy-loggers==1.0.1
- srsly==2.4.2
- statsmodels==0.13.5
- tensorboard==2.8.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
Expand All @@ -205,8 +224,9 @@ dependencies:
- threadpoolctl==3.0.0
- tokenizers==0.10.3
- toml==0.10.2
- tomli==2.0.1
- torch-scatter==2.0.6
- tqdm==4.49.0
- tqdm==4.64.1
- transformers==4.15.0
- typer==0.4.0
- urllib3==1.26.8
Expand All @@ -215,4 +235,4 @@ dependencies:
- wrapt==1.13.3
- xxhash==2.0.2
- zipp==3.7.0
prefix: /home/n3thakur/anaconda3/envs/sparse-retrieval
prefix: /home/fb20user07/miniconda3/envs/sparse-retrieval
58 changes: 58 additions & 0 deletions sample-data/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import json
import shutil
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

import logging
import pathlib, os
import random

#### Just some code to print debug information to stdout
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)
#### /print debug information to stdout

#### Download scifact.zip dataset and unzip the dataset
dataset = "scifact"
url = (
"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(
dataset
)
)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), ".")
shutil.rmtree(os.path.join(out_dir, "scifact"))
data_path = util.download_and_unzip(url, out_dir)

corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
random_state = random.Random(42)
corpus_sampled = dict(random_state.sample(list(corpus.items()), k=10))
qrels_sampled = dict(random_state.sample(list(qrels.items()), k=3))
for qid, rels in qrels_sampled.items():
for did in rels:
corpus_sampled[did] = corpus[did]
queries_sampled = {qid: queries[qid] for qid, _ in qrels_sampled.items()}


with open(os.path.join(data_path, "corpus.jsonl"), "w") as f:
for id, line in corpus_sampled.items():
line.update({"_id": id})
f.write(json.dumps(line) + "\n")

with open(os.path.join(data_path, "queries.jsonl"), "w") as f:
for qid, text in queries_sampled.items():
f.write(json.dumps({"_id": qid, "text": text, "metadata": {}}) + "\n")

with open(os.path.join(data_path, "qrels", "test.tsv"), "w") as f:
f.write("query-id\tcorpus-id\tscore\n")
for qid, rels in qrels_sampled.items():
for did, rel in rels.items():
f.write(f"{qid}\t{did}\t{rel}\n")

os.remove(os.path.join(data_path, "qrels", "train.tsv"))
13 changes: 13 additions & 0 deletions sample-data/scifact/corpus.jsonl

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions sample-data/scifact/qrels/test.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
query-id corpus-id score
1019 11603066 1
75 4387784 1
72 6076903 1
3 changes: 3 additions & 0 deletions sample-data/scifact/queries-test.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
1019 Rapid phosphotransfer rates govern fidelity in two component systems
75 Active H. pylori urease has a polymeric structure that compromises two subunits, UreA and UreB.
72 Activator-inhibitor pairs are provided dorsally by Admpchordin.
3 changes: 3 additions & 0 deletions sample-data/scifact/queries.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"_id": "1019", "text": "Rapid phosphotransfer rates govern fidelity in two component systems", "metadata": {}}
{"_id": "75", "text": "Active H. pylori urease has a polymeric structure that compromises two subunits, UreA and UreB.", "metadata": {}}
{"_id": "72", "text": "Activator-inhibitor pairs are provided dorsally by Admpchordin.", "metadata": {}}
Empty file added sprint/__init__.py
Empty file.
89 changes: 60 additions & 29 deletions sprint/inference/methods/sparta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


class SPARTADocumentEncoder(torch.nn.Module, DocumentEncoder):
def __init__(self, model_name, device): #SpanBERT/spanbert-base-cased'): #bert-base-uncased #distilbert-base-uncased #distilroberta-base
def __init__(
self, model_name, device
): # SpanBERT/spanbert-base-cased'): #bert-base-uncased #distilbert-base-uncased #distilroberta-base
super().__init__()
print("Model name:", model_name)
self.bert_model = AutoModel.from_pretrained(model_name)
Expand All @@ -18,80 +20,109 @@ def __init__(self, model_name, device): #SpanBERT/spanbert-base-cased'): #bert-b
self.device = device
self.max_length = 300
#####
self.bert_input_emb = self.bert_model.embeddings.word_embeddings(torch.tensor(list(range(0, len(self.tokenizer))), device=device)) # for building term weights
self.bert_input_emb = self.bert_model.embeddings.word_embeddings(
torch.tensor(list(range(0, len(self.tokenizer))), device=device)
) # for building term weights
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}
self.special_token_embedding_to_zero = False # used during inference

def bert_embeddings(self, input_ids):
return self.bert_model.embeddings.word_embeddings(input_ids)

def query_embeddings(self, query):
queries_batch = self.tokenizer(query, padding=True, truncation=True, return_tensors='pt', add_special_tokens=False, max_length=self.max_length).to(self.device)
queries_embeddings = self.bert_embeddings(queries_batch['input_ids'])
queries_batch = self.tokenizer(
query,
padding=True,
truncation=True,
return_tensors="pt",
add_special_tokens=False,
max_length=self.max_length,
).to(self.device)
queries_embeddings = self.bert_embeddings(queries_batch["input_ids"])
return queries_embeddings

def passage_embeddings(self, passages):
passage_batch = self.tokenizer(passages, padding=True, truncation=True, return_tensors='pt', max_length=self.max_length).to(self.device)
passage_batch = self.tokenizer(
passages,
padding=True,
truncation=True,
return_tensors="pt",
max_length=self.max_length,
).to(self.device)
passage_embeddings = self.bert_model(**passage_batch).last_hidden_state
return passage_embeddings

def compute_scores(self, query_embeddings, passage_embeddings):
### Eq. 4 - Term matching
scores = []
for idx in range(len(query_embeddings)): #TODO: use correct pytorch function for this
scores.append(torch.matmul(query_embeddings[idx], passage_embeddings.transpose(1, 2)))
for idx in range(
len(query_embeddings)
): # TODO: use correct pytorch function for this
scores.append(
torch.matmul(query_embeddings[idx], passage_embeddings.transpose(1, 2))
)
scores = torch.stack(scores)
#print("Scores:", scores.shape)
# print("Scores:", scores.shape)
max_scores = torch.max(scores, dim=-1).values
#print("Max-Scores:", max_scores.shape)
# print("Max-Scores:", max_scores.shape)

### Eq. 5 - ReLu
relu_scores = torch.relu(max_scores) #torch.relu(max_scores + self.score_bias) #Bias score does not change that much?
#print("ReLu-Scores:", relu_scores.shape)
relu_scores = torch.relu(
max_scores
) # torch.relu(max_scores + self.score_bias) #Bias score does not change that much?
# print("ReLu-Scores:", relu_scores.shape)

### Eq. 6 - Final Score
final_scores = torch.sum(torch.log(relu_scores + 1), dim=-1) #.unsqueeze(dim=0)
#print("Final scores:", final_scores.shape)
final_scores = torch.sum(
torch.log(relu_scores + 1), dim=-1
) # .unsqueeze(dim=0)
# print("Final scores:", final_scores.shape)
return final_scores

def forward(self, queries, passages):
query_embeddings = self.query_embeddings(queries)
passage_embeddings = self.passage_embeddings(passages)
return self.compute_scores(query_embeddings, passage_embeddings)

###
def _set_special_token_embedding_to_zero(self):
if self.bert_model.training == True:
return

if self.special_token_embedding_to_zero:
return

for special_id in self.tokenizer.all_special_ids:
self.bert_input_emb[special_id] = 0 * self.bert_input_emb[special_id]

self.special_token_embedding_to_zero = True

###
def encode(self, texts, **kwargs):
self._set_special_token_embedding_to_zero() # Important for full reproduction (although it seems to have little influence on the performance)

term_weights_batch = []
sparse_vec_size = kwargs.setdefault('sparse_vec_size', 2000) # TODO: Make this into the search.py cli arguments
sparse_vec_size = kwargs.setdefault(
"sparse_vec_size", 2000
) # TODO: Make this into the search.py cli arguments
assert sparse_vec_size <= len(self.tokenizer)

tokens = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=500).to(self.device)
tokens = self.tokenizer(
texts, padding=True, truncation=True, return_tensors="pt", max_length=500
).to(self.device)
passage_embeddings = self.bert_model(**tokens).last_hidden_state
for passage_emb in passage_embeddings: # TODO: Optimize this by batch operations
for (
passage_emb
) in passage_embeddings: # TODO: Optimize this by batch operations
scores = torch.matmul(self.bert_input_emb, passage_emb.transpose(0, 1))
max_scores = torch.max(scores, dim=-1).values
relu_scores = torch.relu(max_scores) #Eq. 5
relu_scores = torch.relu(max_scores) # Eq. 5
final_scores = torch.log(relu_scores + 1) # Eq. 6, final score

top_results = torch.topk(final_scores, k=sparse_vec_size)
tids = top_results[1].cpu().detach().tolist()
scores = top_results[0].cpu().detach().tolist()

term_weights = {}
for tid, score in zip(tids, scores):
if score > 0:
Expand All @@ -100,23 +131,23 @@ def encode(self, texts, **kwargs):
break

term_weights_batch.append(term_weights)

return term_weights_batch

class SPARTAQueryEncoder(QueryEncoder):

def __init__(self, model_name_or_path, device='cpu'):
class SPARTAQueryEncoder(QueryEncoder):
def __init__(self, model_name_or_path, device="cpu"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}

def encode(self, text, **kwargs):
token_ids = self.tokenizer(text, add_special_tokens=False)['input_ids']
token_ids = self.tokenizer(text, add_special_tokens=False)["input_ids"]
tokens = [self.reverse_voc[token_id] for token_id in token_ids]
term_weights = defaultdict(int)

# Important for reproducing the results:
# Note that in Pyserini/Anserini, the query term weights are maintained by JHashMap,
# which will keep only one term weight for identical terms
for token in tokens:
term_weights[token] += 1
return term_weights
return term_weights
Empty file added tests/__init__.py
Empty file.
Loading