Skip to content

Commit

Permalink
Disable dropout in DPO Training (#639)
Browse files Browse the repository at this point in the history
* disable dropout in dpo

* quick fix docs

* precommiot

* add disable_dropout_in_model to DPOTrainer

* disable_dropout -> disable_dropout_in_model

* .

* .
  • Loading branch information
NouamaneTazi authored Aug 14, 2023
1 parent 3b2c820 commit 98120d6
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
]

model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

Expand Down
1 change: 1 addition & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
disable_dropout_in_model,
)

# isort: on
Expand Down
9 changes: 8 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers.trainer_callback import TrainerCallback

from ..import_utils import is_peft_available
from .utils import DPODataCollatorWithPadding, pad_to_length
from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length


if is_peft_available():
Expand Down Expand Up @@ -73,6 +73,8 @@ class DPOTrainer(Trainer):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
"""

def __init__(
Expand All @@ -98,6 +100,7 @@ def __init__(
max_length: Optional[int] = None,
max_prompt_length: Optional[int] = None,
peft_config: Optional[Dict] = None,
disable_dropout: bool = True,
):
if not is_peft_available() and peft_config is not None:
raise ValueError(
Expand Down Expand Up @@ -150,6 +153,10 @@ def __init__(
else:
self.use_dpo_data_collator = False

if disable_dropout:
disable_dropout_in_model(model)
disable_dropout_in_model(ref_model)

self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value

Expand Down
8 changes: 7 additions & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class DPODataCollatorWithPadding:
padding_value (`int`, defaults to 0):
The value used for padding.
truncation_mode: (`str`, defaults to "keep_end"):
The truncation mode to use when truncating the prompt + chosen/rejected responses.
The truncation mode to use when truncating the prompt.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str] = True
Expand Down Expand Up @@ -586,3 +586,9 @@ def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float
],
dim=dim,
)


def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0

0 comments on commit 98120d6

Please sign in to comment.