-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Open
Description
I was training a multiclass model using focal loss provided by SMP. However, it was proving to be a bottleneck in training as switching torch's CrossEntropyLoss with smp's FocalLoss increased training time by a factor of 6x.
I think this may be caused by the suboptimal multiclass mode implementation where the function iterates over each class and calls the focal_loss_with_logits function seperately. This operation can be vectorised by calling focal_loss_with_logits once for all classes where target would be the one-hot encoding of the segmentation mask.
Metadata
Metadata
Assignees
Labels
No labels