diff --git a/pytorch_lattice/layers/numerical_calibrator.py b/pytorch_lattice/layers/numerical_calibrator.py index 6a15a6b..3283619 100644 --- a/pytorch_lattice/layers/numerical_calibrator.py +++ b/pytorch_lattice/layers/numerical_calibrator.py @@ -151,16 +151,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor of shape `(batch_size, 1)` containing calibrated input values. """ if self.input_keypoints_type == InputKeypointsType.LEARNED: - softmaxed_logits = torch.nn.functional.softmax( - self._interpolation_logits, dim=-1 - ) - self._lengths = softmaxed_logits * self._keypoint_range - interior_keypoints = ( - torch.cumsum(self._lengths, dim=-1) + self._keypoint_min - ) - self._interpolation_keypoints = torch.cat( - [torch.tensor([self._keypoint_min]), interior_keypoints[:-1]] - ) + self._calculate_lengths_and_interpolation_keypoints() interpolation_weights = (x - self._interpolation_keypoints) / self._lengths interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0)) @@ -302,6 +293,9 @@ def assert_constraints(self, eps: float = 1e-6) -> list[str]: @torch.no_grad() def keypoints_inputs(self) -> torch.Tensor: """Returns tensor of keypoint inputs.""" + if self.input_keypoints_type == InputKeypointsType.LEARNED: + self._calculate_lengths_and_interpolation_keypoints() + return torch.cat( ( self._interpolation_keypoints, @@ -428,3 +422,21 @@ def _squeeze_by_scaling( if decreasing: bias, heights = -bias, -heights return bias, heights + + def _calculate_lengths_and_interpolation_keypoints(self) -> None: + """Makes necessary updates according to most recent self._interpolation_logits. + + If running the layer with `InputKeyPointType.LEARNED.`, this method will ensure + `self._interpolation_keypoints` and `self._lengths` are correctly updated with + regards to the most recent iteration of `self._interpolation_logits`. + """ + softmaxed_logits = torch.nn.functional.softmax( + self._interpolation_logits, dim=-1 + ) + self._lengths = softmaxed_logits * self._keypoint_range + interior_keypoints = (torch.cumsum(self._lengths, dim=-1) + self._keypoint_min)[ + :-1 + ] + self._interpolation_keypoints = torch.cat( + [torch.tensor([self._keypoint_min]), interior_keypoints] + )