Skip to content

CI fails for slow tests: TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'add_generation_prompt' #4369

@albertvillanova

Description

@albertvillanova

CI fails for slow tests: https://github.com/huggingface/trl/actions/runs/18904382222/job/53959012155

TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'add_generation_prompt'

FAILED tests/slow/test_grpo_slow.py::TestGRPOTrainerSlow::test_training_with_transformers_paged[trl-internal-testing/tiny-LlamaForCausalLM-3.2] - TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'add_generation_prompt'
FAILED tests/slow/test_grpo_slow.py::TestGRPOTrainerSlow::test_training_with_transformers_paged[trl-internal-testing/tiny-MistralForCausalLM-0.2] - TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'add_generation_prompt'

Stacktrace:

_ TestGRPOTrainerSlow.test_training_with_transformers_paged[trl-internal-testing/tiny-MistralForCausalLM-0.2] _

self = <tests.slow.test_grpo_slow.TestGRPOTrainerSlow object at 0x7fbb2f63c8d0>
model_name = 'trl-internal-testing/tiny-MistralForCausalLM-0.2'

    @pytest.mark.parametrize("model_name", MODELS_TO_TEST)
    def test_training_with_transformers_paged(self, model_name):
        """Test that training works with transformers paged implementation (requires GPU)."""
        if Version(transformers.__version__) < Version("4.57.0"):
            pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.57.0")
        training_args = GRPOConfig(
            output_dir=self.tmp_dir,
            learning_rate=0.1,  # increase the learning rate to speed up the test
            per_device_train_batch_size=3,  # reduce the batch size to reduce memory usage
            num_generations=3,  # reduce the number of generations to reduce memory usage
            max_completion_length=8,  # reduce the completion length to reduce memory usage
            use_transformers_paged=True,  # Enable transformers paged implementation
            report_to="none",
            logging_strategy="no",
        )
    
        model = AutoModelForCausalLM.from_pretrained(model_name)
    
        trainer = GRPOTrainer(
            model=model,
            reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
            args=training_args,
            train_dataset=self.train_dataset,
        )
    
        previous_trainable_params = {n: param.clone() for n, param in model.named_parameters()}
    
>       trainer.train()

tests/slow/test_grpo_slow.py:199: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.venv/lib/python3.11/site-packages/transformers/trainer.py:2325: in train
    return inner_training_loop(
.venv/lib/python3.11/site-packages/transformers/trainer.py:2674: in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/trainer.py:4014: in training_step
    inputs = self._prepare_inputs(inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/extras/profiling.py:98: in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1033: in _prepare_inputs
    generation_batch = self._generate_and_score_completions(generation_batch)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1410: in _generate_and_score_completions
    self._generate(prompts)
trl/trainer/grpo_trainer.py:1351: in _generate
    prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts)
                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1270: in _generate_single_turn
    processor_outputs = self.processing_class(text=prompts, **processor_kwargs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:2938: in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:3026: in _call_one
    return self.batch_encode_plus(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = LlamaTokenizerFast(name_or_path='trl-internal-testing/tiny-MistralForCausalLM-0.2', vocab_size=32000, model_max_length...ecial=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)
batch_text_or_text_pairs = ["Although that way may not be obvious at first unless you're", "Although that way may not be obvious at first unless you're", "Although that way may not be obvious at first unless you're"]
add_special_tokens = False, padding = False, truncation = True, max_length = 512
stride = 0, is_split_into_words = False, pad_to_multiple_of = None
padding_side = None, return_tensors = None, return_token_type_ids = None
return_attention_mask = None, return_overflowing_tokens = False
return_special_tokens_mask = False, return_offsets_mapping = False
return_length = False, verbose = True, split_special_tokens = False
kwargs = {'add_generation_prompt': True}
padding_strategy = <PaddingStrategy.DO_NOT_PAD: 'do_not_pad'>
truncation_strategy = <TruncationStrategy.LONGEST_FIRST: 'longest_first'>

    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
    def batch_encode_plus(
        self,
        batch_text_or_text_pairs: Union[
            list[TextInput],
            list[TextInputPair],
            list[PreTokenizedInput],
            list[PreTokenizedInputPair],
            list[EncodedInput],
            list[EncodedInputPair],
        ],
        add_special_tokens: bool = True,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy, None] = None,
        max_length: Optional[int] = None,
        stride: int = 0,
        is_split_into_words: bool = False,
        pad_to_multiple_of: Optional[int] = None,
        padding_side: Optional[str] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_token_type_ids: Optional[bool] = None,
        return_attention_mask: Optional[bool] = None,
        return_overflowing_tokens: bool = False,
        return_special_tokens_mask: bool = False,
        return_offsets_mapping: bool = False,
        return_length: bool = False,
        verbose: bool = True,
        split_special_tokens: bool = False,
        **kwargs,
    ) -> BatchEncoding:
        """
        Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.
    
        <Tip warning={true}>
    
        This method is deprecated, `__call__` should be used instead.
    
        </Tip>
    
        Args:
            batch_text_or_text_pairs (`list[str]`, `list[tuple[str, str]]`, `list[list[str]]`, `list[tuple[list[str], list[str]]]`, and for not-fast tokenizers, also `list[list[int]]`, `list[tuple[list[int], list[int]]]`):
                Batch of sequences or pair of sequences to be encoded. This can be a list of
                string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see
                details in `encode_plus`).
        """
    
        # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
        padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            pad_to_multiple_of=pad_to_multiple_of,
            verbose=verbose,
            **kwargs,
        )
    
>       return self._batch_encode_plus(
            batch_text_or_text_pairs=batch_text_or_text_pairs,
            add_special_tokens=add_special_tokens,
            padding_strategy=padding_strategy,
            truncation_strategy=truncation_strategy,
            max_length=max_length,
            stride=stride,
            is_split_into_words=is_split_into_words,
            pad_to_multiple_of=pad_to_multiple_of,
            padding_side=padding_side,
            return_tensors=return_tensors,
            return_token_type_ids=return_token_type_ids,
            return_attention_mask=return_attention_mask,
            return_overflowing_tokens=return_overflowing_tokens,
            return_special_tokens_mask=return_special_tokens_mask,
            return_offsets_mapping=return_offsets_mapping,
            return_length=return_length,
            verbose=verbose,
            split_special_tokens=split_special_tokens,
            **kwargs,
        )
E       TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'add_generation_prompt'

.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:3227: TypeError

Metadata

Metadata

Labels

🏋 GRPORelated to GRPO🐛 bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions