Skip to content

Commit

Permalink
[chore] fix BERTScore length issue & Resolve unnatural Relation
Browse files Browse the repository at this point in the history
  • Loading branch information
karter-liner committed Jan 1, 2024
1 parent 3bb0ed3 commit 665a20e
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 53 deletions.
104 changes: 60 additions & 44 deletions factsumm/factsumm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def __init__(
self.rel = rel_model if rel_model is not None else self.config.REL_MODEL
self.qg = qg_model if qg_model is not None else self.config.QG_MODEL
self.qa = qa_model if qa_model is not None else self.config.QA_MODEL
self.bert_score = bert_score_model if bert_score_model is not None else self.config.BERT_SCORE_MODEL
self.ie = None
self.bert_score = bert_score_model

def build_perm(
self,
Expand Down Expand Up @@ -99,8 +98,32 @@ def get_facts(self, lines: List[str], entities: List[List[Dict]]) -> Set:
perms = self.build_perm(lines, entities)
triples = []

for perm in perms:
triples.extend(self.rel(perm))
for perm, entity in zip(perms, entities):
entity_key = {ent["word"].replace("▁", ""): ent["entity"] for ent in entity}
facts = self.rel(perm)
filtered_facts = []

for fact in facts:
head, relation, tail = fact

head = head.strip()
tail = tail.strip()

head_entity_type = entity_key.get(head, None)
tail_entity_type = entity_key.get(tail, None)

if head_entity_type is not None and head_entity_type == "PERSON" and not relation.startswith("per:"):
continue

if head_entity_type is not None and head_entity_type != "PERSON" and relation.startswith("per:"):
continue

if tail_entity_type is not None and tail_entity_type != "PERSON" and "members" in relation:
continue

filtered_facts.append(tuple([head, relation, tail]))

triples.extend(filtered_facts)

return set(triples)

Expand Down Expand Up @@ -301,19 +324,12 @@ def extract_qas(

return qa_score

def _print_triples(self, mode: str, triples: Set):
logging.info("%s Triples", mode.capitalize())
for triple in triples:
logging.info(triple)



def calculate_bert_score(
self,
source: str,
summary: str,
device: str = "cpu",
) -> List[float]:
) -> Dict[str, float]:
"""
Calculate BERTScore
Expand All @@ -325,27 +341,30 @@ def calculate_bert_score(
device (str): device info
Returns:
List: (Precision, Recall, F1) BERTScore list
Dict: (Precision, Recall, F1) BERTScore dictionary
"""
if isinstance(self.bert_score, str):
self.bert_score = load_bert_score(self.bert_score, device)
if self.bert_score is None:
self.bert_score = load_bert_score(device)

# BUG: When len(source_lines) == 1, bmm error raises
source_lines = self._segment_sentence(source)
summary_lines = [summary, "dummy"]
summary_lines = self._segment_sentence(summary)

scores = self.bert_score(summary_lines, source_lines)
filtered_scores = []
scores = {
"precision": 0.0,
"recall": 0.0,
"f1": 0.0,
}

for score in scores:
score = score.tolist()
score.pop(-1)
filtered_scores.append(sum(score) / len(score))
for summary_line in summary_lines:
precision, recall, f1 = self.bert_score([summary_line], [source_lines])
scores["precision"] += precision.item()
scores["recall"] += recall.item()
scores["f1"] += f1.item()

logging.info("<BERTScore Score>\nPrecision: %s\nRecall: %s\nF1: %s", filtered_scores[0], filtered_scores[1], filtered_scores[1])
logging.info("<BERTScore Score>\nPrecision: %s\nRecall: %s\nF1: %s", scores["precision"], scores["recall"], scores["f1"])

return filtered_scores
return scores

def __call__(
self,
Expand All @@ -365,8 +384,16 @@ def __call__(

fact_scores = 0
qags_scores = 0
rouges = [0, 0, 0]
bert_scores = [0, 0, 0]
rouge_scores = {
"rouge-1": 0.0,
"rouge-2": 0.0,
"rouge-l": 0.0,
}
bert_scores = {
"precision": 0.0,
"recall": 0.0,
"f1": 0.0,
}

for source, summary in zip(sources, summaries):
source_ents, summary_ents, fact_score = self.extract_facts(
Expand All @@ -388,26 +415,15 @@ def __call__(
qags_scores += qags_score

rouge_1, rouge_2, rouge_l = self.calculate_rouge(source, summary)
rouges[0] += rouge_1
rouges[1] += rouge_2
rouges[2] += rouge_l
rouge_scores["rouge-1"] += rouge_1
rouge_scores["rouge-2"] += rouge_2
rouge_scores["rouge-l"] += rouge_l

bert_score = self.calculate_bert_score(source, summary, device)
bert_scores[0] += bert_score[0]
bert_scores[1] += bert_score[1]
bert_scores[2] += bert_score[2]
bert_scores = self.calculate_bert_score(source, summary, device)

return {
"fact_score": fact_scores / num_pairs,
"qa_score": qags_scores / num_pairs,
"rouge": (
rouges[0] / num_pairs,
rouges[1] / num_pairs,
rouges[2] / num_pairs,
),
"bert_score": {
"precision": bert_scores[0],
"recall": bert_scores[1],
"f1": bert_scores[2],
},
"rouge": rouge_scores,
"bert_score": bert_scores,
}
4 changes: 2 additions & 2 deletions factsumm/utils/module_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def load_ner(model: str, device: str) -> object:
object: Pipeline-based Named Entity Recognition model
"""
logging.info("Loading Named Entity Recognition Pipeline...")
logging.debug("Loading Named Entity Recognition Pipeline...")

try:
ner = pipeline(
Expand Down Expand Up @@ -61,7 +61,7 @@ def load_rel(model: str, device: str):
function: LUKE-based Relation Extraction function
"""
logging.info("Loading Relation Extraction Pipeline...")
logging.debug("Loading Relation Extraction Pipeline...")

try:
# yapf:disable
Expand Down
4 changes: 2 additions & 2 deletions factsumm/utils/module_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def load_qg(model: str, device: str):
function: question generation function
"""
logging.info("Loading Question Generation Pipeline...")
logging.debug("Loading Question Generation Pipeline...")

try:
tokenizer = AutoTokenizer.from_pretrained(model)
Expand Down Expand Up @@ -82,7 +82,7 @@ def load_qa(model: str, device: str):
function: question answering function
"""
logging.info("Loading Question Answering Pipeline...")
logging.debug("Loading Question Answering Pipeline...")

try:
qa = pipeline(
Expand Down
6 changes: 2 additions & 4 deletions factsumm/utils/module_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from bert_score import BERTScorer


def load_bert_score(model: str, device: str):
def load_bert_score(device: str):
"""
Load BERTScore model from HuggingFace hub
Expand All @@ -15,13 +15,11 @@ def load_bert_score(model: str, device: str):
function: BERTScore score function
"""
logging.info("Loading BERTScore Pipeline...")
logging.debug("Loading BERTScore Pipeline...")

try:
scorer = BERTScorer(
model_type=model,
lang="en",
rescale_with_baseline=True,
device=device,
)
return scorer.score
Expand Down
1 change: 0 additions & 1 deletion factsumm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class Config:
QG_MODEL: str = "mrm8488/t5-base-finetuned-question-generation-ap"
QA_MODEL: str = "deepset/roberta-base-squad2"
SUMM_MODEL: str = "sshleifer/distilbart-cnn-12-6"
BERT_SCORE_MODEL: str = "microsoft/deberta-base-mnli"


def grouped_entities(entities: List[Dict]) -> List:
Expand Down

0 comments on commit 665a20e

Please sign in to comment.