Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: fixed keypoints inputs not updating to most recent logits #13

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions pytorch_lattice/layers/numerical_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor of shape `(batch_size, 1)` containing calibrated input values.
"""
if self.input_keypoints_type == InputKeypointsType.LEARNED:
softmaxed_logits = torch.nn.functional.softmax(
self._interpolation_logits, dim=-1
)
self._lengths = softmaxed_logits * self._keypoint_range
interior_keypoints = (
torch.cumsum(self._lengths, dim=-1) + self._keypoint_min
)
self._interpolation_keypoints = torch.cat(
[torch.tensor([self._keypoint_min]), interior_keypoints[:-1]]
)
self._calculate_lengths_and_interpolation_keypoints()

interpolation_weights = (x - self._interpolation_keypoints) / self._lengths
interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0))
Expand Down Expand Up @@ -302,6 +293,9 @@ def assert_constraints(self, eps: float = 1e-6) -> list[str]:
@torch.no_grad()
def keypoints_inputs(self) -> torch.Tensor:
"""Returns tensor of keypoint inputs."""
if self.input_keypoints_type == InputKeypointsType.LEARNED:
self._calculate_lengths_and_interpolation_keypoints()

return torch.cat(
(
self._interpolation_keypoints,
Expand Down Expand Up @@ -428,3 +422,21 @@ def _squeeze_by_scaling(
if decreasing:
bias, heights = -bias, -heights
return bias, heights

def _calculate_lengths_and_interpolation_keypoints(self) -> None:
"""Makes necessary updates according to most recent self._interpolation_logits.

If running the layer with `InputKeyPointType.LEARNED.`, this method will ensure
`self._interpolation_keypoints` and `self._lengths` are correctly updated with
regards to the most recent iteration of `self._interpolation_logits`.
"""
softmaxed_logits = torch.nn.functional.softmax(
self._interpolation_logits, dim=-1
)
self._lengths = softmaxed_logits * self._keypoint_range
interior_keypoints = (torch.cumsum(self._lengths, dim=-1) + self._keypoint_min)[
:-1
]
self._interpolation_keypoints = torch.cat(
[torch.tensor([self._keypoint_min]), interior_keypoints]
)