Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

objectmask loss throws error for VPGNet_v2 model #6

Open
sandeepnmenon opened this issue Mar 10, 2021 · 1 comment
Open

objectmask loss throws error for VPGNet_v2 model #6

sandeepnmenon opened this issue Mar 10, 2021 · 1 comment

Comments

@sandeepnmenon
Copy link

The target and output tensors for the objectmast task is [batch_size, 2, 96, 128]
This fails in the loss calculation with CrossEntropy

loss = loss_fn(pred, gt[:,:,:,:])

With the error

    loss = loss_fn(pred, gt[:,:,:,:])
  File "/home/deepenai/SandeepMenon/venv-e3d/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/deepenai/SandeepMenon/venv-e3d/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 962, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "/home/deepenai/SandeepMenon/venv-e3d/lib/python3.6/site-packages/torch/nn/functional.py", line 2468, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/home/deepenai/SandeepMenon/venv-e3d/lib/python3.6/site-packages/torch/nn/functional.py", line 2266, in nll_loss
    ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [4, 2, 96, 128]
@sandeepnmenon
Copy link
Author

As per this discussion (https://discuss.pytorch.org/t/runtimeerror-1only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-size-1-3-96-128/95030)
The target tensor for a multi-class segmentation use case using nn.CrossEntropyLoss or nn.NLLLoss should have the shape [batch_size, height, width] and contain the class indices in the range [0, nb_classes-1].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant