|
| 1 | +from abc import ABC |
| 2 | + |
| 3 | +from geoalchemy2.functions import ST_Intersects, ST_Transform, ST_GeomFromEWKT, ST_Contains, ST_IsValid |
| 4 | +from sqlalchemy import Select |
| 5 | +from sqlalchemy.orm import Session, InstrumentedAttribute |
| 6 | +from sqlalchemy.sql.functions import GenericFunction |
| 7 | +from sqlean import OperationalError |
| 8 | + |
| 9 | +import database |
| 10 | +import models |
| 11 | +import schemas |
| 12 | + |
| 13 | + |
| 14 | +class BaseFilter(ABC): |
| 15 | + model_class: type(database.Base) |
| 16 | + |
| 17 | + def apply( |
| 18 | + self, |
| 19 | + request: schemas.BaseSearchRequest, |
| 20 | + db: Session, |
| 21 | + query: Select, |
| 22 | + ): |
| 23 | + if geometry_filter := request.geometry: |
| 24 | + query = self._apply_geometry_filter( |
| 25 | + geometry_filter=geometry_filter, |
| 26 | + db=db, |
| 27 | + query=query |
| 28 | + ) |
| 29 | + return query |
| 30 | + |
| 31 | + def _apply_general_boundaries_filter( |
| 32 | + self, |
| 33 | + general_boundaries_filter: schemas.GeneralBoundariesFilter, |
| 34 | + query: Select |
| 35 | + ) -> Select: |
| 36 | + if hasattr(self.model_class, 'name') and general_boundaries_filter.name: |
| 37 | + query = _filter_by_string_field( |
| 38 | + string_filter=general_boundaries_filter.name, |
| 39 | + query=query, |
| 40 | + string_field=getattr(self.model_class, 'name') |
| 41 | + ) |
| 42 | + |
| 43 | + feature_ids = general_boundaries_filter.feature_ids |
| 44 | + if feature_ids and len(general_boundaries_filter.feature_ids) > 0: |
| 45 | + query = query.filter(getattr(self.model_class, 'feature_id').in_(feature_ids)) |
| 46 | + |
| 47 | + codes = general_boundaries_filter.codes |
| 48 | + if codes and len(codes) > 0: |
| 49 | + query = query.filter(getattr(self.model_class, 'code').in_(codes)) |
| 50 | + |
| 51 | + return query |
| 52 | + |
| 53 | + def _apply_geometry_filter( |
| 54 | + self, |
| 55 | + geometry_filter: schemas.GeometryFilter, |
| 56 | + db: Session, |
| 57 | + query: Select, |
| 58 | + ) -> Select: |
| 59 | + filter_func_type = _get_filter_func(geometry_filter.method) |
| 60 | + geom_field = getattr(self.model_class, 'geom') |
| 61 | + |
| 62 | + if ewkb := geometry_filter.ewkb: |
| 63 | + query = _filter_by_geometry( |
| 64 | + db=db, |
| 65 | + query=query, |
| 66 | + field="ewkb", |
| 67 | + geom_value=ewkb, |
| 68 | + filter_func_type=filter_func_type, |
| 69 | + geom_from_func_type=database.GeomFromEWKB, |
| 70 | + geom_field=geom_field, |
| 71 | + ) |
| 72 | + |
| 73 | + if ewkt := geometry_filter.ewkt: |
| 74 | + query = _filter_by_geometry( |
| 75 | + db=db, |
| 76 | + query=query, |
| 77 | + field="ewkt", |
| 78 | + geom_value=ewkt, |
| 79 | + filter_func_type=filter_func_type, |
| 80 | + geom_from_func_type=ST_GeomFromEWKT, |
| 81 | + geom_field=geom_field, |
| 82 | + ) |
| 83 | + |
| 84 | + if geojson := geometry_filter.geojson: |
| 85 | + query = _filter_by_geometry( |
| 86 | + db=db, |
| 87 | + query=query, |
| 88 | + field="geojson", |
| 89 | + geom_value=geojson, |
| 90 | + filter_func_type=filter_func_type, |
| 91 | + geom_from_func_type=database.GeomFromGeoJSON, |
| 92 | + geom_field=geom_field, |
| 93 | + ) |
| 94 | + |
| 95 | + return query |
| 96 | + |
| 97 | + |
| 98 | +class CountiesFilter(BaseFilter): |
| 99 | + model_class = models.Counties |
| 100 | + |
| 101 | + def apply( |
| 102 | + self, |
| 103 | + request: schemas.CountiesSearchRequest, |
| 104 | + db: Session, |
| 105 | + query: Select, |
| 106 | + ): |
| 107 | + query = super().apply(request, db, query) |
| 108 | + |
| 109 | + if counties_filter := request.counties: |
| 110 | + query = self._apply_general_boundaries_filter(general_boundaries_filter=counties_filter, query=query) |
| 111 | + |
| 112 | + return query |
| 113 | + |
| 114 | + |
| 115 | +class MunicipalitiesFilter(CountiesFilter): |
| 116 | + model_class = models.Municipalities |
| 117 | + |
| 118 | + def apply( |
| 119 | + self, |
| 120 | + request: schemas.MunicipalitiesSearchRequest, |
| 121 | + db: Session, |
| 122 | + query: Select, |
| 123 | + ): |
| 124 | + query = super().apply(request, db, query) |
| 125 | + if municipalities_filter := request.municipalities: |
| 126 | + query = self._apply_general_boundaries_filter( |
| 127 | + general_boundaries_filter=municipalities_filter, |
| 128 | + query=query, |
| 129 | + ) |
| 130 | + |
| 131 | + return query |
| 132 | + |
| 133 | + |
| 134 | +class EldershipsFilter(MunicipalitiesFilter): |
| 135 | + model_class = models.Elderships |
| 136 | + |
| 137 | + def apply( |
| 138 | + self, |
| 139 | + request: schemas.EldershipsSearchRequest, |
| 140 | + db: Session, |
| 141 | + query: Select, |
| 142 | + ): |
| 143 | + query = super().apply(request, db, query) |
| 144 | + if elderships_filter := request.elderships: |
| 145 | + query = self._apply_general_boundaries_filter( |
| 146 | + general_boundaries_filter=elderships_filter, |
| 147 | + query=query, |
| 148 | + ) |
| 149 | + return query |
| 150 | + |
| 151 | + |
| 152 | +class ResidentialAreasFilter(MunicipalitiesFilter): |
| 153 | + model_class = models.ResidentialAreas |
| 154 | + |
| 155 | + def apply( |
| 156 | + self, |
| 157 | + request: schemas.ResidentialAreasSearchRequest, |
| 158 | + db: Session, |
| 159 | + query: Select, |
| 160 | + ): |
| 161 | + query = super().apply(request, db, query) |
| 162 | + if residential_areas_filter := request.residential_areas: |
| 163 | + query = self._apply_general_boundaries_filter( |
| 164 | + general_boundaries_filter=residential_areas_filter, query=query, |
| 165 | + ) |
| 166 | + |
| 167 | + return query |
| 168 | + |
| 169 | + |
| 170 | +class StreetsFilter(ResidentialAreasFilter): |
| 171 | + model_class = models.Streets |
| 172 | + |
| 173 | + def apply( |
| 174 | + self, |
| 175 | + request: schemas.StreetsSearchRequest, |
| 176 | + db: Session, |
| 177 | + query: Select, |
| 178 | + ): |
| 179 | + query = super().apply(request, db, query) |
| 180 | + if streets_filter := request.streets: |
| 181 | + query = self._apply_general_boundaries_filter( |
| 182 | + general_boundaries_filter=streets_filter, |
| 183 | + query=query, |
| 184 | + ) |
| 185 | + |
| 186 | + return query |
| 187 | + |
| 188 | + |
| 189 | +class AddressesFilter(StreetsFilter): |
| 190 | + model_class = models.Addresses |
| 191 | + |
| 192 | + def apply( |
| 193 | + self, |
| 194 | + request: schemas.AddressesSearchRequest, |
| 195 | + db: Session, |
| 196 | + query: Select, |
| 197 | + ): |
| 198 | + query = super().apply(request, db, query) |
| 199 | + |
| 200 | + return query |
| 201 | + |
| 202 | + |
| 203 | +def _is_valid_geometry(db: Session, geom: GenericFunction) -> bool: |
| 204 | + try: |
| 205 | + return db.execute(ST_IsValid(geom)).scalar() == 1 |
| 206 | + except OperationalError: |
| 207 | + return False |
| 208 | + |
| 209 | + |
| 210 | +def _filter_by_geometry( |
| 211 | + db: Session, |
| 212 | + query: Select, |
| 213 | + geom_value: str, |
| 214 | + field: str, |
| 215 | + geom_field: InstrumentedAttribute, |
| 216 | + filter_func_type: type(GenericFunction), |
| 217 | + geom_from_func_type: type(GenericFunction), |
| 218 | +): |
| 219 | + geom = ST_Transform(geom_from_func_type(geom_value), 3346) |
| 220 | + if not _is_valid_geometry(db, geom): |
| 221 | + raise InvalidFilterGeometry(message="Invalid geometry", field=field, value=geom_value) |
| 222 | + |
| 223 | + return query.where(filter_func_type(geom, geom_field)) |
| 224 | + |
| 225 | + |
| 226 | +def _get_filter_func(filter_method: schemas.GeometryFilterMethod) -> type(GenericFunction): |
| 227 | + match filter_method: |
| 228 | + case schemas.GeometryFilterMethod.intersects: |
| 229 | + return ST_Intersects |
| 230 | + case schemas.GeometryFilterMethod.contains: |
| 231 | + return ST_Contains |
| 232 | + case _: |
| 233 | + raise ValueError(f"Unknown geometry filter method: {filter_method}") |
| 234 | + |
| 235 | + |
| 236 | +def _filter_by_string_field( |
| 237 | + string_filter: schemas.StringFilter, |
| 238 | + query: Select, |
| 239 | + string_field: InstrumentedAttribute |
| 240 | +) -> Select: |
| 241 | + if string_filter.contains: |
| 242 | + query = query.filter(string_field.icontains(string_filter.contains)) |
| 243 | + if string_filter.starts: |
| 244 | + query = query.filter(string_field.istartswith(string_filter.starts)) |
| 245 | + |
| 246 | + return query |
| 247 | + |
| 248 | + |
| 249 | +class InvalidFilterGeometry(Exception): |
| 250 | + def __init__(self, message: str, field: str, value: str): |
| 251 | + self.message = message |
| 252 | + self.field = field |
| 253 | + self.value = value |
| 254 | + super().__init__(self.message) |
0 commit comments