diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index a50151ca..3deabdc4 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -794,10 +794,19 @@ def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dim): If int, the same embedding dimension will be used for all categorical features """ super(EmbeddingGenerator, self).__init__() - if cat_dims == [] or cat_idxs == []: + if cat_dims == [] and cat_idxs == []: self.skip_embedding = True self.post_embed_dim = input_dim return + elif (cat_dims == []) ^ (cat_idxs == []): + if cat_dims == []: + msg = "If cat_idxs is non-empty, cat_dims must be defined as a list of same length." + else: + msg = "If cat_dims is non-empty, cat_idxs must be defined as a list of same length." + raise ValueError(msg) + elif len(cat_dims) != len(cat_idxs): + msg = "The lists cat_dims and cat_idxs must have the same length." + raise ValueError(msg) self.skip_embedding = False if isinstance(cat_emb_dim, int):