Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About loss function #5

Open
wudongming97 opened this issue Jun 28, 2022 · 13 comments
Open

About loss function #5

wudongming97 opened this issue Jun 28, 2022 · 13 comments

Comments

@wudongming97
Copy link

Hi, I found that the loss used in this repo is a cross-entropy loss between prediction and mask.

loss = F.binary_cross_entropy_with_logits(pred, mask)

But the loss mentioned in the paper is a contrastive loss between visual and textual features.

@Deepayan137
Copy link

I have the same query. Can the authors please clarify?

@Deepayan137
Copy link

Deepayan137 commented Jul 9, 2022

Hello! I wrote the contrastive learning part by following the instructions in the paper. However, when training the model only with the contrastive loss, the training IOU doesn't seem to improve. Below, I am attaching the code snippet and the training IOU and precision curves. The training is done only for 1 epoch. The brown plots are for cross-entropy loss while the blue plots are for contrastive loss. I would be grateful if you could let me know what I am doing wrong and also if the contrastive loss is supposed to be used in addition to cross-entropy loss.
Thanks

def forward(self, x, word, mask):
      x = self.vis(x)
      B, C, H, W = x.size()
      word = self.txt(word)
      x = x.permute(0, 2, 3, 1)
      out = torch.einsum('nhwc,nc->nhw', x, word).unsqueeze(1)
      out = torch.sigmoid(out) #sigmoid of zt dot zv
      loss = torch.zeros((H, W)).cuda()
      pos_count, neg_count = 0, 0
      for i in range(word.size(0)):
          zt = word[i]
          zt = zt.unsqueeze(0)
          for j in range(x.size(0)):
              zv = x[j]
              zv = zv.reshape(self.in_dim, -1)
              prod = torch.mm(zt, zv).squeeze()
              prod = prod.reshape(H, W)
              if i == j:
                  pos = - torch.log(F.sigmoid(prod))
                  loss += pos
                  pos_count += 1
              else:
                  neg = - torch.log(1 - F.sigmoid(prod))
                  loss += neg
                  neg_count += 1
      total = pos_count + neg_count
      loss = torch.mean(loss)
      if out.shape[-2:] != mask.shape[-2:]:
          mask = F.interpolate(mask, out.shape[-2:],
              mode='nearest').detach()
      return out, loss/total, mask

W B Chart 09_07_2022, 10_18_53
W B Chart 09_07_2022, 10_18_40

@DerrickWang005
Copy link
Owner

please follow our implementation.

class Projector(nn.Module):
def __init__(self, word_dim=1024, in_dim=256, kernel_size=3):
super().__init__()
self.in_dim = in_dim
self.kernel_size = kernel_size
# visual projector
self.vis = nn.Sequential( # os16 -> os4
nn.Upsample(scale_factor=2, mode='bilinear'),
conv_layer(in_dim * 2, in_dim * 2, 3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
conv_layer(in_dim * 2, in_dim, 3, padding=1),
nn.Conv2d(in_dim, in_dim, 1))
# textual projector
out_dim = 1 * in_dim * kernel_size * kernel_size + 1
self.txt = nn.Linear(word_dim, out_dim)
def forward(self, x, word):
'''
x: b, 512, 26, 26
word: b, 512
'''
x = self.vis(x)
B, C, H, W = x.size()
# 1, b*256, 104, 104
x = x.reshape(1, B * C, H, W)
# txt: b, (256*3*3 + 1) -> b, 256, 3, 3 / b
word = self.txt(word)
weight, bias = word[:, :-1], word[:, -1]
weight = weight.reshape(B, C, self.kernel_size, self.kernel_size)
# Conv2d - 1, b*256, 104, 104 -> 1, b, 104, 104
out = F.conv2d(x,
weight,
padding=self.kernel_size // 2,
groups=weight.size(0),
bias=bias)
out = out.transpose(0, 1)
# b, 1, 104, 104
return out

@Deepayan137
Copy link

Hello Derrick,

I had seen this implementation. In your paper, you have mentioned equations 9 and 10 as the contrastive loss between pixel embeddings and the text features. I am not able to understand, how it is taken care of in your above code snippet?

@tiger990111
Copy link

I have the same query. Can the authors please clarify?

@FabianRitter
Copy link

No follow up? looks like supervised learning on the code. I assume something is missing in the code.

@Starboy-at-earth
Copy link

@DerrickWang005 Could you please realse the code snippet of contrastive learning loss?

@clownrat6
Copy link

Actually, the implementation is in line with the description of the paper. However, this is actually not the standard contrastive learning.

@Fake10086
Copy link

you may take a deeper look at codes mentioned by the author above, and you'll find that conv2d actually acts like element wise product between text and image which can be considered as equation 9&10.

@lyu-yx
Copy link

lyu-yx commented Oct 19, 2023

I have the same question, could the authors release the latest version of code? @DerrickWang005

@DerrickWang005
Copy link
Owner

I think this article can answer your question to some extent. @lyu-yx
https://arxiv.org/pdf/2303.15343.pdf

@ccccchenllll
Copy link

I have the same question. I couldn't find the code about contrastive loss.

@Shaosifan
Copy link

I have the same question. I couldn't find the code about contrastive loss.

Me too...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests