Skip to content

Latest commit

 

History

History
41 lines (29 loc) · 2.07 KB

optimizer.md

File metadata and controls

41 lines (29 loc) · 2.07 KB

The Optimizer is at the heart of the Gradient Descent process and is a key component that we need to train a good model. Pytorch Tabular uses Adam optimizer with a learning rate of 1e-3 by default. This is mainly because of a rule of thumb which provides a good starting point.

Sometimes, Learning Rate Schedulers let's you have finer control in the way the learning rates are used through the optimization process. By default, PyTorch Tabular applies no Learning Rate Scheduler.

Basic Usage

  • optimizer: str: Any of the standard optimizers from torch.optim. Defaults to Adam
  • optimizer_params: Dict: The parameters for the optimizer. If left blank, will use default parameters.
  • lr_scheduler: str: The name of the LearningRateScheduler to use, if any, from torch.optim.lr_scheduler. If None, will not use any scheduler. Defaults to None
  • lr_scheduler_params: Dict: The parameters for the LearningRateScheduler. If left blank, will use default parameters.
  • lr_scheduler_monitor_metric: str: Used with ReduceLROnPlateau, where the plateau is decided based on this metric. Defaults to val_loss

Usage Example

optimizer_config = OptimizerConfig(
    optimizer="RMSprop", lr_scheduler="StepLR", lr_scheduler_params={"step_size": 10}
)

Advanced Usage

While the Config object restricts you to the standard Optimizers and Learning Rate Schedulers in torch.optim, you can use any custom Optimizer or Learning Rate Scheduler, as long as they are drop-in replacements for standard ones. You can do this using the fit method of TabularModel, which allows you to override the optimizer and learning rate which is set through config.

Usage Example

from torch_optimizer import QHAdam

tabular_model.fit(
    train=train,
    validation=val,
    optimizer=QHAdam,
    optimizer_params={"nus": (0.7, 1.0), "betas": (0.95, 0.998)},
)

::: pytorch_tabular.config.OptimizerConfig options: show_root_heading: yes