-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Refactor KTO [2/N]: Improve config validation in KTOConfig #4787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): | ||
| raise ValueError( | ||
| "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." | ||
| " Please install `wandb` or `comet-ml` to resolve." | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about this change? Before it was checked at trainer instantiation and now at config creation. This has an import time side effect: it imports wandb/comet checkers immediately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I prefer to keep all these kinds of checks in one place (in MyTrainer.__init__), so that the configuration remains minimal, and we avoid unintended duplication. That said, the codebase isn’t entirely consistent on this point, see for example:
trl/trl/trainer/grpo_config.py
Lines 845 to 888 in 1a93971
| if self.generation_batch_size is None and self.steps_per_generation is None: | |
| self.steps_per_generation = self.gradient_accumulation_steps | |
| self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation | |
| elif self.generation_batch_size is not None and self.steps_per_generation is None: | |
| # Just ensure the value is divisible by the global batch size | |
| if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0: | |
| raise ValueError( | |
| f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size " | |
| f"({self.per_device_train_batch_size * num_processes})." | |
| ) | |
| self.steps_per_generation = self.generation_batch_size // ( | |
| self.per_device_train_batch_size * num_processes | |
| ) | |
| elif self.generation_batch_size is None and self.steps_per_generation is not None: | |
| self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation | |
| else: | |
| raise ValueError( | |
| "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time" | |
| ) | |
| if self.do_eval and self.eval_strategy != "no": | |
| # Determine the number of generations to use for evaluation | |
| num_generations = self.num_generations_eval or self.num_generations | |
| # Just ensure the value is divisible by the global batch size | |
| if (self.per_device_eval_batch_size * num_processes) % num_generations != 0: | |
| raise ValueError( | |
| f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " | |
| f"divisible by the number of generations used for evaluation ({num_generations})." | |
| ) | |
| # The generation batch must contain full prompt groups (no partials), so it must be divisible by | |
| # num_generations. | |
| if self.generation_batch_size % self.num_generations != 0: | |
| raise ValueError( | |
| f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " | |
| f"({self.num_generations})." | |
| ) | |
| if self.num_generations < 2: | |
| raise ValueError( | |
| "GRPO requires at least 2 generations per prompt to calculate the advantages. You provided " | |
| f"{self.num_generations}, which is less than the minimum required." | |
| ) |
For this specific argument, I think it could probably be removed. See point 7 here:
#3906 (comment)
| # Validate beta | ||
| if self.beta <= 0: | ||
| raise ValueError( | ||
| f"beta must be positive, got {self.beta}. Higher β means less deviation from the reference model." | ||
| ) | ||
|
|
||
| # Validate weights | ||
| if self.desirable_weight <= 0: | ||
| raise ValueError( | ||
| f"desirable_weight must be positive, got {self.desirable_weight}. " | ||
| "This weight is used to balance desirable and undesirable examples." | ||
| ) | ||
|
|
||
| if self.undesirable_weight <= 0: | ||
| raise ValueError( | ||
| f"undesirable_weight must be positive, got {self.undesirable_weight}. " | ||
| "This weight is used to balance desirable and undesirable examples." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m generally in favor of keeping validation as minimal as possible. Users experimenting with parameters are expected to understand the semantics, and overly defensive validation tends to grow without bound. I’d rather keep checks limited to hard invariants (i.e. impossible combinations), like here:
trl/trl/trainer/grpo_config.py
Lines 876 to 882 in 1a93971
| # The generation batch must contain full prompt groups (no partials), so it must be divisible by | |
| # num_generations. | |
| if self.generation_batch_size % self.num_generations != 0: | |
| raise ValueError( | |
| f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " | |
| f"({self.num_generations})." | |
| ) |
This one is required by the sampling logic: if you remove the check, you can end up in a failure mode that would be very hard to debug. Beyond that, I’d rather let mistakes fail naturally.
Also, strict value checks can unnecessarily block experimentation. For example, negative (un)desirable weights aren’t mathematically invalid: they change the objective (and are almost certainly not what we want), but I’m not sure we should hard-forbid them at the config level unless they actually cause breakage.
Refactor KTO [2/N]: Improve config validation in KTOConfig.
This PR moves validation logic from
KTOTrainer.__init__()toKTOConfig.__post_init__()for earlier error detection, better user experience, and cleaner separation of concerns.Part of:
Problem
Before:
After:
KTOConfig.__post_init__()