diff --git a/datkit/_points.py b/datkit/_points.py index a25c7ed..4c1436d 100644 --- a/datkit/_points.py +++ b/datkit/_points.py @@ -6,6 +6,8 @@ # import numpy as np +import datkit + def index(times, t, ttol=1e-9): """ @@ -47,22 +49,21 @@ def index_near(times, t): """ # Check t is within range if t < times[0]: - dt = sampling_interval(times) - if 2 * (t - times[0]) < dt: + dt = datkit.sampling_interval(times) + if 2 * (times[0] - t) < dt: return 0 raise ValueError( f'Time t is too far outside the provided range: {t} < {times[0]}') elif t > times[-1]: - dt = sampling_interval(times) - if 2 * (times[-1] - t) < dt: + dt = datkit.sampling_interval(times) + if 2 * (t - times[-1]) < dt: return len(times) - 1 raise ValueError( f'Time t is too far outside the provided range: {t} > {times[-1]}') # Find index and return - if lpad is None and rpad is None: - i = np.searchsorted(times, t) # times[i - 1] < t <= times[i] - return i if i == 0 or times[i] - t < t - times[i - 1] else i - 1 + i = np.searchsorted(times, t) # times[i - 1] < t <= times[i] + return i if i == 0 or times[i] - t < t - times[i - 1] else i - 1 def index_on(times, t0, t1, include_left=True, include_right=False): @@ -76,10 +77,17 @@ def index_on(times, t0, t1, include_left=True, include_right=False): If any of the points are outside of range, the interval returned will be smaller or even empty. """ - if t1 <= t0: - raise ValueError('Time t1 must be greater than t0.') + if len(times) < 1: + raise ValueError('Times must contain at least one value.') + if t1 < t0: + raise ValueError('Time t1 must be greater than or equal to t0.') i = np.searchsorted(times, t0) - j = np.searchsorted(times, t1, side='right') + j = np.searchsorted(times, t1) + if (not include_left) and i < len(times) and times[i] == t0: + i += 1 + if include_right and j < len(times) and times[j] == t1: + j += 1 + return i, j def mean_on(times, values, t0, t1, include_left=True, include_right=False): @@ -89,7 +97,7 @@ def mean_on(times, values, t0, t1, include_left=True, include_right=False): By default, the interval is taken as ``t0 <= times < t1``, but this can be customized using ``include_left`` and ``include_right``. """ - i, j = index_interval(times, t0, t1, include_left, include_right) + i, j = index_on(times, t0, t1, include_left, include_right) return np.mean(values[i:j]) diff --git a/datkit/tests/test_points.py b/datkit/tests/test_points.py index 925db3a..64b7446 100755 --- a/datkit/tests/test_points.py +++ b/datkit/tests/test_points.py @@ -77,6 +77,110 @@ def test_index(self): self.assertEqual(d.index(times, 7.3 + 9e-10), 49) self.assertRaisesRegex(ValueError, 'range', d.index, times, 7.3 + 2e-9) + def test_index_near(self): + + # Exact matches + times = np.arange(0, 10) + self.assertEqual(d.index_near(times, 0), 0) + self.assertEqual(d.index_near(times, 4), 4) + self.assertEqual(d.index_near(times, 9), 9) + + # Near matches + self.assertEqual(d.index_near(times, 0.1), 0) + self.assertEqual(d.index_near(times, 2.1), 2) + self.assertEqual(d.index_near(times, 3.5), 3) + + # Outside of range + times = np.arange(0, 20) / 2 + self.assertEqual(d.index_near(times, -0.1), 0) + self.assertEqual(d.index_near(times, -0.24999), 0) + self.assertRaisesRegex( + ValueError, 'range', d.index_near, times, -0.251) + self.assertEqual(d.index_near(times, 9.6), 19) + self.assertEqual(d.index_near(times, 9.7499), 19) + self.assertRaisesRegex(ValueError, 'range', d.index_near, times, 9.751) + + def test_index_on(self): + t = np.arange(0, 10) + self.assertEqual(d.index_on(t, 2, 4), (2, 4)) + self.assertEqual(d.index_on(t, 2, 4.1), (2, 5)) + self.assertEqual(d.index_on(t, 0.1, 5), (1, 5)) + self.assertEqual(d.index_on(t, -5, 4), (0, 4)) + + self.assertEqual(d.index_on(t, 2, 4, include_left=False), (3, 4)) + self.assertEqual(d.index_on(t, -5, 4, include_left=False), (0, 4)) + self.assertEqual(d.index_on(t, 0, 4, include_left=False), (1, 4)) + self.assertEqual(d.index_on(t, 2, 4, include_right=True), (2, 5)) + self.assertEqual(d.index_on(t, 2, 3.9, include_right=True), (2, 4)) + self.assertEqual(d.index_on(t, 2, 100), (2, 10)) + self.assertEqual(d.index_on(t, 2, 100, include_right=True), (2, 10)) + + self.assertEqual(d.index_on(t, 3, 8, True, True), (3, 9)) + self.assertEqual(d.index_on(t, 3, 8, False, True), (4, 9)) + self.assertEqual(d.index_on(t, 3, 8, True, False), (3, 8)) + self.assertEqual(d.index_on(t, 3, 8, False, False), (4, 8)) + + self.assertEqual(d.index_on(t, -3, 88, True, True), (0, 10)) + self.assertEqual(d.index_on(t, -3, 88, False, True), (0, 10)) + self.assertEqual(d.index_on(t, -3, 88, True, False), (0, 10)) + self.assertEqual(d.index_on(t, -3, 88, False, False), (0, 10)) + + self.assertEqual(d.index_on(t, -9, -3, True, True), (0, 0)) + self.assertEqual(d.index_on(t, -9, -3, False, True), (0, 0)) + self.assertEqual(d.index_on(t, -9, -3, True, False), (0, 0)) + self.assertEqual(d.index_on(t, -9, -3, False, False), (0, 0)) + + self.assertEqual(d.index_on(t, 12, 18, True, True), (10, 10)) + self.assertEqual(d.index_on(t, 12, 18, False, True), (10, 10)) + self.assertEqual(d.index_on(t, 12, 18, True, False), (10, 10)) + self.assertEqual(d.index_on(t, 12, 18, False, False), (10, 10)) + + self.assertEqual(d.index_on(t, 0, 0), (0, 0)) + self.assertEqual(d.index_on(t, 4, 4), (4, 4)) + self.assertEqual(d.index_on(t, -4, -4), (0, 0)) + self.assertEqual(d.index_on(t, 10, 10), (10, 10)) + self.assertEqual(d.index_on(t, 12, 12), (10, 10)) + + self.assertRaisesRegex( + ValueError, 'at least one', d.index_on, [], 2, 4) + self.assertEqual(d.index_on([3], 2, 4), (0, 1)) + self.assertEqual(d.index_on([3], 2, 3), (0, 0)) + self.assertRaisesRegex(ValueError, 'greater than', d.index_on, t, 3, 2) + + t = np.arange(4, 40, 2) + self.assertEqual(d.index_on(t, 8, 16), (2, 6)) + t = np.arange(-6, 18, 3) + self.assertEqual(d.index_on(t, -3, 9), (1, 5)) + + def test_mean_on(self): + t = np.arange(1, 11) + print(t) + self.assertEqual(d.mean_on(t, t, 1, 11), 5.5) + self.assertEqual(d.mean_on(t, t, 4, 8), 5.5) + self.assertEqual(d.mean_on(t, t, 4, 8, False), 6) + self.assertEqual(d.mean_on(t, t, 4, 8, True, True), 6) + + def test_value_at(self): + t = np.arange(0, 10) + self.assertEqual(d.value_at(t, t, 0), 0) + self.assertEqual(d.value_at(t, t, 5), 5) + self.assertEqual(d.value_at(t, t, 9), 9) + v = 20 + 2 * t + self.assertEqual(d.value_at(t, v, 0), 20) + self.assertEqual(d.value_at(t, v, 5), 30) + + def test_value_near(self): + t = np.arange(0, 10) + self.assertEqual(d.value_near(t, t, 0), 0) + self.assertEqual(d.value_near(t, t, 5), 5) + self.assertEqual(d.value_near(t, t, 9), 9) + self.assertEqual(d.value_near(t, t, 0.1), 0) + self.assertEqual(d.value_near(t, t, 5.7), 6) + self.assertEqual(d.value_near(t, t, 8.9), 9) + v = 20 + 2 * t + self.assertEqual(d.value_at(t, v, 0), 20) + self.assertEqual(d.value_at(t, v, 5), 30) + if __name__ == '__main__': unittest.main()