Skip to content

Commit

Permalink
feat: raise error in case cat_dims and cat_idxs are incoherent
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed May 5, 2021
1 parent 5e4e809 commit 8c3b795
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,19 @@ def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dim):
If int, the same embedding dimension will be used for all categorical features
"""
super(EmbeddingGenerator, self).__init__()
if cat_dims == [] or cat_idxs == []:
if cat_dims == [] and cat_idxs == []:
self.skip_embedding = True
self.post_embed_dim = input_dim
return
elif (cat_dims == []) ^ (cat_idxs == []):
if cat_dims == []:
msg = "If cat_idxs is non-empty, cat_dims must be defined as a list of same length."
else:
msg = "If cat_dims is non-empty, cat_idxs must be defined as a list of same length."
raise ValueError(msg)
elif len(cat_dims) != len(cat_idxs):
msg = "The lists cat_dims and cat_idxs must have the same length."
raise ValueError(msg)

self.skip_embedding = False
if isinstance(cat_emb_dim, int):
Expand Down

0 comments on commit 8c3b795

Please sign in to comment.