Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change of bin loss computation to avoid learning from empty annotations. #1011

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/lib/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,20 @@ def compute_res_loss(output, target):

# TODO: weight
def compute_bin_loss(output, target, mask):
mask = mask.expand_as(output)
output = output * mask.float()
return F.cross_entropy(output, target, reduction='elementwise_mean')
"""
Compute loss from classifying if angle is in this bin or in the other bin with cross entropy.
Use the prediction if its in this bin (=1) AND if its in the other bin (=0) because bins overlap
and the angle can be in both bins.
Mask predictions by wether or not the annotations even have a gt angle labeled or not.
Don't learn from making a prediction when there is no target.
"""
nonzero_idx = mask.nonzero()[:,0].long()
if nonzero_idx.shape[0] > 0: # if there are any annotations with a labeled angle
output_nz = output.index_select(0, nonzero_idx)
target_nz = target.index_select(0, nonzero_idx)
return F.cross_entropy(output_nz, target_nz, reduction='mean')
else: # loss would be NaN if computed normally when no annotation is given
return torch.tensor(0.0).cuda() # set to different grad_fn but not relevant since loss is zero

def compute_rot_loss(output, target_bin, target_res, mask):
# output: (B, 128, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos,
Expand Down