Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Harold-lkk committed Mar 13, 2023
1 parent a866d9c commit bace3bc
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions mmeval/metrics/char_recall_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ class CharRecallPrecision(BaseMetric):
- unchanged: Do not change prediction texts and labels.
- upper: Convert prediction texts and labels into uppercase
characters.
characters.
- lower: Convert prediction texts and labels into lowercase
characters.
characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
invalid_symbol (str): A regular expression to filter out invalid or
not cared characters. Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'.
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Examples:
Expand All @@ -36,21 +36,21 @@ class CharRecallPrecision(BaseMetric):

def __init__(self,
letter_case: str = 'unchanged',
invalid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]',
**kwargs):
super().__init__(**kwargs)
assert letter_case in ['unchanged', 'upper', 'lower']
self.letter_case = letter_case
self.invalid_symbol = re.compile(invalid_symbol)

def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Process one batch of data and predictions.
Args:
predictions (list[str]): The prediction texts.
labels (list[str]): The ground truth texts.
groundtruths (list[str]): The ground truth texts.
"""
for pred, label in zip(predictions, labels):
for pred, label in zip(predictions, groundtruths):
if self.letter_case in ['upper', 'lower']:
pred = getattr(pred, self.letter_case)()
label = getattr(label, self.letter_case)()
Expand Down Expand Up @@ -79,10 +79,10 @@ def compute_metric(self, results: Sequence[Tuple[int, int, int]]) -> Dict:
true_positive_sum += true_positive
char_recall = true_positive_sum / max(gt_sum, 1.0)
char_precision = true_positive_sum / max(pred_sum, 1.0)
eval_res = {}
eval_res['recall'] = char_recall
eval_res['precision'] = char_precision
return eval_res
metric_results = {}
metric_results['recall'] = char_recall
metric_results['precision'] = char_precision
return metric_results

def _cal_true_positive_char(self, pred: str, gt: str) -> int:
"""Calculate correct character number in prediction.
Expand Down

0 comments on commit bace3bc

Please sign in to comment.