diff --git a/README.md b/README.md
index dd12b56..09ef427 100644
--- a/README.md
+++ b/README.md
@@ -113,6 +113,11 @@ Triple Score: 0.5
Avg. ROUGE-1: 0.4415584415584415
Avg. ROUGE-2: 0.3287671232876712
Avg. ROUGE-L: 0.4415584415584415
+
+BERTScore Score
+Precision: 0.9151781797409058
+Recall: 0.9141832590103149
+F1: 0.9150083661079407
```
@@ -241,6 +246,22 @@ Simple but effective word-level overlap ROUGE score
+### BERTScore Module
+
+```python
+>>> from factsumm import FactSumm
+>>> factsumm = FactSumm()
+>>> factsumm.calculate_bert_score(article, summary)
+BERTScore Score
+Precision: 0.9151781797409058
+Recall: 0.9141832590103149
+F1: 0.9150083661079407
+```
+
+[BERTScore](https://github.com/Tiiiger/bert_score) can be used to calculate the similarity between each source sentence and the summary sentence
+
+
+
### Citation
If you apply this library to any project, please cite:
diff --git a/factsumm/__init__.py b/factsumm/__init__.py
index 6ba8dc3..4d046ce 100644
--- a/factsumm/__init__.py
+++ b/factsumm/__init__.py
@@ -8,7 +8,7 @@
from sumeval.metrics.rouge import RougeCalculator
from factsumm.utils.level_entity import load_ie, load_ner, load_rel
-from factsumm.utils.level_sentence import load_qa, load_qg
+from factsumm.utils.level_sentence import load_bert_score, load_qa, load_qg
from factsumm.utils.utils import Config, qags_score
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -26,6 +26,7 @@ def __init__(
rel_model: str = None,
qg_model: str = None,
qa_model: str = None,
+ bert_score_model: str = None,
):
self.config = Config()
self.segmenter = pysbd.Segmenter(language="en", clean=False)
@@ -36,6 +37,7 @@ def __init__(
self.rel = rel_model if rel_model is not None else self.config.REL_MODEL
self.qg = qg_model if qg_model is not None else self.config.QG_MODEL
self.qa = qa_model if qa_model is not None else self.config.QA_MODEL
+ self.bert_score = bert_score_model if bert_score_model is not None else self.config.BERT_SCORE_MODEL
self.ie = None
def build_perm(
@@ -321,12 +323,46 @@ def extract_triples(self, source: str, summary: str, verbose: bool = False):
return triple_score
+ def calculate_bert_score(self, source: str, summary: str):
+ """
+ Calculate BERTScore
+
+ See also https://arxiv.org/abs/2005.03754
+
+ Args:
+ source (str): original source
+ summary (str): generated summary
+
+ """
+ add_dummy = False
+
+ if isinstance(self.bert_score, str):
+ self.bert_score = load_bert_score(self.bert_score)
+
+ source_lines = self._segment(source)
+ summary_lines = [summary, "dummy"]
+
+ scores = self.bert_score(summary_lines, source_lines)
+ filtered_scores = list()
+
+ for score in scores:
+ score = score.tolist()
+ score.pop(-1)
+ filtered_scores.append(sum(score) / len(score))
+
+ print(
+ f"BERTScore Score\nPrecision: {filtered_scores[0]}\nRecall: {filtered_scores[1]}\nF1: {filtered_scores[2]}"
+ )
+
+ return filtered_scores
+
def __call__(self, source: str, summary: str, verbose: bool = False):
source_ents, summary_ents, fact_score = self.extract_facts(
source,
summary,
verbose,
)
+
qags_score = self.extract_qas(
source,
summary,
@@ -334,12 +370,21 @@ def __call__(self, source: str, summary: str, verbose: bool = False):
summary_ents,
verbose,
)
+
triple_score = self.extract_triples(source, summary, verbose)
+
rouge_1, rouge_2, rouge_l = self.calculate_rouge(source, summary)
+ bert_scores = self.calculate_bert_score(source, summary)
+
return {
"fact_score": fact_score,
"qa_score": qags_score,
"triple_score": triple_score,
"rouge": (rouge_1, rouge_2, rouge_l),
+ "bert_score": {
+ "precision": bert_scores[0],
+ "recall": bert_scores[1],
+ "f1": bert_scores[2],
+ },
}
diff --git a/factsumm/utils/level_sentence.py b/factsumm/utils/level_sentence.py
index ffd676f..ded05c5 100644
--- a/factsumm/utils/level_sentence.py
+++ b/factsumm/utils/level_sentence.py
@@ -1,5 +1,6 @@
from typing import List
+from bert_score import BERTScorer
from rich import print
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
@@ -111,4 +112,18 @@ def answer_question(context: str, qa_pairs: List):
return answer_question
-# TODO: NLI, FactCC
+def load_bert_score(model: str):
+ """
+ Load BERTScore model from HuggingFace hub
+
+ Args:
+ model (str): model name to be loaded
+
+ Returns:
+ function: BERTScore score function
+
+ """
+ print("Loading BERTScore Pipeline...")
+
+ scorer = BERTScorer(model_type=model, lang="en", rescale_with_baseline=True)
+ return scorer.score
diff --git a/factsumm/utils/utils.py b/factsumm/utils/utils.py
index 3e59f65..66afb39 100644
--- a/factsumm/utils/utils.py
+++ b/factsumm/utils/utils.py
@@ -15,6 +15,7 @@ class Config:
QG_MODEL: str = "mrm8488/t5-base-finetuned-question-generation-ap"
QA_MODEL: str = "deepset/roberta-base-squad2"
SUMM_MODEL: str = "sshleifer/distilbart-cnn-12-6"
+ BERT_SCORE_MODEL: str = "microsoft/deberta-base-mnli"
def grouped_entities(entities: List[Dict]):