From fa94b14272955c9ac34bcce6e53016d3179a3120 Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Wed, 15 Feb 2023 17:36:28 +0800 Subject: [PATCH] [Feature] Adapt MMEval for CharRecallPersion WordAccuracy OneMinusNDE --- .../evaluator/multi_datasets_evaluator.py | 33 +- mmocr/evaluation/metrics/recog_metric.py | 337 +++++------------- .../test_metrics/test_recog_metric.py | 24 +- 3 files changed, 117 insertions(+), 277 deletions(-) diff --git a/mmocr/evaluation/evaluator/multi_datasets_evaluator.py b/mmocr/evaluation/evaluator/multi_datasets_evaluator.py index f01aa70f6..647d2b846 100644 --- a/mmocr/evaluation/evaluator/multi_datasets_evaluator.py +++ b/mmocr/evaluation/evaluator/multi_datasets_evaluator.py @@ -3,10 +3,8 @@ from collections import OrderedDict from typing import Sequence, Union -from mmengine.dist import (broadcast_object_list, collect_results, - is_main_process) +from mmengine.dist import broadcast_object_list, is_main_process from mmengine.evaluator import BaseMetric, Evaluator -from mmengine.evaluator.metric import _to_cpu from mmocr.registry import EVALUATOR from mmocr.utils.typing_utils import ConfigType @@ -52,31 +50,32 @@ def evaluate(self, size: int) -> dict: dataset_slices = self.dataset_meta.get('cumulative_sizes', [size]) assert len(dataset_slices) == len(self.dataset_prefixes) for metric in self.metrics: - if len(metric.results) == 0: + if len(metric._results) == 0: warnings.warn( f'{metric.__class__.__name__} got empty `self.results`.' 'Please ensure that the processed results are properly ' 'added into `self.results` in `process` method.') - results = collect_results(metric.results, size, - metric.collect_device) - + global_results = metric.dist_comm.all_gather_object( + metric._results) + if metric.dist_collect_mode == 'cat': + # use `sum` to concatenate list + # e.g. sum([[1, 3], [2, 4]], []) = [1, 3, 2, 4] + collected_results = sum(global_results, []) + else: + collected_results = [] + for partial_result in zip(*global_results): + collected_results.extend(list(partial_result)) if is_main_process(): - # cast all tensors in results list to cpu - results = _to_cpu(results) for start, end, dataset_prefix in zip([0] + dataset_slices[:-1], dataset_slices, self.dataset_prefixes): - metric_results = metric.compute_metrics( - results[start:end]) # type: ignore + metric_results = metric.compute_metric( + collected_results[start:end]) # type: ignore # Add prefix to metric names - if metric.prefix: - final_prefix = '/'.join( - (dataset_prefix, metric.prefix)) - else: - final_prefix = dataset_prefix + final_prefix = '/'.join((dataset_prefix, metric.name)) metric_results = { '/'.join((final_prefix, k)): v for k, v in metric_results.items() @@ -90,7 +89,7 @@ def evaluate(self, size: int) -> dict: f'the same metric name {name}. Please make ' 'sure all metrics have different prefixes.') metrics_results.update(metric_results) - metric.results.clear() + metric.reset() if is_main_process(): metrics_results = [metrics_results] else: diff --git a/mmocr/evaluation/metrics/recog_metric.py b/mmocr/evaluation/metrics/recog_metric.py index a04695121..9d1d9e386 100644 --- a/mmocr/evaluation/metrics/recog_metric.py +++ b/mmocr/evaluation/metrics/recog_metric.py @@ -1,18 +1,56 @@ # Copyright (c) OpenMMLab. All rights reserved. -import re -from difflib import SequenceMatcher +import warnings from typing import Dict, Optional, Sequence, Union -import mmengine -from mmengine.evaluator import BaseMetric -from rapidfuzz.distance import Levenshtein +from mmeval import CharRecallPrecision as _CharRecallPrecision +from mmeval import OneMinusNormEditDistance as _OneMinusNormEditDistance +from mmeval import WordAccuracy as _WordAccuracy from mmocr.registry import METRICS +class TextRecogMixin: + + def __init__(self, + prefix: Optional[str] = None, + collect_device: Optional[str] = None) -> None: + if prefix is not None: + warnings.warn('DeprecationWarning: The `prefix` parameter of' + f' `{self.name}` is deprecated.') + if collect_device is not None: + warnings.warn( + 'DeprecationWarning: The `collect_device` parameter of ' + f'`{self.name}` is deprecated, use `dist_backend` instead.') + + def process(self, data_batch: Sequence[Dict], + predictions: Sequence[Dict]) -> None: + """Process one batch of predictions. The processed results should be + stored in ``self.results``, which will be used to compute the metrics + when all batches have been processed. + + Args: + data_batch (Sequence[Dict]): A batch of gts. + predictions (Sequence[Dict]): A batch of outputs from the model. + """ + preds, labels = list(), list() + for data_sample in predictions: + preds.append(data_sample.get('pred_text').get('item')) + labels.append(data_sample.get('gt_text').get('item')) + self.add(preds, labels) + + def evaluate(self, size: int): + metric_results = self.compute(size) + metric_results = { + f'{self.name}/{k}': v + for k, v in metric_results.items() + } + self.reset() + return metric_results + + @METRICS.register_module() -class WordMetric(BaseMetric): - """Word metrics for text recognition task. +class WordMetric(_WordAccuracy, TextRecogMixin): + """Calculate the word level accuracy. Args: mode (str or list[str]): Options are: @@ -22,271 +60,72 @@ class WordMetric(BaseMetric): - 'ignore_case_symbol': Accuracy at word level, ignoring letter case and symbol. (Default metric for academic evaluation) If mode is a list, then metrics in mode will be calculated - separately. Defaults to 'ignore_case_symbol' + separately. Defaults to 'ignore_case_symbol'. valid_symbol (str): Valid characters. Defaults to - '[^A-Z^a-z^0-9^\u4e00-\u9fa5]' - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Defaults to None. + '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. """ - default_prefix: Optional[str] = 'recog' - def __init__(self, mode: Union[str, Sequence[str]] = 'ignore_case_symbol', valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', - collect_device: str = 'cpu', - prefix: Optional[str] = None) -> None: - super().__init__(collect_device, prefix) - self.valid_symbol = re.compile(valid_symbol) - if isinstance(mode, str): - mode = [mode] - assert mmengine.is_seq_of(mode, str) - assert set(mode).issubset( - {'exact', 'ignore_case', 'ignore_case_symbol'}) - self.mode = set(mode) - - def process(self, data_batch: Sequence[Dict], - data_samples: Sequence[Dict]) -> None: - """Process one batch of data_samples. The processed results should be - stored in ``self.results``, which will be used to compute the metrics - when all batches have been processed. - - Args: - data_batch (Sequence[Dict]): A batch of gts. - data_samples (Sequence[Dict]): A batch of outputs from the model. - """ - for data_sample in data_samples: - match_num = 0 - match_ignore_case_num = 0 - match_ignore_case_symbol_num = 0 - pred_text = data_sample.get('pred_text').get('item') - gt_text = data_sample.get('gt_text').get('item') - if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode: - pred_text_lower = pred_text.lower() - gt_text_lower = gt_text.lower() - if 'ignore_case_symbol' in self.mode: - gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) - pred_text_lower_ignore = self.valid_symbol.sub( - '', pred_text_lower) - match_ignore_case_symbol_num =\ - gt_text_lower_ignore == pred_text_lower_ignore - if 'ignore_case' in self.mode: - match_ignore_case_num = pred_text_lower == gt_text_lower - if 'exact' in self.mode: - match_num = pred_text == gt_text - result = dict( - match_num=match_num, - match_ignore_case_num=match_ignore_case_num, - match_ignore_case_symbol_num=match_ignore_case_symbol_num) - self.results.append(result) - - def compute_metrics(self, results: Sequence[Dict]) -> Dict: - """Compute the metrics from processed results. - - Args: - results (list[Dict]): The processed results of each batch. - - Returns: - Dict: The computed metrics. The keys are the names of the metrics, - and the values are corresponding results. - """ - - eps = 1e-8 - eval_res = {} - gt_word_num = len(results) - if 'exact' in self.mode: - match_nums = [result['match_num'] for result in results] - match_nums = sum(match_nums) - eval_res['word_acc'] = 1.0 * match_nums / (eps + gt_word_num) - if 'ignore_case' in self.mode: - match_ignore_case_num = [ - result['match_ignore_case_num'] for result in results - ] - match_ignore_case_num = sum(match_ignore_case_num) - eval_res['word_acc_ignore_case'] = 1.0 *\ - match_ignore_case_num / (eps + gt_word_num) - if 'ignore_case_symbol' in self.mode: - match_ignore_case_symbol_num = [ - result['match_ignore_case_symbol_num'] for result in results - ] - match_ignore_case_symbol_num = sum(match_ignore_case_symbol_num) - eval_res['word_acc_ignore_case_symbol'] = 1.0 *\ - match_ignore_case_symbol_num / (eps + gt_word_num) - - for key, value in eval_res.items(): - eval_res[key] = float(f'{value:.4f}') - return eval_res + **kwargs) -> None: + collect_device = kwargs.pop('collect_device', None) + prefix = kwargs.pop('prefix', None) + TextRecogMixin.__init__(collect_device, prefix) + super().__init__(mode=mode, valid_symbol=valid_symbol, **kwargs) @METRICS.register_module() -class CharMetric(BaseMetric): - """Character metrics for text recognition task. +class CharMetric(_CharRecallPrecision, TextRecogMixin): + """Calculate the char level recall & precision. Args: - valid_symbol (str): Valid characters. - Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]' - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Defaults to None. + letter_case (str): There are three options to alter the letter cases + - unchanged: Do not change prediction texts and labels. + - upper: Convert prediction texts and labels into uppercase + characters. + - lower: Convert prediction texts and labels into lowercase + characters. + Usually, it only works for English characters. Defaults to + 'unchanged'. + valid_symbol (str): Valid characters. Defaults to + '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. """ - default_prefix: Optional[str] = 'recog' - def __init__(self, + letter_case: str = 'lower', valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', - collect_device: str = 'cpu', - prefix: Optional[str] = None) -> None: - super().__init__(collect_device, prefix) - self.valid_symbol = re.compile(valid_symbol) - - def process(self, data_batch: Sequence[Dict], - data_samples: Sequence[Dict]) -> None: - """Process one batch of data_samples. The processed results should be - stored in ``self.results``, which will be used to compute the metrics - when all batches have been processed. - - Args: - data_batch (Sequence[Dict]): A batch of gts. - data_samples (Sequence[Dict]): A batch of outputs from the model. - """ - for data_sample in data_samples: - pred_text = data_sample.get('pred_text').get('item') - gt_text = data_sample.get('gt_text').get('item') - gt_text_lower = gt_text.lower() - pred_text_lower = pred_text.lower() - gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) - pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower) - # number to calculate char level recall & precision - result = dict( - gt_char_num=len(gt_text_lower_ignore), - pred_char_num=len(pred_text_lower_ignore), - true_positive_char_num=self._cal_true_positive_char( - pred_text_lower_ignore, gt_text_lower_ignore)) - self.results.append(result) - - def compute_metrics(self, results: Sequence[Dict]) -> Dict: - """Compute the metrics from processed results. - - Args: - results (list[Dict]): The processed results of each batch. - - Returns: - Dict: The computed metrics. The keys are the names of the - metrics, and the values are corresponding results. - """ - gt_char_num = [result['gt_char_num'] for result in results] - pred_char_num = [result['pred_char_num'] for result in results] - true_positive_char_num = [ - result['true_positive_char_num'] for result in results - ] - gt_char_num = sum(gt_char_num) - pred_char_num = sum(pred_char_num) - true_positive_char_num = sum(true_positive_char_num) - - eps = 1e-8 - char_recall = 1.0 * true_positive_char_num / (eps + gt_char_num) - char_precision = 1.0 * true_positive_char_num / (eps + pred_char_num) - eval_res = {} - eval_res['char_recall'] = char_recall - eval_res['char_precision'] = char_precision - - for key, value in eval_res.items(): - eval_res[key] = float(f'{value:.4f}') - return eval_res - - def _cal_true_positive_char(self, pred: str, gt: str) -> int: - """Calculate correct character number in prediction. - - Args: - pred (str): Prediction text. - gt (str): Ground truth text. - - Returns: - true_positive_char_num (int): The true positive number. - """ - - all_opt = SequenceMatcher(None, pred, gt) - true_positive_char_num = 0 - for opt, _, _, s2, e2 in all_opt.get_opcodes(): - if opt == 'equal': - true_positive_char_num += (e2 - s2) - else: - pass - return true_positive_char_num + **kwargs) -> None: + collect_device = kwargs.pop('collect_device', None) + prefix = kwargs.pop('prefix', None) + super().__init__( + letter_case=letter_case, valid_symbol=valid_symbol, **kwargs) + TextRecogMixin.__init__(collect_device, prefix) @METRICS.register_module() -class OneMinusNEDMetric(BaseMetric): +class OneMinusNEDMetric(_OneMinusNormEditDistance, TextRecogMixin): """One minus NED metric for text recognition task. Args: + letter_case (str): There are three options to alter the letter cases + - unchanged: Do not change prediction texts and labels. + - upper: Convert prediction texts and labels into uppercase + characters. + - lower: Convert prediction texts and labels into lowercase + characters. + Usually, it only works for English characters. Defaults to + 'unchanged'. valid_symbol (str): Valid characters. Defaults to - '[^A-Z^a-z^0-9^\u4e00-\u9fa5]' - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Defaults to None + '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. """ - default_prefix: Optional[str] = 'recog' def __init__(self, + letter_case: str = 'lower', valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', - collect_device: str = 'cpu', - prefix: Optional[str] = None) -> None: - super().__init__(collect_device, prefix) - self.valid_symbol = re.compile(valid_symbol) - - def process(self, data_batch: Sequence[Dict], - data_samples: Sequence[Dict]) -> None: - """Process one batch of data_samples. The processed results should be - stored in ``self.results``, which will be used to compute the metrics - when all batches have been processed. - - Args: - data_batch (Sequence[Dict]): A batch of gts. - data_samples (Sequence[Dict]): A batch of outputs from the model. - """ - for data_sample in data_samples: - pred_text = data_sample.get('pred_text').get('item') - gt_text = data_sample.get('gt_text').get('item') - gt_text_lower = gt_text.lower() - pred_text_lower = pred_text.lower() - gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) - pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower) - norm_ed = Levenshtein.normalized_distance(pred_text_lower_ignore, - gt_text_lower_ignore) - result = dict(norm_ed=norm_ed) - self.results.append(result) - - def compute_metrics(self, results: Sequence[Dict]) -> Dict: - """Compute the metrics from processed results. - - Args: - results (list[Dict]): The processed results of each batch. - - Returns: - Dict: The computed metrics. The keys are the names of the - metrics, and the values are corresponding results. - """ - - gt_word_num = len(results) - norm_ed = [result['norm_ed'] for result in results] - norm_ed_sum = sum(norm_ed) - normalized_edit_distance = norm_ed_sum / max(1, gt_word_num) - eval_res = {} - eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance - for key, value in eval_res.items(): - eval_res[key] = float(f'{value:.4f}') - return eval_res + **kwargs) -> None: + collect_device = kwargs.pop('collect_device', None) + prefix = kwargs.pop('prefix', None) + super().__init__( + letter_case=letter_case, valid_symbol=valid_symbol, **kwargs) + TextRecogMixin.__init__(collect_device, prefix) diff --git a/tests/test_evaluation/test_metrics/test_recog_metric.py b/tests/test_evaluation/test_metrics/test_recog_metric.py index cd982b160..d9947f456 100644 --- a/tests/test_evaluation/test_metrics/test_recog_metric.py +++ b/tests/test_evaluation/test_metrics/test_recog_metric.py @@ -41,30 +41,32 @@ def test_word_acc_metric(self): metric = WordMetric(mode='exact') metric.process(None, self.pred) eval_res = metric.evaluate(size=3) - self.assertAlmostEqual(eval_res['recog/word_acc'], 1. / 3, 4) + self.assertAlmostEqual(eval_res['WordMetric/accuracy'], 1. / 3, 4) def test_word_acc_ignore_case_metric(self): metric = WordMetric(mode='ignore_case') metric.process(None, self.pred) eval_res = metric.evaluate(size=3) - self.assertAlmostEqual(eval_res['recog/word_acc_ignore_case'], 2. / 3, - 4) + self.assertAlmostEqual(eval_res['WordMetric/ignore_case_accuracy'], + 2. / 3, 4) def test_word_acc_ignore_case_symbol_metric(self): metric = WordMetric(mode='ignore_case_symbol') metric.process(None, self.pred) eval_res = metric.evaluate(size=3) - self.assertEqual(eval_res['recog/word_acc_ignore_case_symbol'], 1.0) + self.assertEqual(eval_res['WordMetric/ignore_case_symbol_accuracy'], + 1.0) def test_all_metric(self): metric = WordMetric( mode=['exact', 'ignore_case', 'ignore_case_symbol']) metric.process(None, self.pred) eval_res = metric.evaluate(size=3) - self.assertAlmostEqual(eval_res['recog/word_acc'], 1. / 3, 4) - self.assertAlmostEqual(eval_res['recog/word_acc_ignore_case'], 2. / 3, - 4) - self.assertEqual(eval_res['recog/word_acc_ignore_case_symbol'], 1.0) + self.assertAlmostEqual(eval_res['WordMetric/accuracy'], 1. / 3, 4) + self.assertAlmostEqual(eval_res['WordMetric/ignore_case_accuracy'], + 2. / 3, 4) + self.assertEqual(eval_res['WordMetric/ignore_case_symbol_accuracy'], + 1.0) class TestCharMetric(unittest.TestCase): @@ -92,8 +94,8 @@ def test_char_recall_precision_metric(self): metric = CharMetric() metric.process(None, self.pred) eval_res = metric.evaluate(size=2) - self.assertEqual(eval_res['recog/char_recall'], 0.7) - self.assertEqual(eval_res['recog/char_precision'], 1) + self.assertEqual(eval_res['CharMetric/recall'], 0.7) + self.assertEqual(eval_res['CharMetric/precision'], 1) class TestOneMinusNED(unittest.TestCase): @@ -121,4 +123,4 @@ def test_one_minus_ned_metric(self): metric = OneMinusNEDMetric() metric.process(None, self.pred) eval_res = metric.evaluate(size=2) - self.assertEqual(eval_res['recog/1-N.E.D'], 0.4875) + self.assertEqual(eval_res['OneMinusNEDMetric/1-N.E.D'], 0.4875)