diff --git a/pyproject.toml b/pyproject.toml index c455fcc..174415e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ keywords=["stac", "pydantic", "validation"] authors=[{ name = "Arturo Engineering", email = "engineering@arturo.ai"}] license= { text = "MIT" } requires-python=">=3.8" -dependencies = ["click>=8.1.7", "pydantic>=2.4.1", "geojson-pydantic>=1.0.0"] +dependencies = ["click>=8.1.7", "pydantic>=2.4.1", "geojson-pydantic>=1.0.0", "ciso8601~=2.3"] dynamic = ["version", "readme"] [project.scripts] diff --git a/stac_pydantic/api/search.py b/stac_pydantic/api/search.py index bc32099..a005369 100644 --- a/stac_pydantic/api/search.py +++ b/stac_pydantic/api/search.py @@ -1,6 +1,7 @@ from datetime import datetime as dt from typing import Any, Dict, List, Optional, Tuple, Union, cast +from ciso8601 import parse_rfc3339 from geojson_pydantic.geometries import ( # type: ignore GeometryCollection, LineString, @@ -16,7 +17,6 @@ from stac_pydantic.api.extensions.query import Operator from stac_pydantic.api.extensions.sort import SortExtension from stac_pydantic.shared import BBox -from stac_pydantic.utils import parse_datetime Intersection = Union[ Point, @@ -50,16 +50,16 @@ def start_date(self) -> Optional[dt]: return None if values[0] == ".." or values[0] == "": return None - return parse_datetime(values[0]) + return parse_rfc3339(values[0]) @property def end_date(self) -> Optional[dt]: values = (self.datetime or "").split("/") if len(values) == 1: - return parse_datetime(values[0]) + return parse_rfc3339(values[0]) if values[1] == ".." or values[1] == "": return None - return parse_datetime(values[1]) + return parse_rfc3339(values[1]) # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @model_validator(mode="before") @@ -109,17 +109,18 @@ def validate_datetime(cls, v: str) -> str: # Single date is interpreted as end date values = ["..", v] - dates = [] + dates: List[dt] = [] for value in values: if value == ".." or value == "": - dates.append("..") continue - parse_datetime(value) - dates.append(value) + dates.append(parse_rfc3339(value)) - if ".." not in dates: - if parse_datetime(dates[0]) > parse_datetime(dates[1]): + if len(values) > 2: + raise ValueError("Invalid datetime range, must match format (begin_date, end_date)") + + if not {"..", ""}.intersection(set(values)): + if dates[0] > dates[1]: raise ValueError( "Invalid datetime range, must match format (begin_date, end_date)" ) diff --git a/stac_pydantic/item.py b/stac_pydantic/item.py index 8d958b9..452b51a 100644 --- a/stac_pydantic/item.py +++ b/stac_pydantic/item.py @@ -1,6 +1,7 @@ from datetime import datetime as dt from typing import Any, Dict, List, Optional, Union +from ciso8601 import parse_rfc3339 from geojson_pydantic import Feature from pydantic import ( AnyUrl, @@ -19,7 +20,6 @@ StacBaseModel, StacCommonMetadata, ) -from stac_pydantic.utils import parse_datetime from stac_pydantic.version import STAC_VERSION @@ -47,13 +47,13 @@ def validate_datetime(cls, data: Dict[str, Any]) -> Dict[str, Any]: ) if isinstance(datetime, str): - data["datetime"] = parse_datetime(datetime) + data["datetime"] = parse_rfc3339(datetime) if isinstance(start_datetime, str): - data["start_datetime"] = parse_datetime(start_datetime) + data["start_datetime"] = parse_rfc3339(start_datetime) if isinstance(end_datetime, str): - data["end_datetime"] = parse_datetime(end_datetime) + data["end_datetime"] = parse_rfc3339(end_datetime) return data diff --git a/stac_pydantic/utils.py b/stac_pydantic/utils.py index ef8cf51..1bdca60 100644 --- a/stac_pydantic/utils.py +++ b/stac_pydantic/utils.py @@ -1,9 +1,5 @@ -import json -from datetime import datetime from enum import Enum -from typing import Any, Callable, List - -from pydantic import TypeAdapter +from typing import Any, List class AutoValueEnum(Enum): @@ -11,8 +7,3 @@ def _generate_next_value_( # type: ignore name: str, start: int, count: int, last_values: List[Any] ) -> Any: return name - - -parse_datetime: Callable[[Any], datetime] = lambda x: TypeAdapter( - datetime -).validate_json(json.dumps(x)) diff --git a/tests/api/test_search.py b/tests/api/test_search.py index fc04a9c..ab44566 100644 --- a/tests/api/test_search.py +++ b/tests/api/test_search.py @@ -1,5 +1,5 @@ import time -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta import pytest from pydantic import ValidationError @@ -93,6 +93,12 @@ def test_invalid_temporal_search(): with pytest.raises(ValidationError): Search(collections=["collection1"], datetime=utcnow) + t1 = datetime.utcnow() + t2 = t1 + timedelta(seconds=100) + t3 = t2 + timedelta(seconds=100) + with pytest.raises(ValidationError): + Search(collections=["collection1"], datetime=f"{t1.strftime(DATETIME_RFC339)}/{t2.strftime(DATETIME_RFC339)}/{t3.strftime(DATETIME_RFC339)}",) + # End date is before start date start = datetime.utcnow() time.sleep(2)