-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add OneMinusNormEditDistance for OCR Task
- Loading branch information
1 parent
5a3647c
commit eba0ab8
Showing
4 changed files
with
108 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import re | ||
from typing import TYPE_CHECKING, Dict, List, Sequence | ||
|
||
from mmeval.core import BaseMetric | ||
from mmeval.utils import try_import | ||
|
||
if TYPE_CHECKING: | ||
from rapidfuzz.distance import Levenshtein | ||
else: | ||
distance = try_import('rapidfuzz.distance') | ||
if distance is not None: | ||
Levenshtein = distance.Levenshtein | ||
|
||
|
||
class OneMinusNormEditDistance(BaseMetric): | ||
"""One minus NED metric for text recognition task. | ||
Args: | ||
letter_case (str): There are three options to alter the letter cases | ||
- unchanged: Do not change prediction texts and labels. | ||
- upper: Convert prediction texts and labels into uppercase | ||
characters. | ||
- lower: Convert prediction texts and labels into lowercase | ||
characters. | ||
Usually, it only works for English characters. Defaults to | ||
'unchanged'. | ||
valid_symbol (str): Valid characters. Defaults to | ||
'[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. | ||
Example: | ||
>>> from mmeval import OneMinusNormEditDistance | ||
>>> metric = OneMinusNormEditDistance() | ||
>>> metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
{'1-N.E.D': 0.6} | ||
>>> metric = OneMinusNormEditDistance(letter_case='upper') | ||
>>> metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
{'1-N.E.D': 0.7} | ||
""" | ||
|
||
def __init__(self, | ||
letter_case: str = 'unchanged', | ||
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', | ||
**kwargs): | ||
super().__init__(**kwargs) | ||
|
||
assert letter_case in ['unchanged', 'upper', 'lower'] | ||
self.letter_case = letter_case | ||
self.valid_symbol = re.compile(valid_symbol) | ||
|
||
def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 | ||
"""Process one batch of data and predictions. | ||
Args: | ||
predictions (list[str]): The prediction texts. | ||
labels (list[str]): The ground truth texts. | ||
""" | ||
for pred, label in zip(predictions, labels): | ||
if self.letter_case in ['upper', 'lower']: | ||
pred = getattr(pred, self.letter_case)() | ||
label = getattr(label, self.letter_case)() | ||
label = self.valid_symbol.sub('', label) | ||
pred = self.valid_symbol.sub('', pred) | ||
norm_ed = Levenshtein.normalized_distance(pred, label) | ||
self._results.append(norm_ed) | ||
|
||
def compute_metric(self, results: List[float]) -> Dict: | ||
"""Compute the metrics from processed results. | ||
Args: | ||
results (list[float]): The processed results of each batch. | ||
Returns: | ||
dict[str, float]: Nested dicts as results. | ||
- 1-N.E.D (float): One minus the normalized edit distance. | ||
""" | ||
gt_word_num = len(results) | ||
norm_ed_sum = sum(results) | ||
normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num) | ||
eval_res = {} | ||
eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance | ||
return eval_res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import pytest | ||
|
||
from mmeval import OneMinusNormEditDistance | ||
|
||
|
||
def test_init(): | ||
with pytest.raises(AssertionError): | ||
OneMinusNormEditDistance(letter_case='fake') | ||
|
||
|
||
def test_one_minus_norm_edit_distance_metric(): | ||
metric = OneMinusNormEditDistance(letter_case='lower') | ||
res = metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
assert abs(res['1-N.E.D'] - 0.7) < 1e-7 | ||
metric = OneMinusNormEditDistance(letter_case='upper') | ||
res = metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
assert abs(res['1-N.E.D'] - 0.7) < 1e-7 | ||
metric = OneMinusNormEditDistance() | ||
res = metric(['helL', 'HEL'], ['hello', 'HELLO']) | ||
assert abs(res['1-N.E.D'] - 0.6) < 1e-7 | ||
|
||
|
||
test_one_minus_norm_edit_distance_metric() |