From 5d76e9920e6906fda2874bf61f54374a066fa3e8 Mon Sep 17 00:00:00 2001 From: sciengineer <40659034+sciengineer@users.noreply.github.com> Date: Tue, 26 Sep 2023 00:20:11 +0800 Subject: [PATCH] Update tab_network.py The mask tensor M needs to be transformed so that "a feature is enforced to be used only at one desicion step", when gamma is 1. --- pytorch_tabnet/tab_network.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 95c2bae2..605d863b 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -170,11 +170,19 @@ def forward(self, x, prior=None): steps_output = [] for step in range(self.n_steps): M = self.att_transformers[step](prior, att) + # copied from M + M_copy = M.clone() + # Set the element of M_copy to 1 if it is positive. + mask = M_copy>0 + M_copy[mask] = 1 M_loss += torch.mean( torch.sum(torch.mul(M, torch.log(M + self.epsilon)), dim=1) ) # update prior - prior = torch.mul(self.gamma - M, prior) + # If gamma is 1 and the element of a sample in M_copy is equal to 1, + # then the prior will be 0 and the corresponding feature will be enforced + # not to use in all the follow decision steps. + prior = torch.mul(self.gamma - M_copy, prior) # output M_feature_level = torch.matmul(M, self.group_attention_matrix) masked_x = torch.mul(M_feature_level, x)