Skip to content

Commit

Permalink
优化ser/re packer,根据words关键字是否存在觉得是否加入
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinNuNu committed Mar 30, 2023
1 parent 1d0c5e3 commit deb96cc
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 72 deletions.
80 changes: 24 additions & 56 deletions mmocr/datasets/preparers/packers/re_packer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Tuple
import warnings
from typing import Dict, Tuple

import mmcv

from mmocr.registry import DATA_PACKERS
from .base import BasePacker
from .ser_packer import SERPacker


@DATA_PACKERS.register_module()
class REPacker(BasePacker):
class REPacker(SERPacker):
"""Relation Extraction packer. It is used to pack the parsed annotation
info to.
Expand All @@ -18,8 +19,6 @@ class REPacker(BasePacker):
{
"metainfo":
{
"dataset_type": "REDataset",
"task_name": "re",
"labels": ['answer', 'header', 'other', 'question'],
"id2label": {
"0": "O",
Expand Down Expand Up @@ -49,8 +48,8 @@ class REPacker(BasePacker):
"instances":
{
"texts": ["绩效目标申报表(一级项目)", "项目名称", ...],
"bboxes": [[906,195,1478,259],
[357,325,467,357], ...],
"boxes": [[906,195,1478,259],
[357,325,467,357], ...],
"labels": ["header", "question", ...],
"linkings": [[0, 1], [2, 3], ...],
"ids": [0, 1, ...],
Expand Down Expand Up @@ -104,75 +103,44 @@ def pack_instance(self, sample: Tuple) -> Dict:
h, w = img.shape[:2]

texts_per_doc = []
bboxes_per_doc = []
boxes_per_doc = []
labels_per_doc = []
words_per_doc = []
linking_per_doc = []
id_per_doc = []
has_words = all(['words' in ins for ins in instances])
if has_words:
words_per_doc = []
else:
warnings.warn(
'Not all instance has `words` key,'
'so final MMOCR format SER instance will not have `words` key')

for instance in instances:
text = instance.get('text', None)
box = instance.get('box', None)
label = instance.get('label', None)
linking = instance.get('linking', None)
ins_id = instance.get('id', None)
words = instance.get('words', None)
assert text or box or label
assert text or box or label or linking or ins_id
texts_per_doc.append(text)
bboxes_per_doc.append(box)
boxes_per_doc.append(box)
labels_per_doc.append(label)
words_per_doc.append(words)
linking_per_doc.append(linking)
id_per_doc.append(ins_id)
if has_words:
words = instance.get('words', None)
words_per_doc.append(words)
packed_instances = dict(
instances=dict(
texts=texts_per_doc,
bboxes=bboxes_per_doc,
boxes=boxes_per_doc,
labels=labels_per_doc,
linkings=linking_per_doc,
ids=id_per_doc,
words=words_per_doc),
ids=id_per_doc),
img_path=osp.relpath(img_path, self.data_root),
height=h,
width=w)
if has_words:
packed_instances['instances'].update({'words': words_per_doc})

return packed_instances

def add_meta(self, sample: List) -> Dict:
"""Add meta information to the sample.
Args:
sample (List): A list of samples of the dataset.
Returns:
Dict: A dict contains the meta information and samples.
"""

def get_BIO_label_list(labels):
bio_label_list = []
for label in labels:
if label == 'other':
bio_label_list.insert(0, 'O')
else:
bio_label_list.append(f'B-{label.upper()}')
bio_label_list.append(f'I-{label.upper()}')
return bio_label_list

labels = []
for s in sample:
labels += s['instances']['labels']
org_label_list = list(set(labels))
bio_label_list = get_BIO_label_list(org_label_list)

meta = {
'metainfo': {
'dataset_type': 'REDataset',
'task_name': 're',
'labels': org_label_list,
'id2label': {k: v
for k, v in enumerate(bio_label_list)},
'label2id': {v: k
for k, v in enumerate(bio_label_list)}
},
'data_list': sample
}
return meta
38 changes: 22 additions & 16 deletions mmocr/datasets/preparers/packers/ser_packer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Dict, List, Tuple

import mmcv
Expand All @@ -18,8 +19,6 @@ class SERPacker(BasePacker):
{
"metainfo":
{
"dataset_type": "SERDataset",
"task_name": "ser",
"labels": ['answer', 'header', 'other', 'question'],
"id2label": {
"0": "O",
Expand Down Expand Up @@ -49,8 +48,8 @@ class SERPacker(BasePacker):
"instances":
{
"texts": ["绩效目标申报表(一级项目)", "项目名称", ...],
"bboxes": [[906,195,1478,259],
[357,325,467,357], ...],
"boxes": [[906,195,1478,259],
[357,325,467,357], ...],
"labels": ["header", "question", ...],
"words": [[{
"box": [
Expand Down Expand Up @@ -100,28 +99,37 @@ def pack_instance(self, sample: Tuple) -> Dict:
h, w = img.shape[:2]

texts_per_doc = []
bboxes_per_doc = []
boxes_per_doc = []
labels_per_doc = []
words_per_doc = []
has_words = all(['words' in ins for ins in instances])
if has_words:
words_per_doc = []
else:
warnings.warn(
'Not all instance has `words` key,'
'so final MMOCR format SER instance will not have `words` key')

for instance in instances:
text = instance.get('text', None)
box = instance.get('box', None)
label = instance.get('label', None)
words = instance.get('words', None)
assert text or box or label
texts_per_doc.append(text)
bboxes_per_doc.append(box)
boxes_per_doc.append(box)
labels_per_doc.append(label)
words_per_doc.append(words)
if has_words:
words = instance.get('words', None)
words_per_doc.append(words)
packed_instances = dict(
instances=dict(
texts=texts_per_doc,
bboxes=bboxes_per_doc,
labels=labels_per_doc,
words=words_per_doc),
boxes=boxes_per_doc,
labels=labels_per_doc),
img_path=osp.relpath(img_path, self.data_root),
height=h,
width=w)
if has_words:
packed_instances['instances'].update({'words': words_per_doc})

return packed_instances

Expand All @@ -135,7 +143,7 @@ def add_meta(self, sample: List) -> Dict:
Dict: A dict contains the meta information and samples.
"""

def get_BIO_label_list(labels):
def get_bio_label_list(labels):
bio_label_list = []
for label in labels:
if label == 'other':
Expand All @@ -149,12 +157,10 @@ def get_BIO_label_list(labels):
for s in sample:
labels += s['instances']['labels']
org_label_list = list(set(labels))
bio_label_list = get_BIO_label_list(org_label_list)
bio_label_list = get_bio_label_list(org_label_list)

meta = {
'metainfo': {
'dataset_type': 'SERDataset',
'task_name': 'ser',
'labels': org_label_list,
'id2label': {k: v
for k, v in enumerate(bio_label_list)},
Expand Down

0 comments on commit deb96cc

Please sign in to comment.