Skip to content

Commit

Permalink
Fix typing of DataFileInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
lemonyte committed Jan 22, 2025
1 parent 46bcbea commit 397c383
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ repos:
- id: ruff-format

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.369
rev: v1.1.392.post0
hooks:
- id: pyright
12 changes: 7 additions & 5 deletions src/ferry_planner/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Annotated
from typing import Annotated, Generic, TypeVar

from pydantic import AfterValidator, BaseModel, field_serializer, field_validator
from pydantic_settings import (
Expand All @@ -12,6 +12,8 @@
from ferry_planner.connection import AirConnection, BusConnection, CarConnection, Connection, FerryConnection
from ferry_planner.location import Airport, BusStop, City, Location, Terminal

DataFileT = TypeVar("DataFileT", bound=Location | Connection)


def check_is_file(path: Path, /) -> Path:
if not path.is_file():
Expand Down Expand Up @@ -45,9 +47,9 @@ def check_is_dir(path: Path, /) -> Path:
}


class DataFileInfo(BaseModel):
class DataFileInfo(BaseModel, Generic[DataFileT]):
path: FilePath
cls: type[Location | Connection]
cls: type[DataFileT]

@field_serializer("cls", when_used="json")
def _serialize_cls(self, value: type[Location | Connection]) -> str:
Expand All @@ -68,8 +70,8 @@ def _validate_cls(cls, value: str | type | None) -> type[Location | Connection]:


class DataConfig(BaseModel):
location_files: tuple[DataFileInfo, ...]
connection_files: tuple[DataFileInfo, ...]
location_files: tuple[DataFileInfo[Location], ...]
connection_files: tuple[DataFileInfo[Connection], ...]


class SchedulesConfig(BaseModel):
Expand Down
39 changes: 19 additions & 20 deletions src/ferry_planner/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import json
from typing import TYPE_CHECKING, Any, TypeVar

from ferry_planner.location import Location, LocationId

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, MutableMapping, Sequence

from ferry_planner.config import DataFileInfo
from ferry_planner.config import DataFileInfo, DataFileT
from ferry_planner.connection import Connection, ConnectionId
from ferry_planner.location import Location, LocationId

LocationT = TypeVar("LocationT", bound=Location)
OriginT = TypeVar("OriginT", bound=Location)
DestinationT = TypeVar("DestinationT", bound=Location)
LocationT = TypeVar("LocationT", bound="Location")
OriginT = TypeVar("OriginT", bound="Location")
DestinationT = TypeVar("DestinationT", bound="Location")


class LocationNotFoundError(Exception):
Expand All @@ -28,7 +27,12 @@ def __init__(self, connection_id: ConnectionId, *args: Iterable) -> None:
super().__init__(f"Connection not found with ID {connection_id}", *args)


def load_from_json(data_file: DataFileInfo, /, *, context: dict[str, Any] | None = None) -> Iterator[DataFileInfo.cls]:
def load_from_json(
data_file: DataFileInfo[DataFileT],
/,
*,
context: dict[str, Any] | None = None,
) -> Iterator[DataFileT]:
data = json.loads(data_file.path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
msg = f"data file '{data_file.path}' must contain a dictionary"
Expand All @@ -41,10 +45,8 @@ def __init__(self, locations: Iterable[Location], /) -> None:
self._locations: dict[LocationId, Location] = {location.id: location for location in locations}

@classmethod
def from_files(cls, data_files: Sequence[DataFileInfo], /) -> LocationDB:
locations = []
for data_file in data_files:
locations.extend(load_from_json(data_file))
def from_files(cls, data_files: Sequence[DataFileInfo[Location]], /) -> LocationDB:
locations = [location for data_file in data_files for location in load_from_json(data_file)]
return cls(locations)

def dict(self) -> MutableMapping[LocationId, Location]:
Expand All @@ -65,15 +67,12 @@ def __init__(self, connections: Iterable[Connection], /) -> None:
self._connections: dict[ConnectionId, Connection] = {connection.id: connection for connection in connections}

@classmethod
def from_files(cls, data_files: Sequence[DataFileInfo], /, *, location_db: LocationDB) -> ConnectionDB:
connections = []
for data_file in data_files:
connections.extend(
load_from_json(
data_file,
context={"location_db": location_db},
),
)
def from_files(cls, data_files: Sequence[DataFileInfo[Connection]], /, *, location_db: LocationDB) -> ConnectionDB:
connections = [
connection
for data_file in data_files
for connection in load_from_json(data_file, context={"location_db": location_db})
]
return cls(connections)

def dict(self) -> MutableMapping[ConnectionId, Connection[Location, Location]]:
Expand Down

0 comments on commit 397c383

Please sign in to comment.