Skip to content

Commit

Permalink
Merge pull request #23 from govlt/allow-setting-geometry-output
Browse files Browse the repository at this point in the history
Allow setting geometry output
  • Loading branch information
vycius authored Jul 9, 2024
2 parents 9eef2bc + 770b8c2 commit d8cf4af
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 35 deletions.
19 changes: 19 additions & 0 deletions api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from fastapi import Query
from fastapi.openapi.models import Example

import schemas

query_srid: Query = Query(
3346,
openapi_examples={
Expand All @@ -20,6 +22,23 @@
description="A spatial reference identifier (SRID) for geometry output."
)

query_geometry_output_type: schemas.GeometryOutputFormat = Query(
schemas.GeometryOutputFormat.ewkt,
openapi_examples={
"example_ewkt": {
"summary": "EWKT",
"description": "Extended Well-Known Text (EWKT) format for representing geometric data.",
"value": schemas.GeometryOutputFormat.ewkt
},
"example_ewkb": {
"summary": "EWKB",
"description": "Extended Well-Known Binary (EWKB) format for representing geometric data.",
"value": schemas.GeometryOutputFormat.ewkb
},
},
description="Specify the format for geometry output."
)

openapi_examples_geometry_filtering: Dict[str, Example] = {
"example_ewkb": {
"summary": "Filter using EWKB and 'intersects'",
Expand Down
26 changes: 16 additions & 10 deletions api/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

import geoalchemy2
import sqlean
from geoalchemy2 import load_spatialite
from geoalchemy2 import load_spatialite, Geometry, WKTElement
from geoalchemy2.functions import GenericFunction
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker, Session
Expand Down Expand Up @@ -39,14 +41,18 @@ class GeomFromGeoJSON(GenericFunction):
inherit_cache = True


class GeomFromEWKB(GenericFunction):
"""
Returns geometric object given its EWKB Representation
see https://www.gaia-gis.it/gaia-sins/spatialite-sql-5.1.0.html
class EWKTGeometry(Geometry):
# We need to override constructor only to set extended to True
def __init__(self, geometry_type: Optional[str] = "GEOMETRY", srid=-1, dimension=2, spatial_index=True,
use_N_D_index=False, use_typmod: Optional[bool] = None, from_text: Optional[str] = None,
name: Optional[str] = None, nullable=True, _spatial_index_reflected=None) -> None:
super().__init__(geometry_type, srid, dimension, spatial_index, use_N_D_index, use_typmod, from_text, name,
nullable, _spatial_index_reflected)
self.extended = True

Return type: :class:`geoalchemy2.types.Geometry`.
"""
name = "geometry"
from_text = 'ST_GeomFromEWKT'
as_binary = 'AsEWKT'
ElementType = WKTElement

type = geoalchemy2.types.Geometry()
inherit_cache = True
cache_ok = False
34 changes: 24 additions & 10 deletions api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def boundaries_search(
sort_order=sort_order,
request=request,
boundaries_filter=boundaries_filter,
srid=None
srid=None,
geometry_output_format=None,
)

@router.get(
Expand Down Expand Up @@ -105,9 +106,10 @@ def get_with_geometry(
),
db: Session = Depends(database.get_db),
srid: int = constants.query_srid,
geometry_output_format: schemas.GeometryOutputFormat = constants.query_geometry_output_type,
service: services.BaseBoundariesService = Depends(service_class),
):
if row := service.get_by_code(db=db, code=code, srid=srid):
if row := service.get_by_code(db=db, code=code, srid=srid, geometry_output_format=geometry_output_format):
return row
else:
raise HTTPException(
Expand Down Expand Up @@ -183,6 +185,7 @@ def addresses_search(
],
sort_by: schemas.SearchSortBy = Query(default=schemas.SearchSortBy.code),
sort_order: schemas.SearchSortOrder = Query(default=schemas.SearchSortOrder.asc),
geometry_output_format: schemas.GeometryOutputFormat = constants.query_geometry_output_type,
srid: int = constants.query_srid,
db: Session = Depends(database.get_db),
addresses_filter: filters.AddressesFilter = Depends(filters.AddressesFilter),
Expand All @@ -195,6 +198,7 @@ def addresses_search(
request=request,
srid=srid,
boundaries_filter=addresses_filter,
geometry_output_format=geometry_output_format
)


Expand All @@ -212,15 +216,19 @@ def addresses_search(
def get(
code: int = Path(
description="The code of the address to retrieve",
examples=[
155218235
]
openapi_examples={
"example_address_code": {
"summary": "Example address code",
"value": 155218235
},
},
),
srid: int = constants.query_srid,
geometry_output_format: schemas.GeometryOutputFormat = constants.query_geometry_output_type,
db: Session = Depends(database.get_db),
service: services.AddressesService = Depends(services.AddressesService),
):
if item := service.get_by_code(db=db, code=code, srid=srid):
if item := service.get_by_code(db=db, code=code, srid=srid, geometry_output_format=geometry_output_format):
return item
else:
raise HTTPException(
Expand Down Expand Up @@ -257,6 +265,7 @@ def rooms_search(
sort_by: schemas.SearchSortBy = Query(default=schemas.SearchSortBy.code),
sort_order: schemas.SearchSortOrder = Query(default=schemas.SearchSortOrder.asc),
srid: int = constants.query_srid,
geometry_output_format: schemas.GeometryOutputFormat = Query(default=schemas.GeometryOutputFormat.ewkt),
db: Session = Depends(database.get_db),
rooms_filter: filters.RoomsFilter = Depends(filters.RoomsFilter),
service: services.RoomsService = Depends(services.RoomsService),
Expand All @@ -268,6 +277,7 @@ def rooms_search(
request=request,
srid=srid,
boundaries_filter=rooms_filter,
geometry_output_format=geometry_output_format
)


Expand All @@ -285,15 +295,19 @@ def rooms_search(
def get(
code: int = Path(
description="The code of the room to retrieve",
examples=[
194858325
]
openapi_examples={
"example_room_code": {
"summary": "Example room code",
"value": 194858325
},
},
),
srid: int = constants.query_srid,
geometry_output_format: schemas.GeometryOutputFormat = constants.query_geometry_output_type,
db: Session = Depends(database.get_db),
service: services.RoomsService = Depends(services.RoomsService),
):
if item := service.get_by_code(db=db, code=code, srid=srid):
if item := service.get_by_code(db=db, code=code, srid=srid, geometry_output_format=geometry_output_format):
return item
else:
raise HTTPException(
Expand Down
5 changes: 5 additions & 0 deletions api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class SearchSortOrder(str, enum.Enum):
desc = 'desc'


class GeometryOutputFormat(str, enum.Enum):
ewkt = 'ewkt'
ewkb = 'ewkb'


class Geometry(BaseModel):
srid: int = Field(description="Spatial Reference Identifier (SRID) for the geometry")
data: str = Field(description="Geometry data in WKB (Well-Known Binary) format, represented as a hex string")
Expand Down
80 changes: 65 additions & 15 deletions api/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate
from geoalchemy2 import Geometry
from geoalchemy2.functions import ST_Transform
from sqlalchemy import select, Select, func, text, Row, Label
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Session, InstrumentedAttribute
from sqlalchemy.sql import operators

import database
import filters
import models
import schemas
Expand Down Expand Up @@ -70,16 +72,33 @@ class BaseBoundariesService(abc.ABC):
model_class: Type[models.BaseBoundaries]

@abc.abstractmethod
def _get_select_query(self, srid: Optional[int]) -> Select:
def _get_select_query(self, srid: Optional[int],
geometry_output_format: Optional[schemas.GeometryOutputFormat]) -> Select:
pass

@abc.abstractmethod
def _filter_by_code(self, query: Select, code: int) -> Select:
pass

@staticmethod
def _get_geometry_field(field: InstrumentedAttribute, srid: int) -> Label:
return ST_Transform(field, srid).label("geometry")
def _get_geometry_output_type(geometry_output_format: schemas.GeometryOutputFormat):
match geometry_output_format:
case schemas.GeometryOutputFormat.ewkt:
return database.EWKTGeometry
case schemas.GeometryOutputFormat.ewkb:
return Geometry()
case _:
raise ValueError(f"Unable to map geometry output format {geometry_output_format}")

@staticmethod
def _get_geometry_field(
field: InstrumentedAttribute,
srid: int,
geometry_output_format: schemas.GeometryOutputFormat
) -> Label:
geometry_output_type = BaseBoundariesService._get_geometry_output_type(geometry_output_format)

return ST_Transform(field, srid, type_=geometry_output_type).label("geometry")

def search(
self,
Expand All @@ -89,8 +108,9 @@ def search(
request: schemas.BaseSearchRequest,
boundaries_filter: filters.BaseFilter,
srid: Optional[int],
geometry_output_format: Optional[schemas.GeometryOutputFormat]
) -> Page:
query = self._get_select_query(srid=srid)
query = self._get_select_query(srid=srid, geometry_output_format=geometry_output_format)

query = boundaries_filter.apply(request, db, query)

Expand All @@ -106,9 +126,10 @@ def get_by_code(
self,
db: Session,
code: int,
geometry_output_format: Optional[schemas.GeometryOutputFormat] = None,
srid: Optional[int] = None,
) -> Row | None:
query = self._get_select_query(srid=srid)
query = self._get_select_query(srid=srid, geometry_output_format=geometry_output_format)
query = self._filter_by_code(code=code, query=query)

return db.execute(query).first()
Expand All @@ -117,15 +138,20 @@ def get_by_code(
class CountiesService(BaseBoundariesService):
model_class = models.Counties

def _get_select_query(self, srid: Optional[int]) -> Select:
def _get_select_query(
self,
srid: Optional[int],
geometry_output_format: Optional[schemas.GeometryOutputFormat],
) -> Select:
columns = [
models.Counties.code,
models.Counties.feature_id,
models.Counties.name,
models.Counties.area_ha,
models.Counties.area_ha,
models.Counties.created_at,
] + ([self._get_geometry_field(models.Counties.geom, srid)] if srid else [])
] + ([self._get_geometry_field(models.Counties.geom, srid,
geometry_output_format)] if srid and geometry_output_format else [])

return select(*columns).select_from(models.Counties)

Expand All @@ -136,7 +162,11 @@ def _filter_by_code(self, query: Select, code: int) -> Select:
class MunicipalitiesService(BaseBoundariesService):
model_class = models.Municipalities

def _get_select_query(self, srid: Optional[int]) -> Select:
def _get_select_query(
self,
srid: Optional[int],
geometry_output_format: Optional[schemas.GeometryOutputFormat],
) -> Select:
columns = [
models.Municipalities.code,
models.Municipalities.feature_id,
Expand All @@ -157,7 +187,11 @@ def _filter_by_code(self, query: Select, code: int) -> Select:
class EldershipsService(BaseBoundariesService):
model_class = models.Elderships

def _get_select_query(self, srid: Optional[int]) -> Select:
def _get_select_query(
self,
srid: Optional[int],
geometry_output_format: Optional[schemas.GeometryOutputFormat],
) -> Select:
columns = [
models.Elderships.code,
models.Elderships.feature_id,
Expand All @@ -178,7 +212,10 @@ def _filter_by_code(self, query: Select, code: int) -> Select:
class ResidentialAreasService(BaseBoundariesService):
model_class = models.ResidentialAreas

def _get_select_query(self, srid: Optional[int]) -> Select:
def _get_select_query(
self, srid: Optional[int],
geometry_output_format: Optional[schemas.GeometryOutputFormat],
) -> Select:
columns = [
models.ResidentialAreas.code,
models.ResidentialAreas.feature_id,
Expand All @@ -199,7 +236,11 @@ def _filter_by_code(self, query: Select, code: int) -> Select:
class StreetsService(BaseBoundariesService):
model_class = models.Streets

def _get_select_query(self, srid: Optional[int]) -> Select:
def _get_select_query(
self,
srid: Optional[int],
geometry_output_format: Optional[schemas.GeometryOutputFormat],
) -> Select:
columns = [
models.Streets.code,
models.Streets.feature_id,
Expand All @@ -221,7 +262,11 @@ def _filter_by_code(self, query: Select, code: int) -> Select:
class AddressesService(BaseBoundariesService):
model_class = models.Addresses

def _get_select_query(self, srid: Optional[int]) -> Select:
def _get_select_query(
self,
srid: Optional[int],
geometry_output_format: Optional[schemas.GeometryOutputFormat],
) -> Select:
columns = [
models.Addresses.feature_id,
models.Addresses.code,
Expand All @@ -232,7 +277,7 @@ def _get_select_query(self, srid: Optional[int]) -> Select:
_flat_residential_area_object,
_municipality_object,
_flat_street_object,
self._get_geometry_field(models.Addresses.geom, srid)
self._get_geometry_field(models.Addresses.geom, srid, geometry_output_format)
]

return select(*columns).select_from(models.Addresses) \
Expand All @@ -248,13 +293,18 @@ def _filter_by_code(self, query: Select, code: int) -> Select:
class RoomsService(BaseBoundariesService):
model_class = models.Rooms

def _get_select_query(self, srid: Optional[int]) -> Select:
def _get_select_query(
self,
srid: Optional[int],
geometry_output_format: Optional[schemas.GeometryOutputFormat],
) -> Select:
columns = [
models.Rooms.code,
models.Rooms.room_number,
models.Rooms.created_at,
_address_short_object,
] + ([self._get_geometry_field(models.Addresses.geom, srid)] if srid else [])
] + ([self._get_geometry_field(models.Addresses.geom, srid,
geometry_output_format)] if srid and geometry_output_format else [])

return select(*columns).select_from(models.Rooms) \
.outerjoin(models.Rooms.address) \
Expand Down

0 comments on commit d8cf4af

Please sign in to comment.