diff --git a/scripts/config_default.yaml b/scripts/config_default.yaml index 8c79934f..aea068d8 100644 --- a/scripts/config_default.yaml +++ b/scripts/config_default.yaml @@ -173,6 +173,26 @@ losses: # nan removal value. # (in prob; heatmaps with max prob values are removed) prob_threshold: 0.05 + # loss = kl loss between heatmaps in successive timepoints + temporal_heatmap_kl: + # weight in front of temporal heatmap loss + log_weight: 7.0 + # for epsilon insensitive rectification + # (in kl loss values; diffs below this are not penalized) + epsilon: 5.0 + # nan removal threshold value. + # (in prob; heatmaps with confidences below this value are removed) + prob_threshold: 0.05 + # loss = mse between heatmaps in successive timepoints + temporal_heatmap_mse: + # weight in front of temporal heatmap loss + log_weight: 5.0 + # for epsilon insensitive rectification + # (in mse values; diffs below this are not penalized) + epsilon: 20.0 + # nan removal threshold value. + # (in prob; heatmaps with confidences below this value are removed) + prob_threshold: 0.05 # loss = mse between model heatmap and idealized gaussian heatmap centered on softargmax unimodal_mse: # weight in front of unimodal_mse loss diff --git a/scripts/configs_mirror-mouse/config_mirror-mouse.yaml b/scripts/configs_mirror-mouse/config_mirror-mouse.yaml index 411c1587..6aaddab1 100644 --- a/scripts/configs_mirror-mouse/config_mirror-mouse.yaml +++ b/scripts/configs_mirror-mouse/config_mirror-mouse.yaml @@ -194,6 +194,7 @@ losses: epsilon: 20.0 # nan removal threshold value. # (in prob; heatmaps with confidences below this value are removed) + prob_threshold: 0.05 # loss = mse between unimodal heatmap and predicted heatmap unimodal_mse: # weight in front of unimodal loss diff --git a/tests/losses/test_losses.py b/tests/losses/test_losses.py index e943ab95..1177b784 100644 --- a/tests/losses/test_losses.py +++ b/tests/losses/test_losses.py @@ -337,6 +337,97 @@ def test_temporal_loss_multi_epsilon_rectification(): assert torch.allclose(rectified[1, :], torch.tensor([0.0, 0.0, 0.0])) +def test_temporal_heatmap_loss(): + + from lightning_pose.losses.losses import TemporalHeatmapLoss + from lightning_pose.data.utils import generate_heatmaps + from kornia.losses import kl_div_loss_2d + + temporal_heatmap_kl_loss = TemporalHeatmapLoss(loss_name="temporal_heatmap_kl", epsilon=0.0) + + # make sure zero is returned for constant predictions (along dim 0) + m = 100 # max pixel + predicted_keypoints = torch.ones(size=(12, 16, 2), device=device) + predicted_keypoints[:, 1] = 2 + predicted_keypoints[:, 2] = 4 + predicted_keypoints[:, 3] = 8 + predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64)) + confidences = torch.ones(size=(12, 16), device=device) + loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage) + assert loss.shape == torch.Size([]) + assert loss == 0.0 + assert logs[0]["name"] == "%s_temporal_heatmap_kl_loss" % stage + assert logs[0]["value"] == loss / temporal_heatmap_kl_loss.weight + assert logs[1]["name"] == "temporal_heatmap_kl_weight" + assert logs[1]["value"] == temporal_heatmap_kl_loss.weight + + # make sure non-negative scalar is returned + predicted_keypoints = torch.rand(size=(12, 16, 2), device=device) + predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64)) + confidences = torch.ones(size=(12, 16), device=device) + loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage) + assert loss.shape == torch.Size([]) + assert loss > 0.0 + + # check against actual kl value + predicted_keypoints = torch.Tensor([[[0.0, 0.0]], [[2.0, 2.0]]]) + predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64)) + confidences = torch.ones(size=(2, 1), device=device) + loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage) + kl = kl_div_loss_2d(predicted_heatmaps[0].unsqueeze(0) + 1e-10, + predicted_heatmaps[1].unsqueeze(0) + 1e-10, + reduction="none") + assert loss.item() - kl < 1e-6 + + # check higher heatmap overlap has lower loss than lower heatmap overlap + predicted_keypoints_close = torch.Tensor( + [[[0.0, 0.0]], [[1.0, 1.0]]], device=device + ) + predicted_keypoints_far = torch.Tensor( + [[[0.0, 0.0]], [[2.0, 2.0]]], device=device + ) + predicted_heatmaps_close = generate_heatmaps(predicted_keypoints_close, height=m, width=m, output_shape=(64, 64)) + predicted_heatmaps_far = generate_heatmaps(predicted_keypoints_far, height=m, width=m, output_shape=(64, 64)) + confidences = torch.ones(size=(2, 1), device=device) + loss_close, _ = temporal_heatmap_kl_loss(predicted_heatmaps_close, confidences=confidences, stage=stage) + loss_far, _ = temporal_heatmap_kl_loss(predicted_heatmaps_far, confidences=confidences, stage=stage) + assert loss_close.item() < loss_far.item() + + # test epsilon + s2 = 1.0 + s3 = 4.0 + predicted_keypoints = torch.Tensor( + [[[0.0, 0.0]], [[s2, s2]], [[s3 + s2, s3 + s2]]], device=device + ) + predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64)) + confidences = torch.ones(size=(3, 1), device=device) + # [s2, s2] -> ~ 0.3816 + # [s3, s3] -> ~ 9.9069 + loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage) + kl_s2 = kl_div_loss_2d(predicted_heatmaps[0].unsqueeze(0) + 1e-10, + predicted_heatmaps[1].unsqueeze(0) + 1e-10, + reduction="none") + kl_s3 = kl_div_loss_2d(predicted_heatmaps[1].unsqueeze(0) + 1e-10, + predicted_heatmaps[2].unsqueeze(0) + 1e-10, + reduction="none") + assert loss.item() - (kl_s2 + kl_s3) < 1e-6 + + temporal_heatmap_kl_loss = TemporalHeatmapLoss(loss_name="temporal_heatmap_kl", epsilon=3.0) + loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage) + # due to epsilon the "s2" entry will be zeroed out + assert (loss.item() - kl_s3) < 1e-6 + + # check confidence masking working properly + predicted_keypoints = torch.Tensor( + [[[0.0, 0.0]], [[2.0, 2.0]]], device=device + ) + predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64)) + confidences = torch.ones(size=(2, 1), device=device) + confidences[1, 0] = 0.0 + loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage, prob_threshold=0.05) + assert loss.item() == 0 + + def test_unimodal_mse_loss(): from lightning_pose.losses.losses import UnimodalLoss