Skip to content

Commit

Permalink
fix: fix embedding net mistake
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Sep 16, 2024
1 parent 045bf5e commit bcc75db
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions sbi/neural_nets/estimators/categorical_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __init__(
self._initialize()

Check warning on line 61 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L60-L61

Added lines #L60 - L61 were not covered by tests

def forward(self, inputs, context=None):
embedded_inputs = self.embedding_net.forward(inputs)
return super().forward(embedded_inputs, context=context)
embedded_context = self.embedding_net.forward(context)
return super().forward(inputs, context=embedded_context)

Check warning on line 65 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L64-L65

Added lines #L64 - L65 were not covered by tests

def compute_probs(self, outputs):
ps = F.softmax(outputs, dim=-1) * self.mask
Expand Down

0 comments on commit bcc75db

Please sign in to comment.