From 5adb80482c8242dde7b7942529db94fa9ccbfe48 Mon Sep 17 00:00:00 2001 From: Optimox Date: Tue, 25 May 2021 10:19:25 +0200 Subject: [PATCH] feat: pretraining matches paper --- README.md | 26 +++++++++++++++++++++----- pretraining_example.ipynb | 7 +++++-- pytorch_tabnet/abstract_model.py | 2 ++ pytorch_tabnet/tab_network.py | 25 ++++++++++++++----------- 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index a91651b1..4d3c00dd 100644 --- a/README.md +++ b/README.md @@ -155,6 +155,18 @@ A complete example can be found within the notebook `pretraining_example.ipynb`. /!\ : current implementation is trying to reconstruct the original inputs, but Batch Normalization applies a random transformation that can't be deduced by a single line, making the reconstruction harder. Lowering the `batch_size` might make the pretraining easier. +# Easy saving and loading + +It's really easy to save and re-load a trained model, this makes TabNet production ready. +``` +# save tabnet model +saving_path_name = "./tabnet_model_test_1" +saved_filepath = clf.save_model(saving_path_name) + +# define new model with basic parameters and load state dict weights +loaded_clf = TabNetClassifier() +loaded_clf.load_model(saved_filepath) +``` # Useful links @@ -251,10 +263,6 @@ A complete example can be found within the notebook `pretraining_example.ipynb`. Name of the model used for saving in disk, you can customize this to easily retrieve and reuse your trained models. -- `saving_path` : str (default = './') - - Path defining where to save models. - - `verbose` : int (default=1) Verbosity for notebooks plots, set to 1 to see every epoch, 0 to get None. @@ -263,7 +271,15 @@ A complete example can be found within the notebook `pretraining_example.ipynb`. 'cpu' for cpu training, 'gpu' for gpu training, 'auto' to automatically detect gpu. - `mask_type: str` (default='sparsemax') - Either "sparsemax" or "entmax" : this is the masking function to use for selecting features + Either "sparsemax" or "entmax" : this is the masking function to use for selecting features. + +- `n_shared_decoder` : int (default=1) + + Number of shared GLU block in decoder, this is only useful for `TabNetPretrainer`. + +- `n_indep_decoder` : int (default=1) + + Number of independent GLU block in decoder, this is only useful for `TabNetPretrainer`. ## Fit parameters diff --git a/pretraining_example.ipynb b/pretraining_example.ipynb index 6dab9c3b..2594403d 100644 --- a/pretraining_example.ipynb +++ b/pretraining_example.ipynb @@ -178,7 +178,9 @@ " cat_emb_dim=3,\n", " optimizer_fn=torch.optim.Adam,\n", " optimizer_params=dict(lr=2e-2),\n", - " mask_type='entmax' # \"sparsemax\"\n", + " mask_type='entmax', # \"sparsemax\",\n", + " n_shared_decoder=1, # nb shared glu for decoding\n", + " n_indep_decoder=1, # nb independent glu for decoding\n", ")" ] }, @@ -214,6 +216,7 @@ " num_workers=0,\n", " drop_last=False,\n", " pretraining_ratio=0.8,\n", + "\n", ") " ] }, @@ -492,7 +495,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5" + "version": "3.7.6" }, "toc": { "base_numbering": 1, diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 50d99bb1..29be89f9 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -60,6 +60,8 @@ class TabModel(BaseEstimator): input_dim: int = None output_dim: int = None device_name: str = "auto" + n_shared_decoder: int = 1 + n_indep_decoder: int = 1 def __post_init__(self): self.batch_size = 1024 diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 3deabdc4..769995ec 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -206,8 +206,8 @@ def __init__( input_dim, n_d=8, n_steps=3, - n_independent=2, - n_shared=2, + n_independent=1, + n_shared=1, virtual_batch_size=128, momentum=0.02, ): @@ -228,9 +228,9 @@ def __init__( gamma : float Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0) n_independent : int - Number of independent GLU layer in each GLU block (default 2) + Number of independent GLU layer in each GLU block (default 1) n_shared : int - Number of independent GLU layer in each GLU block (default 2) + Number of independent GLU layer in each GLU block (default 1) virtual_batch_size : int Batch size for Ghost Batch Normalization momentum : float @@ -245,7 +245,6 @@ def __init__( self.virtual_batch_size = virtual_batch_size self.feat_transformers = torch.nn.ModuleList() - self.reconstruction_layers = torch.nn.ModuleList() if self.n_shared > 0: shared_feat_transform = torch.nn.ModuleList() @@ -268,16 +267,16 @@ def __init__( momentum=momentum, ) self.feat_transformers.append(transformer) - reconstruction_layer = Linear(n_d, self.input_dim, bias=False) - initialize_non_glu(reconstruction_layer, n_d, self.input_dim) - self.reconstruction_layers.append(reconstruction_layer) + + self.reconstruction_layer = Linear(n_d, self.input_dim, bias=False) + initialize_non_glu(self.reconstruction_layer, n_d, self.input_dim) def forward(self, steps_output): res = 0 for step_nb, step_output in enumerate(steps_output): x = self.feat_transformers[step_nb](step_output) - x = self.reconstruction_layers[step_nb](step_output) res = torch.add(res, x) + res = self.reconstruction_layer(res) return res @@ -299,6 +298,8 @@ def __init__( virtual_batch_size=128, momentum=0.02, mask_type="sparsemax", + n_shared_decoder=1, + n_indep_decoder=1, ): super(TabNetPretraining, self).__init__() @@ -316,6 +317,8 @@ def __init__( self.n_shared = n_shared self.mask_type = mask_type self.pretraining_ratio = pretraining_ratio + self.n_shared_decoder = n_shared_decoder + self.n_indep_decoder = n_indep_decoder if self.n_steps <= 0: raise ValueError("n_steps should be a positive integer.") @@ -345,8 +348,8 @@ def __init__( self.post_embed_dim, n_d=n_d, n_steps=n_steps, - n_independent=n_independent, - n_shared=n_shared, + n_independent=self.n_indep_decoder, + n_shared=self.n_shared_decoder, virtual_batch_size=virtual_batch_size, momentum=momentum, )