Skip to content

Commit

Permalink
Improve processor, add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Sep 15, 2022
1 parent 88618dd commit 47c172c
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 23 deletions.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@
"MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MarkupLMConfig",
"MarkupLMFeatureExtractor",
"MarkupLMProcessor",
"MarkupLMTokenizer",
],
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
Expand Down Expand Up @@ -3161,6 +3162,7 @@
MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
MarkupLMConfig,
MarkupLMFeatureExtractor,
MarkupLMProcessor,
MarkupLMTokenizer,
)
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
Expand Down
22 changes: 18 additions & 4 deletions src/transformers/models/markuplm/processing_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
"""
Processor class for MarkupLM.
"""
from typing import Optional, Union
from typing import List, Optional, Union

from ...file_utils import TensorType
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TruncationStrategy
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy


class MarkupLMProcessor(ProcessorMixin):
Expand All @@ -29,7 +29,7 @@ class MarkupLMProcessor(ProcessorMixin):
[`MarkupLMProcessor`] offers all the functionalities you need to prepare data for the model.
It first uses [`MarkupLMFeatureExtractor`] to get nodes and corresponding xpaths from one or more HTML strings.
It first uses [`MarkupLMFeatureExtractor`] to extract nodes and corresponding xpaths from one or more HTML strings.
Next, these are provided to [`MarkupLMTokenizer`] or [`MarkupLMTokenizerFast`], which turns them into token-level
`input_ids`, `attention_mask`, `token_type_ids`, `xpath_tags_seq` and `xpath_subs_seq`.
Expand All @@ -45,7 +45,7 @@ class MarkupLMProcessor(ProcessorMixin):
def __call__(
self,
html_strings,
text,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
Expand Down Expand Up @@ -97,3 +97,17 @@ def __call__(
)

return encoded_inputs

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to TrOCRTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)

def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to TrOCRTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
37 changes: 21 additions & 16 deletions src/transformers/models/markuplm/test_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from transformers import MarkupLMFeatureExtractor
from transformers import MarkupLMFeatureExtractor, MarkupLMProcessor, MarkupLMTokenizer


feature_extractor = MarkupLMFeatureExtractor()
tokenizer = MarkupLMTokenizer.from_pretrained("microsoft/markuplm-base")

html_string = """<HTML>
processor = MarkupLMProcessor(feature_extractor, tokenizer)

<HEAD>
<TITLE>sample document</TITLE>
</HEAD>

<BODY BGCOLOR="FFFFFF">
<HR> <a href="http://google.com">Goog</a> <H1>This is one header</H1> <H2>This is a another Header</H2> <P>Travel
from
<P>
<B>SFO to JFK</B> <BR> <B><I>on May 2, 2015 at 2:00 pm. For details go to confirm.com </I></B> <HR> <div
style="color:#0000FF">
<h3>Traveler <b> name </b> is <p> John Doe </p>
</div>"""
def prepare_html_string():
html_string = """
<!DOCTYPE html> <html> <head> <title>Page Title</title> </head> <body>
encoding = feature_extractor(html_string)
for k, v in encoding.items():
print(k, v)
<h1>This is a Heading</h1> <p>This is a paragraph.</p>
</body> </html>
"""

return html_string


encoding = processor(prepare_html_string())
# for k, v in encoding.items():
# print(k, v)

print(encoding.input_ids)

print(processor.decode(encoding.input_ids[0]))
54 changes: 53 additions & 1 deletion tests/models/markuplm/test_modeling_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@

from transformers import MarkupLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import cached_property

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor


if is_torch_available():
import torch

from transformers import (
MarkupLMForQuestionAnswering,
MarkupLMForSequenceClassification,
MarkupLMForTokenClassification,
MarkupLMModel,
)

# TODO check dependencies
from transformers import MarkupLMFeatureExtractor, MarkupLMProcessor, MarkupLMTokenizer


class MarkupLMModelTester:
"""You can also import this e.g from .test_modeling_markuplm import MarkupLMModelTester"""
Expand Down Expand Up @@ -305,8 +311,54 @@ def test_for_question_answering(self):
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)


def prepare_html_string():
html_string = """
<!DOCTYPE html>
<html>
<head>
<title>Page Title</title>
</head>
<body>
<h1>This is a Heading</h1>
<p>This is a paragraph.</p>
</body>
</html>
"""

return html_string


@require_torch
class MarkupLMModelIntegrationTest(unittest.TestCase):
@cached_property
def default_processor(self):
# TODO use from_pretrained here
feature_extractor = MarkupLMFeatureExtractor()
tokenizer = MarkupLMTokenizer.from_pretrained("microsoft/markuplm-base")

return MarkupLMProcessor(feature_extractor, tokenizer)

@slow
def test_forward_pass_no_head(self):
raise NotImplementedError("To do")
model = MarkupLMModel.from_pretrained("microsoft/markuplm-base").to(torch_device)

processor = self.default_processor

inputs = processor(prepare_html_string(), return_tensors="pt")
inputs = inputs.to(torch_device)

# forward pass
with torch.no_grad():
outputs = model(**inputs)

# verify the last hidden states
expected_shape = torch.Size([1, 14, 768])
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)

expected_slice = torch.tensor(
[[0.0267, -0.1289, 0.4930], [-0.2376, -0.0342, 0.2381], [-0.0329, -0.3785, 0.0263]]
).to(torch_device)

self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
3 changes: 1 addition & 2 deletions tests/models/markuplm/test_tokenization_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,7 @@ def test_number_of_added_tokens(self):
)

def test_padding_to_max_length(self):
"""We keep this test for backward compatibility but it should be removed when `pad_to_max_length` will be deprecated
"""
"""We keep this test for backward compatibility but it should be removed when `pad_to_max_length` will be deprecated"""
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
Expand Down

0 comments on commit 47c172c

Please sign in to comment.