-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated function to get similar papers
1. Set the similarity threshold to 0.64 (keep all papers with similarity scores > 0.64 ) 2. Updated comparison strategy: On level 1, compare the paper with the seed paper. If similarity score is higher than the threshold, than compare its references on level 2 with a concatenation of the seed paper + this paper on level 1.
- Loading branch information
1 parent
56bde2e
commit c48dde0
Showing
1 changed file
with
202 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "74c09945-0977-4124-834c-6fbf5c31c86d", | ||
"metadata": {}, | ||
"source": [ | ||
"### This file will extract a list of paper DOIs that belong to similar papers to a given seed paper.\n", | ||
"- Similarity is measured using sentence embeddings. Similarity threshold can be decided manually\n", | ||
"- Make sure to execute every cell sequentially. get_similar_papers() is the function that returns a list of relevant papers' DOIs. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "a73d16fc-1f75-418e-ae54-c87cee4d6bee", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/bowenyi/.local/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", | ||
" from tqdm.autonotebook import tqdm, trange\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", | ||
"\n", | ||
"import json\n", | ||
"from sentence_transformers import SentenceTransformer, util\n", | ||
"import numpy as np\n", | ||
"import random\n", | ||
"random.seed(42)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "99cc20d3-8586-44b2-bdc0-d6e1e5496c75", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/bowenyi/.local/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", | ||
" warnings.warn(\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = SentenceTransformer('all-mpnet-base-v2').to(device)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 63, | ||
"id": "19993d38-f81b-4bbd-afdd-83b866aba918", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"with open(\"seed_paper_10.json\") as file: # Take seed paper 10 as example\n", | ||
" refs_10 = json.load(file) \n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 82, | ||
"id": "6408e564-5e9c-44de-bbef-f131c99416cb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def is_similar(target, ref, threshold): # Use sentence embeddings to measure the similarity and return a similarity score\n", | ||
" target_embedding = model.encode(target, convert_to_tensor=True, device=device)\n", | ||
" ref_embeddings = model.encode(ref, convert_to_tensor=True, device=device)\n", | ||
" similarity_scores = model.similarity(target_embedding, ref_embeddings)\n", | ||
" return similarity_scores[0] >= threshold, similarity_scores[0]\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 39, | ||
"id": "49e060f8-8aee-4b6c-bb75-60427e74e9fd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def concatenate_title_abs(title, abs): # Concatenate title with abstract in account of type mismatch\n", | ||
" if type(title) != str or type(abs) != str:\n", | ||
" if type(title) == list:\n", | ||
" str_title = \"\"\n", | ||
" for text in title:\n", | ||
" str_title += text + \" \"\n", | ||
" title = str_title\n", | ||
" \n", | ||
" if type(abs) == list:\n", | ||
" str_abs = \"\"\n", | ||
" for text in abs:\n", | ||
" str_abs += text + \" \"\n", | ||
" abs = str_abs\n", | ||
" \n", | ||
" \n", | ||
" return str(title) + \": \" + str(abs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 88, | ||
"id": "2f0850f0-fdb3-41c6-9138-05b0a2198f22", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Return (1) a list of DOIs from similar papers to a seed paper and (2) total number of papers\n", | ||
"# Sample usage: similar_papers_8, paper_count_8 = get_similar_papers(seed_paper_refs=refs_8)\n", | ||
"\n", | ||
"def get_similar_papers(seed_paper_refs, similarity_threshold=0.64): \n", | ||
" similar_papers = []\n", | ||
" paper_count = 1\n", | ||
" \n", | ||
" seed_title = seed_paper_refs[0]['metadata']['title']\n", | ||
" seed_abs = seed_paper_refs[0]['metadata']['abstract']\n", | ||
" seed_title_abs = concatenate_title_abs(seed_title, seed_abs) \n", | ||
" \n", | ||
" for lev_1_ref in seed_paper_refs[0]['references']:\n", | ||
" lev_1_title = lev_1_ref['metadata']['title']\n", | ||
" lev_1_abs = lev_1_ref['metadata']['abstract']\n", | ||
" lev_1_title_abs = concatenate_title_abs(lev_1_title, lev_1_abs) \n", | ||
" similar_lev_1, score_lev_1 = is_similar(seed_title_abs, lev_1_title_abs, similarity_threshold)\n", | ||
" \n", | ||
" paper_count += 1 + len(lev_1_ref['references'])\n", | ||
" \n", | ||
" \n", | ||
" if similar_lev_1 and len(lev_1_ref['metadata']['doi']) != 0:\n", | ||
" similar_papers.append(lev_1_ref['metadata']['doi'])\n", | ||
" # similar_papers.append((lev_1_ref['metadata']['doi'], score_lev_1))\n", | ||
"\n", | ||
" for lev_2_ref in lev_1_ref['references']:\n", | ||
" lev_2_title = lev_2_ref['metadata']['title']\n", | ||
" lev_2_abs = lev_2_ref['metadata']['abstract']\n", | ||
" lev_2_title_abs = concatenate_title_abs(lev_2_title, lev_2_abs)\n", | ||
" \n", | ||
" target_lev_2 = seed_title_abs + \" \" + lev_1_title_abs\n", | ||
" similar_lev_2, score_lev_2 = is_similar(target_lev_2, lev_2_title_abs, similarity_threshold)\n", | ||
" \n", | ||
" \n", | ||
" if similar_lev_2 and len(lev_2_ref['metadata']['doi']) != 0:\n", | ||
" similar_papers.append(lev_2_ref['metadata']['doi'])\n", | ||
" # similar_papers.append((lev_2_ref['metadata']['doi'], score_lev_2))\n", | ||
" \n", | ||
"\n", | ||
" return similar_papers, paper_count\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 92, | ||
"id": "267ea957-37d3-4e99-9778-4aac55650f37", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Sample usage:\n", | ||
"similar_papers_10, paper_count_10 = get_similar_papers(seed_paper_refs=refs_10)\n", | ||
"print(len(similar_papers_10))\n", | ||
"print(paper_count_10)\n", | ||
"print(similar_papers_10[0:10])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "487d1fc9-e506-4ec7-a3d5-26622ba7fa44", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |