Skip to content

Commit

Permalink
text2text, summarization and translation
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Sep 19, 2024
1 parent 4979344 commit ceb990d
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ def _sanitize_parameters(
forward_params = generate_kwargs

postprocess_params = {}
if return_text is not None:
warnings.warn(
"The `return_text` argument is deprecated and will be removed in version 5 of Transformers. ",
FutureWarning,
)
if return_tensors is not None and return_type is None:
warnings.warn(
"The `return_tensors` argument is deprecated and will be removed in version 5 of Transformers. ",
FutureWarning,
)
return_type = ReturnType.TENSORS if return_tensors else ReturnType.TEXT
if return_type is not None:
postprocess_params["return_type"] = return_type
Expand Down Expand Up @@ -135,17 +144,13 @@ def _parse_and_tokenize(self, *args, truncation):
del inputs["token_type_ids"]
return inputs

def __call__(self, *args, **kwargs):
def __call__(self, *inputs, **kwargs):
r"""
Generate the output text(s) using text(s) given as inputs.
Args:
args (`str` or `List[str]`):
inputs (`str` or `List[str]`):
Input text for the encoder.
return_tensors (`bool`, *optional*, defaults to `False`):
Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (`bool`, *optional*, defaults to `True`):
Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the potential extra spaces in the text output.
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
Expand All @@ -164,10 +169,10 @@ def __call__(self, *args, **kwargs):
ids of the generated text.
"""

result = super().__call__(*args, **kwargs)
result = super().__call__(*inputs, **kwargs)
if (
isinstance(args[0], list)
and all(isinstance(el, str) for el in args[0])
isinstance(inputs[0], list)
and all(isinstance(el, str) for el in inputs[0])
and all(len(res) == 1 for res in result)
):
return [res[0] for res in result]
Expand Down Expand Up @@ -252,14 +257,14 @@ def __call__(self, *args, **kwargs):
Summarize the text(s) given as inputs.
Args:
documents (*str* or `List[str]`):
inputs (*str* or `List[str]`):
One or several articles (or one list of articles) to summarize.
return_text (`bool`, *optional*, defaults to `True`):
Whether or not to include the decoded texts in the outputs
return_tensors (`bool`, *optional*, defaults to `False`):
Whether or not to include the tensors of predictions (as token indices) in the outputs.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the potential extra spaces in the text output.
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
The truncation strategy for the tokenization within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
(default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
max_length instead of throwing an error down the line.
generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework [here](./text_generation)).
Expand Down Expand Up @@ -343,19 +348,19 @@ def _sanitize_parameters(self, src_lang=None, tgt_lang=None, **kwargs):
preprocess_params["tgt_lang"] = items[3]
return preprocess_params, forward_params, postprocess_params

def __call__(self, *args, **kwargs):
def __call__(self, *inputs, **kwargs):
r"""
Translate the text(s) given as inputs.
Args:
args (`str` or `List[str]`):
inputs (`str` or `List[str]`):
Texts to be translated.
return_tensors (`bool`, *optional*, defaults to `False`):
Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (`bool`, *optional*, defaults to `True`):
Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the potential extra spaces in the text output.
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
The truncation strategy for the tokenization within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
(default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
max_length instead of throwing an error down the line.
src_lang (`str`, *optional*):
The language of the input. Might be required for multilingual models. Will not have any effect for
single pair translation models
Expand All @@ -373,4 +378,4 @@ def __call__(self, *args, **kwargs):
- **translation_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The
token ids of the translation.
"""
return super().__call__(*args, **kwargs)
return super().__call__(*inputs, **kwargs)

0 comments on commit ceb990d

Please sign in to comment.