diff --git a/src/tests/test_datetime.py b/src/tests/test_datetime.py index d7c4fb9cd..3662b112a 100644 --- a/src/tests/test_datetime.py +++ b/src/tests/test_datetime.py @@ -120,7 +120,7 @@ text_clean, zoned_datetimes, ) -from utilities.math import is_integral +from utilities.math import MAX_INT32, MIN_INT32, is_integral, round_to_float from utilities.zoneinfo import UTC, HongKong, Tokyo if TYPE_CHECKING: @@ -129,6 +129,7 @@ from utilities.types import Number +@mark.skip class TestAddDuration: @given(date=dates(), days=integers()) def test_date(self, *, date: dt.date, days: int) -> None: @@ -215,7 +216,6 @@ def test_datetime(self, *, datetime: dt.datetime) -> None: check_zoned_datetime(datetime) -@mark.only class TestDateDurationToInt: @given(n=integers()) def test_int(self, *, n: int) -> None: @@ -253,7 +253,6 @@ def test_error_timedelta(self, *, n: int, frac: dt.timedelta) -> None: _ = date_duration_to_int(timedelta) -@mark.only class TestDateDurationToTimeDelta: @given(n=integers()) def test_int(self, *, n: int) -> None: @@ -310,28 +309,34 @@ def test_main(self, *, date: dt.date) -> None: class TestDateTimeDurationToFloat: - @given(duration=integers(0, 10) | floats(0.0, 10.0)) - def test_number(self, *, duration: Number) -> None: - result = datetime_duration_to_float(duration) - assert result == duration + @given(n=int32s()) + def test_int(self, *, n: int) -> None: + result = datetime_duration_to_float(n) + assert result == n - @given(duration=timedeltas()) - def test_timedelta(self, *, duration: dt.timedelta) -> None: - result = datetime_duration_to_float(duration) - assert result == duration.total_seconds() + @given(n=floats(allow_nan=False, allow_infinity=False)) + def test_float(self, *, n: Number) -> None: + result = datetime_duration_to_float(n) + assert result == n + + @given(timedelta=timedeltas()) + def test_timedelta(self, *, timedelta: dt.timedelta) -> None: + result = datetime_duration_to_float(timedelta) + assert result == timedelta.total_seconds() class TestDateTimeDurationToTimedelta: - @given(duration=integers(0, 10)) - def test_int(self, *, duration: int) -> None: - result = datetime_duration_to_timedelta(duration) - assert result.total_seconds() == duration + @given(n=int32s()) + def test_int(self, *, n: int) -> None: + result = datetime_duration_to_timedelta(n) + assert result.total_seconds() == n - @given(duration=floats(0.0, 10.0)) - def test_float(self, *, duration: float) -> None: - duration = round(10 * duration) / 10 - result = datetime_duration_to_timedelta(duration) - assert isclose(result.total_seconds(), duration) + @given(n=floats(min_value=MIN_INT32, max_value=MAX_INT32)) + def test_float(self, *, n: float) -> None: + n = round_to_float(n, 1e-6) + with assume_does_not_raise(OverflowError): + result = datetime_duration_to_timedelta(n) + assert isclose(result.total_seconds(), n) @given(duration=timedeltas()) def test_timedelta(self, *, duration: dt.timedelta) -> None: diff --git a/src/utilities/datetime.py b/src/utilities/datetime.py index af74dfc49..c319ae0ff 100644 --- a/src/utilities/datetime.py +++ b/src/utilities/datetime.py @@ -242,18 +242,26 @@ def __str__(self) -> str: def datetime_duration_to_float(duration: Duration, /) -> float: """Ensure a datetime duration is a float.""" - if isinstance(duration, int): - return float(duration) - if isinstance(duration, float): - return duration - return duration.total_seconds() + match duration: + case int(): + return float(duration) + case float(): + return duration + case dt.timedelta(): + return duration.total_seconds() + case _ as never: # pyright: ignore[reportUnnecessaryComparison] + assert_never(never) def datetime_duration_to_timedelta(duration: Duration, /) -> dt.timedelta: """Ensure a datetime duration is a timedelta.""" - if isinstance(duration, int | float): - return dt.timedelta(seconds=duration) - return duration + match duration: + case int() | float(): + return dt.timedelta(seconds=duration) + case dt.timedelta(): + return duration + case _ as never: # pyright: ignore[reportUnnecessaryComparison] + assert_never(never) ##