From f9f08c902c91f9d30b6a26f1f7939ab19ca06874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 17 Sep 2024 09:45:05 +0000 Subject: [PATCH 1/2] `prompt, images=image` to `images=image, text=prompt` --- trl/trainer/dpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2189580c2e..3c9ff4624b 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -123,7 +123,7 @@ def _process_prompt( ) prompt_tokens = [] for prompt, image in zip(prompts, images): - tokens = processor(prompt, images=image, **processor_kwargs) + tokens = processor(images=image, text=prompt, **processor_kwargs) tokens = {k: v[0] for k, v in tokens.items()} if not isinstance(tokens["input_ids"], list): tokens["input_ids"] = tokens["input_ids"].tolist() @@ -302,7 +302,7 @@ def tokenize(text, images=None): if "add_special_tokens" in inspect.signature(processor).parameters else {} ) - tokenized = processor(text, images=images, **processor_kwargs) + tokenized = processor(images=images, text=text, **processor_kwargs) tokenized = {k: v[0] for k, v in tokenized.items()} if not isinstance(tokenized["input_ids"], list): tokenized["input_ids"] = tokenized["input_ids"].tolist() From 0423e3a422ee765220c1eb1d9c73c3dc12cc4546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 17 Sep 2024 09:53:56 +0000 Subject: [PATCH 2/2] special case of model being str in BCO --- trl/trainer/bco_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 44a4ca03a2..c4ed97d3de 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -336,7 +336,7 @@ def __init__( if type(args) is TrainingArguments: raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.") - if ref_model is model: + if not isinstance(model, str) and ref_model is model: raise ValueError( "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " "same as `model`, you must mass a copy of it, or `None` if you use peft."