Skip to content

Commit

Permalink
Updated function to get similar papers
Browse files Browse the repository at this point in the history
1. Set the similarity threshold to 0.64 (keep all papers with similarity scores > 0.64 )
2. Updated comparison strategy: On level 1, compare the paper with the seed paper. If similarity score is higher than the threshold, than compare its references on level 2 with a concatenation of the seed paper + this paper on level 1.
  • Loading branch information
bowenyi-pierre authored Oct 19, 2024
1 parent 56bde2e commit c48dde0
Showing 1 changed file with 202 additions and 0 deletions.
202 changes: 202 additions & 0 deletions get_similar_papers.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "74c09945-0977-4124-834c-6fbf5c31c86d",
"metadata": {},
"source": [
"### This file will extract a list of paper DOIs that belong to similar papers to a given seed paper.\n",
"- Similarity is measured using sentence embeddings. Similarity threshold can be decided manually\n",
"- Make sure to execute every cell sequentially. get_similar_papers() is the function that returns a list of relevant papers' DOIs. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a73d16fc-1f75-418e-ae54-c87cee4d6bee",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bowenyi/.local/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
" from tqdm.autonotebook import tqdm, trange\n"
]
}
],
"source": [
"import torch\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"import json\n",
"from sentence_transformers import SentenceTransformer, util\n",
"import numpy as np\n",
"import random\n",
"random.seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "99cc20d3-8586-44b2-bdc0-d6e1e5496c75",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bowenyi/.local/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"model = SentenceTransformer('all-mpnet-base-v2').to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "19993d38-f81b-4bbd-afdd-83b866aba918",
"metadata": {},
"outputs": [],
"source": [
"with open(\"seed_paper_10.json\") as file: # Take seed paper 10 as example\n",
" refs_10 = json.load(file) \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "6408e564-5e9c-44de-bbef-f131c99416cb",
"metadata": {},
"outputs": [],
"source": [
"def is_similar(target, ref, threshold): # Use sentence embeddings to measure the similarity and return a similarity score\n",
" target_embedding = model.encode(target, convert_to_tensor=True, device=device)\n",
" ref_embeddings = model.encode(ref, convert_to_tensor=True, device=device)\n",
" similarity_scores = model.similarity(target_embedding, ref_embeddings)\n",
" return similarity_scores[0] >= threshold, similarity_scores[0]\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "49e060f8-8aee-4b6c-bb75-60427e74e9fd",
"metadata": {},
"outputs": [],
"source": [
"def concatenate_title_abs(title, abs): # Concatenate title with abstract in account of type mismatch\n",
" if type(title) != str or type(abs) != str:\n",
" if type(title) == list:\n",
" str_title = \"\"\n",
" for text in title:\n",
" str_title += text + \" \"\n",
" title = str_title\n",
" \n",
" if type(abs) == list:\n",
" str_abs = \"\"\n",
" for text in abs:\n",
" str_abs += text + \" \"\n",
" abs = str_abs\n",
" \n",
" \n",
" return str(title) + \": \" + str(abs)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "2f0850f0-fdb3-41c6-9138-05b0a2198f22",
"metadata": {},
"outputs": [],
"source": [
"# Return (1) a list of DOIs from similar papers to a seed paper and (2) total number of papers\n",
"# Sample usage: similar_papers_8, paper_count_8 = get_similar_papers(seed_paper_refs=refs_8)\n",
"\n",
"def get_similar_papers(seed_paper_refs, similarity_threshold=0.64): \n",
" similar_papers = []\n",
" paper_count = 1\n",
" \n",
" seed_title = seed_paper_refs[0]['metadata']['title']\n",
" seed_abs = seed_paper_refs[0]['metadata']['abstract']\n",
" seed_title_abs = concatenate_title_abs(seed_title, seed_abs) \n",
" \n",
" for lev_1_ref in seed_paper_refs[0]['references']:\n",
" lev_1_title = lev_1_ref['metadata']['title']\n",
" lev_1_abs = lev_1_ref['metadata']['abstract']\n",
" lev_1_title_abs = concatenate_title_abs(lev_1_title, lev_1_abs) \n",
" similar_lev_1, score_lev_1 = is_similar(seed_title_abs, lev_1_title_abs, similarity_threshold)\n",
" \n",
" paper_count += 1 + len(lev_1_ref['references'])\n",
" \n",
" \n",
" if similar_lev_1 and len(lev_1_ref['metadata']['doi']) != 0:\n",
" similar_papers.append(lev_1_ref['metadata']['doi'])\n",
" # similar_papers.append((lev_1_ref['metadata']['doi'], score_lev_1))\n",
"\n",
" for lev_2_ref in lev_1_ref['references']:\n",
" lev_2_title = lev_2_ref['metadata']['title']\n",
" lev_2_abs = lev_2_ref['metadata']['abstract']\n",
" lev_2_title_abs = concatenate_title_abs(lev_2_title, lev_2_abs)\n",
" \n",
" target_lev_2 = seed_title_abs + \" \" + lev_1_title_abs\n",
" similar_lev_2, score_lev_2 = is_similar(target_lev_2, lev_2_title_abs, similarity_threshold)\n",
" \n",
" \n",
" if similar_lev_2 and len(lev_2_ref['metadata']['doi']) != 0:\n",
" similar_papers.append(lev_2_ref['metadata']['doi'])\n",
" # similar_papers.append((lev_2_ref['metadata']['doi'], score_lev_2))\n",
" \n",
"\n",
" return similar_papers, paper_count\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 92,
"id": "267ea957-37d3-4e99-9778-4aac55650f37",
"metadata": {},
"outputs": [],
"source": [
"# Sample usage:\n",
"similar_papers_10, paper_count_10 = get_similar_papers(seed_paper_refs=refs_10)\n",
"print(len(similar_papers_10))\n",
"print(paper_count_10)\n",
"print(similar_papers_10[0:10])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "487d1fc9-e506-4ec7-a3d5-26622ba7fa44",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit c48dde0

Please sign in to comment.