Skip to content

Commit

Permalink
processor(prompt, images=image) to `processor(images=image, text=pr…
Browse files Browse the repository at this point in the history
…ompt)` (#2076)

* `prompt, images=image` to `images=image, text=prompt`

* special case of model being str in BCO
  • Loading branch information
qgallouedec authored Sep 17, 2024
1 parent e74dbf2 commit c314383
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def __init__(
if type(args) is TrainingArguments:
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")

if ref_model is model:
if not isinstance(model, str) and ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must mass a copy of it, or `None` if you use peft."
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _process_prompt(
)
prompt_tokens = []
for prompt, image in zip(prompts, images):
tokens = processor(prompt, images=image, **processor_kwargs)
tokens = processor(images=image, text=prompt, **processor_kwargs)
tokens = {k: v[0] for k, v in tokens.items()}
if not isinstance(tokens["input_ids"], list):
tokens["input_ids"] = tokens["input_ids"].tolist()
Expand Down Expand Up @@ -302,7 +302,7 @@ def tokenize(text, images=None):
if "add_special_tokens" in inspect.signature(processor).parameters
else {}
)
tokenized = processor(text, images=images, **processor_kwargs)
tokenized = processor(images=images, text=text, **processor_kwargs)
tokenized = {k: v[0] for k, v in tokenized.items()}
if not isinstance(tokenized["input_ids"], list):
tokenized["input_ids"] = tokenized["input_ids"].tolist()
Expand Down

0 comments on commit c314383

Please sign in to comment.