Skip to content

Commit

Permalink
Model Sweep (#366)
Browse files Browse the repository at this point in the history
* Added model comparator
tweaked DANet default virtual batch size

* added test cases

* added lite config
added model sweep
added model sweep tutorial

* added documentation

* update documentation for model sweep

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
manujosephv and pre-commit-ci[bot] authored Jan 4, 2024
1 parent 96d6c06 commit 24ce9e5
Show file tree
Hide file tree
Showing 17 changed files with 1,437 additions and 1,795 deletions.
6 changes: 6 additions & 0 deletions docs/apidocs_coreclasses.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@
::: pytorch_tabular.TabularDatamodule
options:
heading_level: 3
::: pytorch_tabular.TabularModelTuner
options:
heading_level: 3
::: pytorch_tabular.model_sweep
options:
heading_level: 3
23 changes: 22 additions & 1 deletion docs/apidocs_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
::: pytorch_tabular.utils.get_gaussian_centers
options:
heading_level: 3
::: pytorch_tabular.utils.load_covertype_dataset
options:
heading_level: 3
::: pytorch_tabular.utils.make_mixed_dataset
options:
heading_level: 3
::: pytorch_tabular.utils.print_metrics
options:
heading_level: 3

## NN Utilities
::: pytorch_tabular.utils._initialize_layers
Expand All @@ -38,7 +47,10 @@
::: pytorch_tabular.utils.to_one_hot
options:
heading_level: 3

::: pytorch_tabular.utils.count_parameters
options:
heading_level: 3

## Python Utilities
::: pytorch_tabular.utils.getattr_nested
options:
Expand All @@ -55,3 +67,12 @@
::: pytorch_tabular.utils.generate_doc_dataclass
options:
heading_level: 3
::: pytorch_tabular.utils.suppress_lightning_logs
options:
heading_level: 3
::: pytorch_tabular.utils.enable_lightning_logs
options:
heading_level: 3
::: pytorch_tabular.utils.int_to_human_readable
options:
heading_level: 3
39 changes: 39 additions & 0 deletions docs/tabular_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,45 @@ tabular_model = TabularModel(
)
```

### Model Sweep

PyTorch Tabular also provides an easy way to check performance of different models and configurations on a given dataset. This is done through the `model_sweep` function. It takes in a list of model configs or one of the presets defined in ``pytorch_tabular.MODEL_PRESETS`` and trains them on the data. It then ranks the models based on the metric provided and returns the best model.

These are the major args:
- ``task``: The type of prediction task. Either 'classification' or 'regression'
- ``train``: The training data
- ``test``: The test data on which performance is evaluated
- all the config objects can be passed as either the object or the path to the yaml file.
- ``models``: The list of models to compare. This can be one of the presets defined in ``pytorch_tabular.MODEL_SWEEP_PRESETS`` or a list of ``ModelConfig`` objects.
- ``metrics``: the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in ``torchmetrics``. By default, it is accuracy if classification and mean_squared_error for regression
- ``metrics_prob_input``: Is a mandatory parameter for classification metrics defined in the config. This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None.
- ``metrics_params``: The parameters to be passed to the metrics function.
- ``rank_metric``: The metric to use for ranking the models. The first element of the tuple is the metric name and the second element is the direction. Defaults to ('loss', "lower_is_better").
- ``return_best_model``: If True, will return the best model. Defaults to True.

#### Usage Example

```python
sweep_df, best_model = model_sweep(
task="classification", # One of "classification", "regression"
train=train,
test=test,
data_config=data_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
model_list="lite", # One of the presets defined in pytorch_tabular.MODEL_SWEEP_PRESETS
common_model_args=dict(head="LinearHead", head_config=head_config),
metrics=['accuracy', "f1_score"], # The metrics to track during training
metrics_params=[{}, {"average": "weighted"}],
metrics_prob_input=[False, True],
rank_metric=("accuracy", "higher_is_better"), # The metric to use for ranking the models.
progress_bar=True, # If True, will show a progress bar
verbose=False # If True, will print the results of each model
)
```

For more examples, check out the tutorial notebook - [Model Sweep]("tutorials/13-Model Sweep.ipynb") for example usage.

### Advanced Usage

- `config`: DictConfig: Another way of initializing `TabularModel` is with an `Dictconfig` from `omegaconf`. Although not recommended, you can create a normal dictionary with all the parameters dumped into it and create a `DictConfig` from `omegaconf` and pass it here. The downside is that you'll be skipping all the validation(both type validation and logical validations). This is primarily used internally to load a saved model from a checkpoint.
Expand Down
402 changes: 402 additions & 0 deletions docs/tutorials/13-Model Leaderboard copy.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 24ce9e5

Please sign in to comment.