Skip to content

Commit

Permalink
Saved at 2025-01-04 10:51:01 (Sat)
Browse files Browse the repository at this point in the history
  • Loading branch information
dycw committed Jan 4, 2025
1 parent fd056fb commit e4d68e4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 28 deletions.
45 changes: 25 additions & 20 deletions src/tests/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
text_clean,
zoned_datetimes,
)
from utilities.math import is_integral
from utilities.math import MAX_INT32, MIN_INT32, is_integral, round_to_float
from utilities.zoneinfo import UTC, HongKong, Tokyo

if TYPE_CHECKING:
Expand All @@ -129,6 +129,7 @@
from utilities.types import Number


@mark.skip
class TestAddDuration:
@given(date=dates(), days=integers())
def test_date(self, *, date: dt.date, days: int) -> None:
Expand Down Expand Up @@ -215,7 +216,6 @@ def test_datetime(self, *, datetime: dt.datetime) -> None:
check_zoned_datetime(datetime)


@mark.only
class TestDateDurationToInt:
@given(n=integers())
def test_int(self, *, n: int) -> None:
Expand Down Expand Up @@ -253,7 +253,6 @@ def test_error_timedelta(self, *, n: int, frac: dt.timedelta) -> None:
_ = date_duration_to_int(timedelta)


@mark.only
class TestDateDurationToTimeDelta:
@given(n=integers())
def test_int(self, *, n: int) -> None:
Expand Down Expand Up @@ -310,28 +309,34 @@ def test_main(self, *, date: dt.date) -> None:


class TestDateTimeDurationToFloat:
@given(duration=integers(0, 10) | floats(0.0, 10.0))
def test_number(self, *, duration: Number) -> None:
result = datetime_duration_to_float(duration)
assert result == duration
@given(n=int32s())
def test_int(self, *, n: int) -> None:
result = datetime_duration_to_float(n)
assert result == n

@given(duration=timedeltas())
def test_timedelta(self, *, duration: dt.timedelta) -> None:
result = datetime_duration_to_float(duration)
assert result == duration.total_seconds()
@given(n=floats(allow_nan=False, allow_infinity=False))
def test_float(self, *, n: Number) -> None:
result = datetime_duration_to_float(n)
assert result == n

@given(timedelta=timedeltas())
def test_timedelta(self, *, timedelta: dt.timedelta) -> None:
result = datetime_duration_to_float(timedelta)
assert result == timedelta.total_seconds()


class TestDateTimeDurationToTimedelta:
@given(duration=integers(0, 10))
def test_int(self, *, duration: int) -> None:
result = datetime_duration_to_timedelta(duration)
assert result.total_seconds() == duration
@given(n=int32s())
def test_int(self, *, n: int) -> None:
result = datetime_duration_to_timedelta(n)
assert result.total_seconds() == n

@given(duration=floats(0.0, 10.0))
def test_float(self, *, duration: float) -> None:
duration = round(10 * duration) / 10
result = datetime_duration_to_timedelta(duration)
assert isclose(result.total_seconds(), duration)
@given(n=floats(min_value=MIN_INT32, max_value=MAX_INT32))
def test_float(self, *, n: float) -> None:
n = round_to_float(n, 1e-6)
with assume_does_not_raise(OverflowError):
result = datetime_duration_to_timedelta(n)
assert isclose(result.total_seconds(), n)

@given(duration=timedeltas())
def test_timedelta(self, *, duration: dt.timedelta) -> None:
Expand Down
24 changes: 16 additions & 8 deletions src/utilities/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,26 @@ def __str__(self) -> str:

def datetime_duration_to_float(duration: Duration, /) -> float:
"""Ensure a datetime duration is a float."""
if isinstance(duration, int):
return float(duration)
if isinstance(duration, float):
return duration
return duration.total_seconds()
match duration:
case int():
return float(duration)
case float():
return duration
case dt.timedelta():
return duration.total_seconds()
case _ as never: # pyright: ignore[reportUnnecessaryComparison]
assert_never(never)


def datetime_duration_to_timedelta(duration: Duration, /) -> dt.timedelta:
"""Ensure a datetime duration is a timedelta."""
if isinstance(duration, int | float):
return dt.timedelta(seconds=duration)
return duration
match duration:
case int() | float():
return dt.timedelta(seconds=duration)
case dt.timedelta():
return duration
case _ as never: # pyright: ignore[reportUnnecessaryComparison]
assert_never(never)


##
Expand Down

0 comments on commit e4d68e4

Please sign in to comment.