Skip to content

Commit

Permalink
优化ser_postprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinNuNu committed May 29, 2023
1 parent b04e126 commit 551086e
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@
pretrained_model_name_or_path=hf_pretrained_model,
num_labels=len(class_name) * 2 - 1),
loss_processor=dict(type='ComputeLossAfterLabelSmooth'),
postprocessor=dict(type='SERPostprocessor', classes=class_name))
postprocessor=dict(
type='SERPostprocessor',
classes=class_name,
only_label_first_subword=only_label_first_subword))
# ====================================================================
# ========================= Evaluation ===============================
val_evaluator = dict(type='SeqevalMetric', prefix=dataset_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,12 @@ class ConvertBIOLabelForSER(BaseTransform):

def __init__(self,
classes: Union[tuple, list],
only_label_first_subword: bool = False) -> None:
only_label_first_subword: bool = True) -> None:
super().__init__()
self.other_label_name = find_other_label_name_of_biolabel(classes)
self.biolabel2id = self._generate_biolabel2id_map(classes)
assert only_label_first_subword is True, \
'Only support `only_label_first_subword=True` now.'
self.only_label_first_subword = only_label_first_subword

def _generate_biolabel2id_map(self, classes: Union[tuple, list]) -> Dict:
Expand Down
110 changes: 74 additions & 36 deletions projects/LayoutLMv3/models/ser_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
class SERPostprocessor(nn.Module):
"""PostProcessor for SER."""

def __init__(self, classes: Union[tuple, list]) -> None:
def __init__(self,
classes: Union[tuple, list],
only_label_first_subword: bool = True) -> None:
super().__init__()
self.other_label_name = find_other_label_name_of_biolabel(classes)
self.id2biolabel = self._generate_id2biolabel_map(classes)
assert only_label_first_subword is True, \
'Only support `only_label_first_subword=True` now.'
self.only_label_first_subword = only_label_first_subword
self.softmax = nn.Softmax(dim=-1)

def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
Expand All @@ -40,62 +45,95 @@ def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
def __call__(self, outputs: torch.Tensor,
data_samples: Sequence[SERDataSample]
) -> Sequence[SERDataSample]:
# merge several truncation data_sample to one data_sample
assert all('truncation_word_ids' in d for d in data_samples), \
'The key `truncation_word_ids` should be specified' \
'in PackSERInputs.'
truncation_word_ids = []
for data_sample in data_samples:
truncation_word_ids.append(data_sample.pop('truncation_word_ids'))
merged_data_sample = copy.deepcopy(data_samples[0])
merged_data_sample.set_metainfo(
dict(truncation_word_ids=truncation_word_ids))
flattened_word_ids = [
word_id for word_ids in truncation_word_ids for word_id in word_ids
truncation_word_ids = [
data_sample.pop('truncation_word_ids')
for data_sample in data_samples
]
word_ids = [
word_id for word_ids in truncation_word_ids
for word_id in word_ids[1:-1]
]

# merge several truncation data_sample to one data_sample
merged_data_sample = copy.deepcopy(data_samples[0])

# convert outputs dim from (truncation_num, max_length, label_num)
# to (truncation_num * max_length, label_num)
outputs = outputs.cpu().detach()
outputs = torch.reshape(outputs, (-1, outputs.size(-1)))
outputs = torch.reshape(outputs[:, 1:-1, :], (-1, outputs.size(-1)))
# get pred label ids/scores from outputs
probs = self.softmax(outputs)
max_value, max_idx = torch.max(probs, -1)
pred_label_ids = max_idx.numpy()
pred_label_scores = max_value.numpy()

# inference process do not have item in gt_label,
# so select valid token with word_ids rather than
# with gt_label_ids like official code.
pred_words_biolabels = []
word_biolabels = []
pre_word_id = None
for idx, cur_word_id in enumerate(word_ids):
if cur_word_id is not None:
if cur_word_id != pre_word_id:
if word_biolabels:
pred_words_biolabels.append(word_biolabels)
word_biolabels = []
word_biolabels.append((self.id2biolabel[pred_label_ids[idx]],
pred_label_scores[idx]))
else:
pred_words_biolabels.append(word_biolabels)
break
pre_word_id = cur_word_id
# record pred_label
if self.only_label_first_subword:
pred_label = LabelData()
pred_label.item = [
pred_word_biolabels[0][0]
for pred_word_biolabels in pred_words_biolabels
]
pred_label.score = [
pred_word_biolabels[0][1]
for pred_word_biolabels in pred_words_biolabels
]
merged_data_sample.pred_label = pred_label
else:
raise NotImplementedError(
'The `only_label_first_subword=False` is not support yet.')

# determine whether it is an inference process
if 'item' in data_samples[0].gt_label:
# merge gt label ids from data_samples
gt_label_ids = [
data_sample.gt_label.item for data_sample in data_samples
data_sample.gt_label.item[1:-1] for data_sample in data_samples
]
gt_label_ids = torch.cat(
gt_label_ids, dim=0).cpu().detach().numpy()
gt_biolabels = [
self.id2biolabel[g]
for (w, g) in zip(flattened_word_ids, gt_label_ids)
if w is not None
]
gt_words_biolabels = []
word_biolabels = []
pre_word_id = None
for idx, cur_word_id in enumerate(word_ids):
if cur_word_id is not None:
if cur_word_id != pre_word_id:
if word_biolabels:
gt_words_biolabels.append(word_biolabels)
word_biolabels = []
word_biolabels.append(self.id2biolabel[gt_label_ids[idx]])
else:
gt_words_biolabels.append(word_biolabels)
break
pre_word_id = cur_word_id
# update merged gt_label
merged_data_sample.gt_label.item = gt_biolabels

# inference process do not have item in gt_label,
# so select valid token with flattened_word_ids
# rather than with gt_label_ids like official code.
pred_biolabels = [
self.id2biolabel[p]
for (w, p) in zip(flattened_word_ids, pred_label_ids)
if w is not None
]
pred_biolabel_scores = [
s for (w, s) in zip(flattened_word_ids, pred_label_scores)
if w is not None
]
# record pred_label
pred_label = LabelData()
pred_label.item = pred_biolabels
pred_label.score = pred_biolabel_scores
merged_data_sample.pred_label = pred_label
if self.only_label_first_subword:
merged_data_sample.gt_label.item = [
gt_word_biolabels[0]
for gt_word_biolabels in gt_words_biolabels
]
else:
raise NotImplementedError(
'The `only_label_first_subword=False` is not support yet.')

return [merged_data_sample]
51 changes: 15 additions & 36 deletions projects/LayoutLMv3/visualization/ser_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def __init__(self,
self.line_width = line_width
self.alpha = alpha

def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
torch.Tensor],
word_ids: Optional[List[int]],
gt_labels: Optional[LabelData],
pred_labels: Optional[LabelData]) -> np.ndarray:
def _draw_instances(self,
image: np.ndarray,
bboxes: Union[np.ndarray, torch.Tensor],
gt_labels: Optional[LabelData] = None,
pred_labels: Optional[LabelData] = None) -> np.ndarray:
"""Draw bboxes and polygons on image.
Args:
Expand Down Expand Up @@ -97,33 +97,19 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,

if gt_labels is not None:
gt_tokens_biolabel = gt_labels.item
gt_words_label = []

pre_word_id = None
for idx, cur_word_id in enumerate(word_ids):
if cur_word_id is not None:
if cur_word_id != pre_word_id:
gt_words_label_name = gt_tokens_biolabel[idx][2:] \
if gt_tokens_biolabel[idx] != 'O' else 'other'
gt_words_label.append(gt_words_label_name)
pre_word_id = cur_word_id
gt_words_label = [
token_biolabel[2:] if token_biolabel != 'O' else 'other'
for token_biolabel in gt_tokens_biolabel
]
assert len(gt_words_label) == len(bboxes)

if pred_labels is not None:
pred_tokens_biolabel = pred_labels.item
pred_words_label = []
pred_tokens_biolabel_score = pred_labels.score
pred_words_label_score = []

pre_word_id = None
for idx, cur_word_id in enumerate(word_ids):
if cur_word_id is not None:
if cur_word_id != pre_word_id:
pred_words_label_name = pred_tokens_biolabel[idx][2:] \
if pred_tokens_biolabel[idx] != 'O' else 'other'
pred_words_label.append(pred_words_label_name)
pred_words_label_score.append(
pred_tokens_biolabel_score[idx])
pre_word_id = cur_word_id
pred_words_label = [
token_biolabel[2:] if token_biolabel != 'O' else 'other'
for token_biolabel in pred_tokens_biolabel
]
pred_words_label_score = pred_labels.score
assert len(pred_words_label) == len(bboxes)

# draw gt or pred labels
Expand Down Expand Up @@ -205,11 +191,6 @@ def add_datasample(self,
cat_images = []
if data_sample is not None:
bboxes = np.array(data_sample.instances.get('boxes', None))
# here need to flatten truncation_word_ids
word_ids = [
word_id for word_ids in data_sample.truncation_word_ids
for word_id in word_ids[1:-1]
]
gt_label = data_sample.gt_label if \
draw_gt and 'gt_label' in data_sample else None
pred_label = data_sample.pred_label if \
Expand All @@ -218,15 +199,13 @@ def add_datasample(self,
orig_img_with_bboxes = self._draw_instances(
image=image.copy(),
bboxes=bboxes,
word_ids=None,
gt_labels=None,
pred_labels=None)
cat_images.append(orig_img_with_bboxes)
empty_img = np.full_like(image, 255)
empty_img_with_label = self._draw_instances(
image=empty_img,
bboxes=bboxes,
word_ids=word_ids,
gt_labels=gt_label,
pred_labels=pred_label)
cat_images.append(empty_img_with_label)
Expand Down

0 comments on commit 551086e

Please sign in to comment.