diff --git a/pytorch_lattice/layers/numerical_calibrator.py b/pytorch_lattice/layers/numerical_calibrator.py
index 6a15a6b..3283619 100644
--- a/pytorch_lattice/layers/numerical_calibrator.py
+++ b/pytorch_lattice/layers/numerical_calibrator.py
@@ -151,16 +151,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
             torch.Tensor of shape `(batch_size, 1)` containing calibrated input values.
         """
         if self.input_keypoints_type == InputKeypointsType.LEARNED:
-            softmaxed_logits = torch.nn.functional.softmax(
-                self._interpolation_logits, dim=-1
-            )
-            self._lengths = softmaxed_logits * self._keypoint_range
-            interior_keypoints = (
-                torch.cumsum(self._lengths, dim=-1) + self._keypoint_min
-            )
-            self._interpolation_keypoints = torch.cat(
-                [torch.tensor([self._keypoint_min]), interior_keypoints[:-1]]
-            )
+            self._calculate_lengths_and_interpolation_keypoints()
 
         interpolation_weights = (x - self._interpolation_keypoints) / self._lengths
         interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0))
@@ -302,6 +293,9 @@ def assert_constraints(self, eps: float = 1e-6) -> list[str]:
     @torch.no_grad()
     def keypoints_inputs(self) -> torch.Tensor:
         """Returns tensor of keypoint inputs."""
+        if self.input_keypoints_type == InputKeypointsType.LEARNED:
+            self._calculate_lengths_and_interpolation_keypoints()
+
         return torch.cat(
             (
                 self._interpolation_keypoints,
@@ -428,3 +422,21 @@ def _squeeze_by_scaling(
         if decreasing:
             bias, heights = -bias, -heights
         return bias, heights
+
+    def _calculate_lengths_and_interpolation_keypoints(self) -> None:
+        """Makes necessary updates according to most recent self._interpolation_logits.
+
+        If running the layer with `InputKeyPointType.LEARNED.`, this method will ensure
+        `self._interpolation_keypoints` and `self._lengths` are correctly updated with
+        regards to the most recent iteration of `self._interpolation_logits`.
+        """
+        softmaxed_logits = torch.nn.functional.softmax(
+            self._interpolation_logits, dim=-1
+        )
+        self._lengths = softmaxed_logits * self._keypoint_range
+        interior_keypoints = (torch.cumsum(self._lengths, dim=-1) + self._keypoint_min)[
+            :-1
+        ]
+        self._interpolation_keypoints = torch.cat(
+            [torch.tensor([self._keypoint_min]), interior_keypoints]
+        )