Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Temporal heatmap tests + small config fixes #71

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions scripts/config_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,26 @@ losses:
# nan removal value.
# (in prob; heatmaps with max prob values are removed)
prob_threshold: 0.05
# loss = kl loss between heatmaps in successive timepoints
temporal_heatmap_kl:
# weight in front of temporal heatmap loss
log_weight: 7.0
# for epsilon insensitive rectification
# (in kl loss values; diffs below this are not penalized)
epsilon: 5.0
# nan removal threshold value.
# (in prob; heatmaps with confidences below this value are removed)
prob_threshold: 0.05
# loss = mse between heatmaps in successive timepoints
temporal_heatmap_mse:
# weight in front of temporal heatmap loss
log_weight: 5.0
# for epsilon insensitive rectification
# (in mse values; diffs below this are not penalized)
epsilon: 20.0
# nan removal threshold value.
# (in prob; heatmaps with confidences below this value are removed)
prob_threshold: 0.05
# loss = mse between model heatmap and idealized gaussian heatmap centered on softargmax
unimodal_mse:
# weight in front of unimodal_mse loss
Expand Down
1 change: 1 addition & 0 deletions scripts/configs_mirror-mouse/config_mirror-mouse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ losses:
epsilon: 20.0
# nan removal threshold value.
# (in prob; heatmaps with confidences below this value are removed)
prob_threshold: 0.05
# loss = mse between unimodal heatmap and predicted heatmap
unimodal_mse:
# weight in front of unimodal loss
Expand Down
91 changes: 91 additions & 0 deletions tests/losses/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,97 @@ def test_temporal_loss_multi_epsilon_rectification():
assert torch.allclose(rectified[1, :], torch.tensor([0.0, 0.0, 0.0]))


def test_temporal_heatmap_loss():

from lightning_pose.losses.losses import TemporalHeatmapLoss
from lightning_pose.data.utils import generate_heatmaps
from kornia.losses import kl_div_loss_2d

temporal_heatmap_kl_loss = TemporalHeatmapLoss(loss_name="temporal_heatmap_kl", epsilon=0.0)

# make sure zero is returned for constant predictions (along dim 0)
m = 100 # max pixel
predicted_keypoints = torch.ones(size=(12, 16, 2), device=device)
predicted_keypoints[:, 1] = 2
predicted_keypoints[:, 2] = 4
predicted_keypoints[:, 3] = 8
predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64))
confidences = torch.ones(size=(12, 16), device=device)
loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage)
assert loss.shape == torch.Size([])
assert loss == 0.0
assert logs[0]["name"] == "%s_temporal_heatmap_kl_loss" % stage
assert logs[0]["value"] == loss / temporal_heatmap_kl_loss.weight
assert logs[1]["name"] == "temporal_heatmap_kl_weight"
assert logs[1]["value"] == temporal_heatmap_kl_loss.weight

# make sure non-negative scalar is returned
predicted_keypoints = torch.rand(size=(12, 16, 2), device=device)
predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64))
confidences = torch.ones(size=(12, 16), device=device)
loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage)
assert loss.shape == torch.Size([])
assert loss > 0.0

# check against actual kl value
predicted_keypoints = torch.Tensor([[[0.0, 0.0]], [[2.0, 2.0]]])
predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64))
confidences = torch.ones(size=(2, 1), device=device)
loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage)
kl = kl_div_loss_2d(predicted_heatmaps[0].unsqueeze(0) + 1e-10,
predicted_heatmaps[1].unsqueeze(0) + 1e-10,
reduction="none")
assert loss.item() - kl < 1e-6

# check higher heatmap overlap has lower loss than lower heatmap overlap
predicted_keypoints_close = torch.Tensor(
[[[0.0, 0.0]], [[1.0, 1.0]]], device=device
)
predicted_keypoints_far = torch.Tensor(
[[[0.0, 0.0]], [[2.0, 2.0]]], device=device
)
predicted_heatmaps_close = generate_heatmaps(predicted_keypoints_close, height=m, width=m, output_shape=(64, 64))
predicted_heatmaps_far = generate_heatmaps(predicted_keypoints_far, height=m, width=m, output_shape=(64, 64))
confidences = torch.ones(size=(2, 1), device=device)
loss_close, _ = temporal_heatmap_kl_loss(predicted_heatmaps_close, confidences=confidences, stage=stage)
loss_far, _ = temporal_heatmap_kl_loss(predicted_heatmaps_far, confidences=confidences, stage=stage)
assert loss_close.item() < loss_far.item()

# test epsilon
s2 = 1.0
s3 = 4.0
predicted_keypoints = torch.Tensor(
[[[0.0, 0.0]], [[s2, s2]], [[s3 + s2, s3 + s2]]], device=device
)
predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64))
confidences = torch.ones(size=(3, 1), device=device)
# [s2, s2] -> ~ 0.3816
# [s3, s3] -> ~ 9.9069
loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage)
kl_s2 = kl_div_loss_2d(predicted_heatmaps[0].unsqueeze(0) + 1e-10,
predicted_heatmaps[1].unsqueeze(0) + 1e-10,
reduction="none")
kl_s3 = kl_div_loss_2d(predicted_heatmaps[1].unsqueeze(0) + 1e-10,
predicted_heatmaps[2].unsqueeze(0) + 1e-10,
reduction="none")
assert loss.item() - (kl_s2 + kl_s3) < 1e-6

temporal_heatmap_kl_loss = TemporalHeatmapLoss(loss_name="temporal_heatmap_kl", epsilon=3.0)
loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage)
# due to epsilon the "s2" entry will be zeroed out
assert (loss.item() - kl_s3) < 1e-6

# check confidence masking working properly
predicted_keypoints = torch.Tensor(
[[[0.0, 0.0]], [[2.0, 2.0]]], device=device
)
predicted_heatmaps = generate_heatmaps(predicted_keypoints, height=m, width=m, output_shape=(64, 64))
confidences = torch.ones(size=(2, 1), device=device)
confidences[1, 0] = 0.0
loss, logs = temporal_heatmap_kl_loss(predicted_heatmaps, confidences=confidences, stage=stage, prob_threshold=0.05)
assert loss.item() == 0


def test_unimodal_mse_loss():

from lightning_pose.losses.losses import UnimodalLoss
Expand Down