diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 95c2bae2..605d863b 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -170,11 +170,19 @@ def forward(self, x, prior=None): steps_output = [] for step in range(self.n_steps): M = self.att_transformers[step](prior, att) + # copied from M + M_copy = M.clone() + # Set the element of M_copy to 1 if it is positive. + mask = M_copy>0 + M_copy[mask] = 1 M_loss += torch.mean( torch.sum(torch.mul(M, torch.log(M + self.epsilon)), dim=1) ) # update prior - prior = torch.mul(self.gamma - M, prior) + # If gamma is 1 and the element of a sample in M_copy is equal to 1, + # then the prior will be 0 and the corresponding feature will be enforced + # not to use in all the follow decision steps. + prior = torch.mul(self.gamma - M_copy, prior) # output M_feature_level = torch.matmul(M, self.group_attention_matrix) masked_x = torch.mul(M_feature_level, x)