Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using a different ref_model from model leads to incorrect results #2307

Open
2 of 4 tasks
DarshanDeshpande opened this issue Nov 1, 2024 · 1 comment
Open
2 of 4 tasks
Labels
✨ enhancement New feature or request ❓ question Seeking clarification or more information

Comments

@DarshanDeshpande
Copy link

System Info

  • Platform: Linux-6.8.0-47-generic-x86_64-with-glibc2.35
  • Python version: 3.10.15
  • PyTorch version: 2.4.0
  • CUDA device(s): NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3
  • Transformers version: 4.46.1
  • Accelerate version: 1.1.0
  • Accelerate config:
    • compute_environment: LOCAL_MACHINE
    • distributed_type: FSDP
    • mixed_precision: bf16
    • use_cpu: False
    • debug: True
    • num_processes: 4
    • machine_rank: 0
    • num_machines: 1
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • enable_cpu_affinity: False
    • fsdp_config: {'fsdp_activation_checkpointing': True, 'fsdp_auto_wrap_policy': 'TRANSFORMER_BASED_WRAP', 'fsdp_backward_prefetch': 'BACKWARD_PRE', 'fsdp_cpu_ram_efficient_loading': True, 'fsdp_forward_prefetch': True, 'fsdp_offload_params': False, 'fsdp_sharding_strategy': 'FULL_SHARD', 'fsdp_state_dict_type': 'FULL_STATE_DICT', 'fsdp_sync_module_states': True, 'fsdp_use_orig_params': True}
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
    • tpu_env: []
  • Datasets version: 3.1.0
  • HF Hub version: 0.26.2
  • TRL version: 0.12.0
  • bitsandbytes version: 0.44.1
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: not installed

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

I noticed that the DPO trainer uses the processing_class to tokenize inputs to both model and ref_model. Is there a way to allow for a different ref_model base class that does not share the same tokenizer config with the model? For example using a Llama-3.1-8b model to align a Llama-3.2-3b model - training with this configuration leads to a constant loss=1.0 at the moment.

Expected behavior

The Trainer must take two processing classes and allow for a different ref_model and model class

@qgallouedec
Copy link
Member

Indeed, it's not currently supported. And unless it's widely demanded, I don't think it will be.

Having said that, I think you can easily implement it. The following should work:

  1. set precompute_ref_log_probs=True in DPOConfig

  2. add a new parameter ref_processing_class in DPOTrainer

  3. in DPOTrainer.__init__, create a new tokenized dataset with ref_processing_class something like

    fn_kwargs = {
        "processing_class": ref_processing_class,  # <-
        "max_prompt_length": args.max_prompt_length,
        "max_completion_length": args.max_completion_length,
        # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
        "add_special_tokens": self.is_encoder_decoder,
    }
    self.ref_train_dataset = train_dataset.map(  # <-
        self.tokenize_row if not self.is_vision_model else self.process_row,
        fn_kwargs=fn_kwargs,
        num_proc=self.dataset_num_proc,
        writer_batch_size=10,
        desc="Tokenizing train dataset",
    )
  4. modify precomputing ref part here

    if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
    dataloader_params = {
    "batch_size": self.args.per_device_train_batch_size,
    "collate_fn": self.data_collator,
    "num_workers": self.args.dataloader_num_workers,
    "pin_memory": self.args.dataloader_pin_memory,
    "shuffle": False,
    }
    # prepare dataloader
    data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))

    with

    data_loader = self.accelerator.prepare(DataLoader(self.ref_train_dataset, **dataloader_params))

@qgallouedec qgallouedec added ✨ enhancement New feature or request ❓ question Seeking clarification or more information labels Nov 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
✨ enhancement New feature or request ❓ question Seeking clarification or more information
Projects
None yet
Development

No branches or pull requests

2 participants