Skip to content

Conversation

@albertvillanova
Copy link
Member

Refactor KTO [2/N]: Improve config validation in KTOConfig.

This PR moves validation logic from KTOTrainer.__init__() to KTOConfig.__post_init__() for earlier error detection, better user experience, and cleaner separation of concerns.

  • Principle: Fail-fast with clear, actionable error messages

Part of:

Problem

Before:

  • Validation scattered between config and trainer
  • Errors discovered late (during trainer initialization)
  • Invalid configs could be created and passed around
  • generate_during_eval validated in trainer
  • No validation for loss_type, truncation_mode, beta, weights
  • No validation for max_length relationships

After:

  • All validation centralized in KTOConfig.__post_init__()
  • Errors discovered immediately at config creation
  • Invalid configs cannot be created
  • Clear, actionable error messages with guidance
  • Comprehensive validation for all critical parameters

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines -458 to -463
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
" Please install `wandb` or `comet-ml` to resolve."
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about this change? Before it was checked at trainer instantiation and now at config creation. This has an import time side effect: it imports wandb/comet checkers immediately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I prefer to keep all these kinds of checks in one place (in MyTrainer.__init__), so that the configuration remains minimal, and we avoid unintended duplication. That said, the codebase isn’t entirely consistent on this point, see for example:

if self.generation_batch_size is None and self.steps_per_generation is None:
self.steps_per_generation = self.gradient_accumulation_steps
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
elif self.generation_batch_size is not None and self.steps_per_generation is None:
# Just ensure the value is divisible by the global batch size
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
f"({self.per_device_train_batch_size * num_processes})."
)
self.steps_per_generation = self.generation_batch_size // (
self.per_device_train_batch_size * num_processes
)
elif self.generation_batch_size is None and self.steps_per_generation is not None:
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
else:
raise ValueError(
"'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time"
)
if self.do_eval and self.eval_strategy != "no":
# Determine the number of generations to use for evaluation
num_generations = self.num_generations_eval or self.num_generations
# Just ensure the value is divisible by the global batch size
if (self.per_device_eval_batch_size * num_processes) % num_generations != 0:
raise ValueError(
f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be "
f"divisible by the number of generations used for evaluation ({num_generations})."
)
# The generation batch must contain full prompt groups (no partials), so it must be divisible by
# num_generations.
if self.generation_batch_size % self.num_generations != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations "
f"({self.num_generations})."
)
if self.num_generations < 2:
raise ValueError(
"GRPO requires at least 2 generations per prompt to calculate the advantages. You provided "
f"{self.num_generations}, which is less than the minimum required."
)

For this specific argument, I think it could probably be removed. See point 7 here:
#3906 (comment)

@albertvillanova albertvillanova mentioned this pull request Jan 8, 2026
6 tasks
Comment on lines +284 to +301
# Validate beta
if self.beta <= 0:
raise ValueError(
f"beta must be positive, got {self.beta}. Higher β means less deviation from the reference model."
)

# Validate weights
if self.desirable_weight <= 0:
raise ValueError(
f"desirable_weight must be positive, got {self.desirable_weight}. "
"This weight is used to balance desirable and undesirable examples."
)

if self.undesirable_weight <= 0:
raise ValueError(
f"undesirable_weight must be positive, got {self.undesirable_weight}. "
"This weight is used to balance desirable and undesirable examples."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m generally in favor of keeping validation as minimal as possible. Users experimenting with parameters are expected to understand the semantics, and overly defensive validation tends to grow without bound. I’d rather keep checks limited to hard invariants (i.e. impossible combinations), like here:

# The generation batch must contain full prompt groups (no partials), so it must be divisible by
# num_generations.
if self.generation_batch_size % self.num_generations != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations "
f"({self.num_generations})."
)

This one is required by the sampling logic: if you remove the check, you can end up in a failure mode that would be very hard to debug. Beyond that, I’d rather let mistakes fail naturally.

Also, strict value checks can unnecessarily block experimentation. For example, negative (un)desirable weights aren’t mathematically invalid: they change the objective (and are almost certainly not what we want), but I’m not sure we should hard-forbid them at the config level unless they actually cause breakage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants