Skip to content

Commit

Permalink
Fix _process_tokens for empty prompts in KTOTrainer (#2093)
Browse files Browse the repository at this point in the history
The function _process_tokens in trl/trainers/kto_trainer.py crashes if the prompt_input_ids are an empty list.
- added a check for nonzero length
- added a check for nonzero length of answer_input_ids for consistency

The checks happen when determining when subtracting 1 from max_length (happens when BOS or EOS is already present).
  • Loading branch information
gabikadlecova authored Sep 21, 2024
1 parent 9b80f3d commit 44d998b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **
max_length = kwargs["max_length"]
bos_token_id = kwargs["tokenizer"].bos_token_id
eos_token_id = kwargs["tokenizer"].eos_token_id
if bos_token_id != all_tokens["prompt_input_ids"][0]:
if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]:
max_length -= 1
if eos_token_id != all_tokens["answer_input_ids"][-1]:
if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]:
max_length -= 1

# if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt
Expand Down

0 comments on commit 44d998b

Please sign in to comment.