diff --git a/mmocr/datasets/preparers/packers/re_packer.py b/mmocr/datasets/preparers/packers/re_packer.py index 5f3c12fe7..54edce73d 100644 --- a/mmocr/datasets/preparers/packers/re_packer.py +++ b/mmocr/datasets/preparers/packers/re_packer.py @@ -1,15 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Dict, List, Tuple +import warnings +from typing import Dict, Tuple import mmcv from mmocr.registry import DATA_PACKERS -from .base import BasePacker +from .ser_packer import SERPacker @DATA_PACKERS.register_module() -class REPacker(BasePacker): +class REPacker(SERPacker): """Relation Extraction packer. It is used to pack the parsed annotation info to. @@ -18,8 +19,6 @@ class REPacker(BasePacker): { "metainfo": { - "dataset_type": "REDataset", - "task_name": "re", "labels": ['answer', 'header', 'other', 'question'], "id2label": { "0": "O", @@ -49,8 +48,8 @@ class REPacker(BasePacker): "instances": { "texts": ["绩效目标申报表(一级项目)", "项目名称", ...], - "bboxes": [[906,195,1478,259], - [357,325,467,357], ...], + "boxes": [[906,195,1478,259], + [357,325,467,357], ...], "labels": ["header", "question", ...], "linkings": [[0, 1], [2, 3], ...], "ids": [0, 1, ...], @@ -104,75 +103,44 @@ def pack_instance(self, sample: Tuple) -> Dict: h, w = img.shape[:2] texts_per_doc = [] - bboxes_per_doc = [] + boxes_per_doc = [] labels_per_doc = [] - words_per_doc = [] linking_per_doc = [] id_per_doc = [] + has_words = all(['words' in ins for ins in instances]) + if has_words: + words_per_doc = [] + else: + warnings.warn( + 'Not all instance has `words` key,' + 'so final MMOCR format SER instance will not have `words` key') + for instance in instances: text = instance.get('text', None) box = instance.get('box', None) label = instance.get('label', None) linking = instance.get('linking', None) ins_id = instance.get('id', None) - words = instance.get('words', None) - assert text or box or label + assert text or box or label or linking or ins_id texts_per_doc.append(text) - bboxes_per_doc.append(box) + boxes_per_doc.append(box) labels_per_doc.append(label) - words_per_doc.append(words) linking_per_doc.append(linking) id_per_doc.append(ins_id) + if has_words: + words = instance.get('words', None) + words_per_doc.append(words) packed_instances = dict( instances=dict( texts=texts_per_doc, - bboxes=bboxes_per_doc, + boxes=boxes_per_doc, labels=labels_per_doc, linkings=linking_per_doc, - ids=id_per_doc, - words=words_per_doc), + ids=id_per_doc), img_path=osp.relpath(img_path, self.data_root), height=h, width=w) + if has_words: + packed_instances['instances'].update({'words': words_per_doc}) return packed_instances - - def add_meta(self, sample: List) -> Dict: - """Add meta information to the sample. - - Args: - sample (List): A list of samples of the dataset. - - Returns: - Dict: A dict contains the meta information and samples. - """ - - def get_BIO_label_list(labels): - bio_label_list = [] - for label in labels: - if label == 'other': - bio_label_list.insert(0, 'O') - else: - bio_label_list.append(f'B-{label.upper()}') - bio_label_list.append(f'I-{label.upper()}') - return bio_label_list - - labels = [] - for s in sample: - labels += s['instances']['labels'] - org_label_list = list(set(labels)) - bio_label_list = get_BIO_label_list(org_label_list) - - meta = { - 'metainfo': { - 'dataset_type': 'REDataset', - 'task_name': 're', - 'labels': org_label_list, - 'id2label': {k: v - for k, v in enumerate(bio_label_list)}, - 'label2id': {v: k - for k, v in enumerate(bio_label_list)} - }, - 'data_list': sample - } - return meta diff --git a/mmocr/datasets/preparers/packers/ser_packer.py b/mmocr/datasets/preparers/packers/ser_packer.py index 1b1d8528f..3db633bfa 100644 --- a/mmocr/datasets/preparers/packers/ser_packer.py +++ b/mmocr/datasets/preparers/packers/ser_packer.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +import warnings from typing import Dict, List, Tuple import mmcv @@ -18,8 +19,6 @@ class SERPacker(BasePacker): { "metainfo": { - "dataset_type": "SERDataset", - "task_name": "ser", "labels": ['answer', 'header', 'other', 'question'], "id2label": { "0": "O", @@ -49,8 +48,8 @@ class SERPacker(BasePacker): "instances": { "texts": ["绩效目标申报表(一级项目)", "项目名称", ...], - "bboxes": [[906,195,1478,259], - [357,325,467,357], ...], + "boxes": [[906,195,1478,259], + [357,325,467,357], ...], "labels": ["header", "question", ...], "words": [[{ "box": [ @@ -100,28 +99,37 @@ def pack_instance(self, sample: Tuple) -> Dict: h, w = img.shape[:2] texts_per_doc = [] - bboxes_per_doc = [] + boxes_per_doc = [] labels_per_doc = [] - words_per_doc = [] + has_words = all(['words' in ins for ins in instances]) + if has_words: + words_per_doc = [] + else: + warnings.warn( + 'Not all instance has `words` key,' + 'so final MMOCR format SER instance will not have `words` key') + for instance in instances: text = instance.get('text', None) box = instance.get('box', None) label = instance.get('label', None) - words = instance.get('words', None) assert text or box or label texts_per_doc.append(text) - bboxes_per_doc.append(box) + boxes_per_doc.append(box) labels_per_doc.append(label) - words_per_doc.append(words) + if has_words: + words = instance.get('words', None) + words_per_doc.append(words) packed_instances = dict( instances=dict( texts=texts_per_doc, - bboxes=bboxes_per_doc, - labels=labels_per_doc, - words=words_per_doc), + boxes=boxes_per_doc, + labels=labels_per_doc), img_path=osp.relpath(img_path, self.data_root), height=h, width=w) + if has_words: + packed_instances['instances'].update({'words': words_per_doc}) return packed_instances @@ -135,7 +143,7 @@ def add_meta(self, sample: List) -> Dict: Dict: A dict contains the meta information and samples. """ - def get_BIO_label_list(labels): + def get_bio_label_list(labels): bio_label_list = [] for label in labels: if label == 'other': @@ -149,12 +157,10 @@ def get_BIO_label_list(labels): for s in sample: labels += s['instances']['labels'] org_label_list = list(set(labels)) - bio_label_list = get_BIO_label_list(org_label_list) + bio_label_list = get_bio_label_list(org_label_list) meta = { 'metainfo': { - 'dataset_type': 'SERDataset', - 'task_name': 'ser', 'labels': org_label_list, 'id2label': {k: v for k, v in enumerate(bio_label_list)},