Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: binli123/dsmil-wsi
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: master
Choose a base ref
...
head repository: GeorgeBatch/dsmil-wsi-public-fork
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: simplify-critical-instance-choice
Choose a head ref
Able to merge. These branches can be automatically merged.
  • 3 commits
  • 1 file changed
  • 2 contributors

Commits on Jun 22, 2023

  1. find indices of maximum class scores instead of sorting (descending) …

    …and taking the 0th element
    GeorgeBatch committed Jun 22, 2023
    Copy the full SHA
    05dc425 View commit details
  2. Copy the full SHA
    0da1493 View commit details

Commits on Apr 29, 2024

  1. Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    1ea14cc View commit details
Showing with 2 additions and 2 deletions.
  1. +2 −2 dsmil.py
4 changes: 2 additions & 2 deletions dsmil.py
Original file line number Diff line number Diff line change
@@ -49,8 +49,8 @@ def forward(self, feats, c): # N x K, N x C
Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted

# handle multiple classes without for loop
_, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K
_, m_indices = torch.max(c, dim=0) # sort class scores along the instance dimension, m_indices in shape N x C
m_feats = feats[m_indices, :] # select critical instances, m_feats in shape C x K
q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q
A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores
A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C,