From 551086edfcbd71a82aaef536b8ca9d0e257c6e7a Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 29 May 2023 10:47:21 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96ser=5Fpostprocessor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ser/layoutlmv3_1k_xfund_zh_1xbs8.py | 5 +- .../transforms/layoutlmv3_transforms.py | 4 +- .../LayoutLMv3/models/ser_postprocessor.py | 110 ++++++++++++------ .../visualization/ser_visualizer.py | 51 +++----- 4 files changed, 96 insertions(+), 74 deletions(-) diff --git a/projects/LayoutLMv3/configs/ser/layoutlmv3_1k_xfund_zh_1xbs8.py b/projects/LayoutLMv3/configs/ser/layoutlmv3_1k_xfund_zh_1xbs8.py index 76cacde9b..304156810 100644 --- a/projects/LayoutLMv3/configs/ser/layoutlmv3_1k_xfund_zh_1xbs8.py +++ b/projects/LayoutLMv3/configs/ser/layoutlmv3_1k_xfund_zh_1xbs8.py @@ -108,7 +108,10 @@ pretrained_model_name_or_path=hf_pretrained_model, num_labels=len(class_name) * 2 - 1), loss_processor=dict(type='ComputeLossAfterLabelSmooth'), - postprocessor=dict(type='SERPostprocessor', classes=class_name)) + postprocessor=dict( + type='SERPostprocessor', + classes=class_name, + only_label_first_subword=only_label_first_subword)) # ==================================================================== # ========================= Evaluation =============================== val_evaluator = dict(type='SeqevalMetric', prefix=dataset_name) diff --git a/projects/LayoutLMv3/datasets/transforms/layoutlmv3_transforms.py b/projects/LayoutLMv3/datasets/transforms/layoutlmv3_transforms.py index 2bed95708..02684a2b0 100644 --- a/projects/LayoutLMv3/datasets/transforms/layoutlmv3_transforms.py +++ b/projects/LayoutLMv3/datasets/transforms/layoutlmv3_transforms.py @@ -220,10 +220,12 @@ class ConvertBIOLabelForSER(BaseTransform): def __init__(self, classes: Union[tuple, list], - only_label_first_subword: bool = False) -> None: + only_label_first_subword: bool = True) -> None: super().__init__() self.other_label_name = find_other_label_name_of_biolabel(classes) self.biolabel2id = self._generate_biolabel2id_map(classes) + assert only_label_first_subword is True, \ + 'Only support `only_label_first_subword=True` now.' self.only_label_first_subword = only_label_first_subword def _generate_biolabel2id_map(self, classes: Union[tuple, list]) -> Dict: diff --git a/projects/LayoutLMv3/models/ser_postprocessor.py b/projects/LayoutLMv3/models/ser_postprocessor.py index a70c2ae82..94f2073db 100644 --- a/projects/LayoutLMv3/models/ser_postprocessor.py +++ b/projects/LayoutLMv3/models/ser_postprocessor.py @@ -16,10 +16,15 @@ class SERPostprocessor(nn.Module): """PostProcessor for SER.""" - def __init__(self, classes: Union[tuple, list]) -> None: + def __init__(self, + classes: Union[tuple, list], + only_label_first_subword: bool = True) -> None: super().__init__() self.other_label_name = find_other_label_name_of_biolabel(classes) self.id2biolabel = self._generate_id2biolabel_map(classes) + assert only_label_first_subword is True, \ + 'Only support `only_label_first_subword=True` now.' + self.only_label_first_subword = only_label_first_subword self.softmax = nn.Softmax(dim=-1) def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict: @@ -40,62 +45,95 @@ def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict: def __call__(self, outputs: torch.Tensor, data_samples: Sequence[SERDataSample] ) -> Sequence[SERDataSample]: - # merge several truncation data_sample to one data_sample assert all('truncation_word_ids' in d for d in data_samples), \ 'The key `truncation_word_ids` should be specified' \ 'in PackSERInputs.' - truncation_word_ids = [] - for data_sample in data_samples: - truncation_word_ids.append(data_sample.pop('truncation_word_ids')) - merged_data_sample = copy.deepcopy(data_samples[0]) - merged_data_sample.set_metainfo( - dict(truncation_word_ids=truncation_word_ids)) - flattened_word_ids = [ - word_id for word_ids in truncation_word_ids for word_id in word_ids + truncation_word_ids = [ + data_sample.pop('truncation_word_ids') + for data_sample in data_samples + ] + word_ids = [ + word_id for word_ids in truncation_word_ids + for word_id in word_ids[1:-1] ] + # merge several truncation data_sample to one data_sample + merged_data_sample = copy.deepcopy(data_samples[0]) + # convert outputs dim from (truncation_num, max_length, label_num) # to (truncation_num * max_length, label_num) outputs = outputs.cpu().detach() - outputs = torch.reshape(outputs, (-1, outputs.size(-1))) + outputs = torch.reshape(outputs[:, 1:-1, :], (-1, outputs.size(-1))) # get pred label ids/scores from outputs probs = self.softmax(outputs) max_value, max_idx = torch.max(probs, -1) pred_label_ids = max_idx.numpy() pred_label_scores = max_value.numpy() + # inference process do not have item in gt_label, + # so select valid token with word_ids rather than + # with gt_label_ids like official code. + pred_words_biolabels = [] + word_biolabels = [] + pre_word_id = None + for idx, cur_word_id in enumerate(word_ids): + if cur_word_id is not None: + if cur_word_id != pre_word_id: + if word_biolabels: + pred_words_biolabels.append(word_biolabels) + word_biolabels = [] + word_biolabels.append((self.id2biolabel[pred_label_ids[idx]], + pred_label_scores[idx])) + else: + pred_words_biolabels.append(word_biolabels) + break + pre_word_id = cur_word_id + # record pred_label + if self.only_label_first_subword: + pred_label = LabelData() + pred_label.item = [ + pred_word_biolabels[0][0] + for pred_word_biolabels in pred_words_biolabels + ] + pred_label.score = [ + pred_word_biolabels[0][1] + for pred_word_biolabels in pred_words_biolabels + ] + merged_data_sample.pred_label = pred_label + else: + raise NotImplementedError( + 'The `only_label_first_subword=False` is not support yet.') + # determine whether it is an inference process if 'item' in data_samples[0].gt_label: # merge gt label ids from data_samples gt_label_ids = [ - data_sample.gt_label.item for data_sample in data_samples + data_sample.gt_label.item[1:-1] for data_sample in data_samples ] gt_label_ids = torch.cat( gt_label_ids, dim=0).cpu().detach().numpy() - gt_biolabels = [ - self.id2biolabel[g] - for (w, g) in zip(flattened_word_ids, gt_label_ids) - if w is not None - ] + gt_words_biolabels = [] + word_biolabels = [] + pre_word_id = None + for idx, cur_word_id in enumerate(word_ids): + if cur_word_id is not None: + if cur_word_id != pre_word_id: + if word_biolabels: + gt_words_biolabels.append(word_biolabels) + word_biolabels = [] + word_biolabels.append(self.id2biolabel[gt_label_ids[idx]]) + else: + gt_words_biolabels.append(word_biolabels) + break + pre_word_id = cur_word_id # update merged gt_label - merged_data_sample.gt_label.item = gt_biolabels - - # inference process do not have item in gt_label, - # so select valid token with flattened_word_ids - # rather than with gt_label_ids like official code. - pred_biolabels = [ - self.id2biolabel[p] - for (w, p) in zip(flattened_word_ids, pred_label_ids) - if w is not None - ] - pred_biolabel_scores = [ - s for (w, s) in zip(flattened_word_ids, pred_label_scores) - if w is not None - ] - # record pred_label - pred_label = LabelData() - pred_label.item = pred_biolabels - pred_label.score = pred_biolabel_scores - merged_data_sample.pred_label = pred_label + if self.only_label_first_subword: + merged_data_sample.gt_label.item = [ + gt_word_biolabels[0] + for gt_word_biolabels in gt_words_biolabels + ] + else: + raise NotImplementedError( + 'The `only_label_first_subword=False` is not support yet.') return [merged_data_sample] diff --git a/projects/LayoutLMv3/visualization/ser_visualizer.py b/projects/LayoutLMv3/visualization/ser_visualizer.py index f0cdc3707..0df89db0b 100644 --- a/projects/LayoutLMv3/visualization/ser_visualizer.py +++ b/projects/LayoutLMv3/visualization/ser_visualizer.py @@ -65,11 +65,11 @@ def __init__(self, self.line_width = line_width self.alpha = alpha - def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray, - torch.Tensor], - word_ids: Optional[List[int]], - gt_labels: Optional[LabelData], - pred_labels: Optional[LabelData]) -> np.ndarray: + def _draw_instances(self, + image: np.ndarray, + bboxes: Union[np.ndarray, torch.Tensor], + gt_labels: Optional[LabelData] = None, + pred_labels: Optional[LabelData] = None) -> np.ndarray: """Draw bboxes and polygons on image. Args: @@ -97,33 +97,19 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray, if gt_labels is not None: gt_tokens_biolabel = gt_labels.item - gt_words_label = [] - - pre_word_id = None - for idx, cur_word_id in enumerate(word_ids): - if cur_word_id is not None: - if cur_word_id != pre_word_id: - gt_words_label_name = gt_tokens_biolabel[idx][2:] \ - if gt_tokens_biolabel[idx] != 'O' else 'other' - gt_words_label.append(gt_words_label_name) - pre_word_id = cur_word_id + gt_words_label = [ + token_biolabel[2:] if token_biolabel != 'O' else 'other' + for token_biolabel in gt_tokens_biolabel + ] assert len(gt_words_label) == len(bboxes) + if pred_labels is not None: pred_tokens_biolabel = pred_labels.item - pred_words_label = [] - pred_tokens_biolabel_score = pred_labels.score - pred_words_label_score = [] - - pre_word_id = None - for idx, cur_word_id in enumerate(word_ids): - if cur_word_id is not None: - if cur_word_id != pre_word_id: - pred_words_label_name = pred_tokens_biolabel[idx][2:] \ - if pred_tokens_biolabel[idx] != 'O' else 'other' - pred_words_label.append(pred_words_label_name) - pred_words_label_score.append( - pred_tokens_biolabel_score[idx]) - pre_word_id = cur_word_id + pred_words_label = [ + token_biolabel[2:] if token_biolabel != 'O' else 'other' + for token_biolabel in pred_tokens_biolabel + ] + pred_words_label_score = pred_labels.score assert len(pred_words_label) == len(bboxes) # draw gt or pred labels @@ -205,11 +191,6 @@ def add_datasample(self, cat_images = [] if data_sample is not None: bboxes = np.array(data_sample.instances.get('boxes', None)) - # here need to flatten truncation_word_ids - word_ids = [ - word_id for word_ids in data_sample.truncation_word_ids - for word_id in word_ids[1:-1] - ] gt_label = data_sample.gt_label if \ draw_gt and 'gt_label' in data_sample else None pred_label = data_sample.pred_label if \ @@ -218,7 +199,6 @@ def add_datasample(self, orig_img_with_bboxes = self._draw_instances( image=image.copy(), bboxes=bboxes, - word_ids=None, gt_labels=None, pred_labels=None) cat_images.append(orig_img_with_bboxes) @@ -226,7 +206,6 @@ def add_datasample(self, empty_img_with_label = self._draw_instances( image=empty_img, bboxes=bboxes, - word_ids=word_ids, gt_labels=gt_label, pred_labels=pred_label) cat_images.append(empty_img_with_label)