Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable padding_side as call time kwargs #33385

Merged
merged 11 commits into from
Sep 13, 2024
29 changes: 26 additions & 3 deletions src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def __call__(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -517,6 +518,7 @@ def _is_valid_text_input(t):
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand All @@ -539,6 +541,7 @@ def _is_valid_text_input(t):
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand Down Expand Up @@ -567,6 +570,7 @@ def batch_encode_plus(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -598,6 +602,7 @@ def batch_encode_plus(
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand Down Expand Up @@ -625,6 +630,7 @@ def _batch_encode_plus(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -653,6 +659,7 @@ def _batch_encode_plus(
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
Expand All @@ -677,6 +684,7 @@ def _batch_prepare_for_model(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -708,6 +716,7 @@ def _batch_prepare_for_model(
max_length=max_length,
stride=stride,
pad_to_multiple_of=None, # we pad in batch afterward
padding_side=None, # we pad in batch afterward
return_attention_mask=False, # we pad in batch afterward
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
Expand All @@ -728,6 +737,7 @@ def _batch_prepare_for_model(
padding=padding_strategy.value,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_attention_mask=return_attention_mask,
)

Expand All @@ -748,6 +758,7 @@ def encode(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand All @@ -769,6 +780,7 @@ def encode(
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand All @@ -795,6 +807,7 @@ def encode_plus(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -838,6 +851,7 @@ def encode_plus(
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand All @@ -861,6 +875,7 @@ def _encode_plus(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -891,6 +906,7 @@ def _encode_plus(
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
prepend_batch_axis=True,
return_attention_mask=return_attention_mask,
Expand All @@ -914,6 +930,7 @@ def prepare_for_model(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -1100,6 +1117,7 @@ def prepare_for_model(
max_length=max_length,
padding=padding_strategy.value,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_attention_mask=return_attention_mask,
)

Expand Down Expand Up @@ -1243,6 +1261,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -1265,6 +1284,9 @@ def _pad(
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
padding_side:
The side on which the model should have padding applied. Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
Expand All @@ -1288,7 +1310,8 @@ def _pad(

if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
padding_side = padding_side if padding_side is not None else self.padding_side
if padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if "token_type_ids" in encoded_inputs:
Expand All @@ -1302,7 +1325,7 @@ def _pad(
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == "left":
elif padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "token_type_ids" in encoded_inputs:
Expand All @@ -1317,7 +1340,7 @@ def _pad(
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
raise ValueError("Invalid padding strategy:" + str(padding_side))

return encoded_inputs

Expand Down
22 changes: 19 additions & 3 deletions src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __call__(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -268,6 +269,7 @@ def _is_valid_text_input(t):
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand All @@ -290,6 +292,7 @@ def _is_valid_text_input(t):
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand Down Expand Up @@ -318,6 +321,7 @@ def batch_encode_plus(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -349,6 +353,7 @@ def batch_encode_plus(
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand Down Expand Up @@ -381,6 +386,7 @@ def encode_plus(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -424,6 +430,7 @@ def encode_plus(
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand Down Expand Up @@ -451,6 +458,7 @@ def _batch_encode_plus(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand All @@ -470,6 +478,7 @@ def _batch_encode_plus(
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
)

if is_pair:
Expand Down Expand Up @@ -603,6 +612,7 @@ def _encode_plus(
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_tensors: Optional[bool] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
Expand Down Expand Up @@ -631,6 +641,7 @@ def _encode_plus(
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
Expand Down Expand Up @@ -663,6 +674,7 @@ def _pad(
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Expand All @@ -685,6 +697,9 @@ def _pad(
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
padding_side:
The side on which the model should have padding applied. Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
Expand All @@ -708,7 +723,8 @@ def _pad(

if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
padding_side = padding_side if padding_side is not None else self.padding_side
if padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if "token_type_ids" in encoded_inputs:
Expand All @@ -722,7 +738,7 @@ def _pad(
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == "left":
elif padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "token_type_ids" in encoded_inputs:
Expand All @@ -737,7 +753,7 @@ def _pad(
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
raise ValueError("Invalid padding strategy:" + str(padding_side))

return encoded_inputs

Expand Down
Loading
Loading