-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add CharRecallPrecision for OCR Task
- Loading branch information
1 parent
5a3647c
commit 753a303
Showing
4 changed files
with
130 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import re | ||
from difflib import SequenceMatcher | ||
from typing import Dict, Sequence, Tuple | ||
|
||
from mmeval.core import BaseMetric | ||
|
||
|
||
class CharRecallPrecision(BaseMetric): | ||
"""Calculate the char level recall & precision. | ||
Args: | ||
letter_case (str): There are three options to alter the letter cases | ||
- unchanged: Do not change prediction texts and labels. | ||
- upper: Convert prediction texts and labels into uppercase | ||
characters. | ||
- lower: Convert prediction texts and labels into lowercase | ||
characters. | ||
Usually, it only works for English characters. Defaults to | ||
'unchanged'. | ||
valid_symbol (str): Valid characters. Defaults to | ||
'[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. | ||
Examples: | ||
>>> from mmeval import CharRecallPrecision | ||
>>> metric = CharRecallPrecision() | ||
>>> metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
{'char_recall': 0.6, 'char_precision': 0.8571428571428571} | ||
>>> metric = CharRecallPrecision(letter_case='upper') | ||
>>> metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
{'char_recall': 0.7, 'char_precision': 1.0} | ||
""" | ||
|
||
def __init__(self, | ||
letter_case: str = 'unchanged', | ||
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', | ||
**kwargs): | ||
super().__init__(**kwargs) | ||
assert letter_case in ['unchanged', 'upper', 'lower'] | ||
self.letter_case = letter_case | ||
self.valid_symbol = re.compile(valid_symbol) | ||
|
||
def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 | ||
"""Process one batch of data and predictions. | ||
Args: | ||
predictions (list[str]): The prediction texts. | ||
labels (list[str]): The ground truth texts. | ||
""" | ||
for pred, label in zip(predictions, labels): | ||
if self.letter_case in ['upper', 'lower']: | ||
pred = getattr(pred, self.letter_case)() | ||
label = getattr(label, self.letter_case)() | ||
label_ignore = self.valid_symbol.sub('', label) | ||
pred_ignore = self.valid_symbol.sub('', pred) | ||
# number to calculate char level recall & precision | ||
true_positive_char_num = self._cal_true_positive_char( | ||
pred_ignore, label_ignore) | ||
self._results.append( | ||
(len(label_ignore), len(pred_ignore), true_positive_char_num)) | ||
|
||
def compute_metric(self, results: Sequence[Tuple[int, int, int]]) -> Dict: | ||
"""Compute the metrics from processed results. | ||
Args: | ||
results (list[tuple]): The processed results of each batch. | ||
Returns: | ||
Dict: The computed metrics. The keys are the names of the | ||
metrics, and the values are corresponding results. | ||
""" | ||
gt_sum, pred_sum, true_positive_sum = 0.0, 0.0, 0.0 | ||
for gt, pred, true_positive in results: | ||
gt_sum += gt | ||
pred_sum += pred | ||
true_positive_sum += true_positive | ||
char_recall = true_positive_sum / max(gt_sum, 1.0) | ||
char_precision = true_positive_sum / max(pred_sum, 1.0) | ||
eval_res = {} | ||
eval_res['recall'] = char_recall | ||
eval_res['precision'] = char_precision | ||
return eval_res | ||
|
||
def _cal_true_positive_char(self, pred: str, gt: str) -> int: | ||
"""Calculate correct character number in prediction. | ||
Args: | ||
pred (str): Prediction text. | ||
gt (str): Ground truth text. | ||
Returns: | ||
true_positive_char_num (int): The true positive number. | ||
""" | ||
|
||
all_opt = SequenceMatcher(None, pred, gt) | ||
true_positive_char_num = 0 | ||
for opt, _, _, s2, e2 in all_opt.get_opcodes(): | ||
if opt == 'equal': | ||
true_positive_char_num += (e2 - s2) | ||
else: | ||
pass | ||
return true_positive_char_num |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
difflib | ||
opencv-python!=4.5.5.62,!=4.5.5.64 | ||
pycocotools | ||
scipy | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pytest | ||
|
||
from mmeval import CharRecallPrecision | ||
|
||
|
||
def test_init(): | ||
with pytest.raises(AssertionError): | ||
CharRecallPrecision(letter_case='fake') | ||
|
||
|
||
def test_char_recall_precision_metric(): | ||
metric = CharRecallPrecision(letter_case='lower') | ||
res = metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
assert abs(res['recall'] - 0.7) < 1e-7 | ||
assert abs(res['precision'] - 1) < 1e-7 | ||
|
||
metric = CharRecallPrecision(letter_case='upper') | ||
res = metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
assert abs(res['recall'] - 0.7) < 1e-7 | ||
assert abs(res['precision'] - 1) < 1e-7 | ||
|
||
metric = CharRecallPrecision(letter_case='unchanged') | ||
res = metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
assert abs(res['recall'] - 0.6) < 1e-7 | ||
assert abs(res['precision'] - 6.0 / 7) < 1e-7 |