diff --git a/setup.py b/setup.py index 4960111a02a1..d59f0c6496a0 100644 --- a/setup.py +++ b/setup.py @@ -169,6 +169,8 @@ def get_git_sha() -> str: "ocient": [ "sqlalchemy-ocient>=1.0.0", "pyocient>=1.0.15", + "shapely", + "geojson", ], "oracle": ["cx-Oracle>8.0.0, <8.1"], "pinot": ["pinotdb>=0.3.3, <0.4"], diff --git a/superset/db_engine_specs/ocient.py b/superset/db_engine_specs/ocient.py index 5413ad9e1330..4b8a59117e95 100644 --- a/superset/db_engine_specs/ocient.py +++ b/superset/db_engine_specs/ocient.py @@ -26,13 +26,15 @@ # Need to try-catch here because pyocient may not be installed try: # Ensure pyocient inherits Superset's logging level + import geojson import pyocient + from shapely import wkt from superset import app superset_log_level = app.config["LOG_LEVEL"] pyocient.logger.setLevel(superset_log_level) -except ImportError as e: +except (ImportError, RuntimeError): pass from superset.db_engine_specs.base import BaseEngineSpec @@ -84,39 +86,93 @@ def _to_hex(data: bytes) -> str: return data.hex() -def _polygon_to_json(polygon: Any) -> str: +def _wkt_to_geo_json(geo_as_wkt: str) -> Any: """ - Converts the _STPolygon object into its JSON representation. + Converts pyocient geometry objects to their geoJSON representation. - :param data: the polygon object - :returns: JSON representation of the polygon + :param geo_as_wkt: the GIS object in WKT format + :returns: the geoJSON encoding of `geo` """ - json_value = f"{str([[p.long, p.lat] for p in polygon.exterior])}" - if polygon.holes: - for hole in polygon.holes: - json_value += f", {str([[p.long, p.lat] for p in hole])}" - json_value = f"[{json_value}]" - return json_value + # Need to try-catch here because these deps may not be installed + geo = wkt.loads(geo_as_wkt) + return geojson.Feature(geometry=geo, properties={}) -def _linestring_to_json(linestring: Any) -> str: +def _point_list_to_wkt( + points, # type: List[pyocient._STPoint] +) -> str: """ - Converts the _STLinestring object into its JSON representation. + Converts the list of pyocient._STPoint elements to a WKT LineString. - :param data: the linestring object - :returns: JSON representation of the linestring + :param points: the list of pyocient._STPoint objects + :returns: WKT LineString """ - return f"{str([[p.long, p.lat] for p in linestring.points])}" + coords = [f"{p.long} {p.lat}" for p in points] + return f"LINESTRING({', '.join(coords)})" + +def _point_to_geo_json( + point, # type: pyocient._STPoint +) -> Any: + """ + Converts the pyocient._STPolygon object to the geoJSON format + + :param point: the pyocient._STPoint instance + :returns: the geoJSON encoding of this point + """ + wkt_point = str(point) + return _wkt_to_geo_json(wkt_point) -def _point_to_comma_delimited(point: Any) -> str: + +def _linestring_to_geo_json( + linestring, # type: pyocient._STLinestring +) -> Any: + """ + Converts the pyocient._STLinestring object to a GIS format + compatible with the Superset visualization toolkit (powered + by Deck.gl). + + :param linestring: the pyocient._STLinestring instance + :returns: the geoJSON of this linestring + """ + if len(linestring.points) == 1: + # While technically an invalid linestring object, Ocient + # permits ST_LINESTRING containers to contain a single + # point. The flexibility allows the database to encode + # geometry collections as an array of the highest dimensional + # element in the collection (i.e. ST_LINESTRING[] or + # ST_POLYGON[]). + point = linestring.points[0] + return _point_to_geo_json(point) + + wkt_linestring = str(linestring) + return _wkt_to_geo_json(wkt_linestring) + + +def _polygon_to_geo_json( + polygon, # type: pyocient._STPolygon +) -> Any: """ - Returns the x and y coordinates as a comma delimited string. + Converts the pyocient._STPolygon object to a GIS format + compatible with the Superset visualization toolkit (powered + by Deck.gl). - :param data: the point object - :returns: the x and y coordinates as a comma delimited string + :param polygon: the pyocient._STPolygon instance + :returns: the geoJSON encoding of this polygon """ - return f"{point.long}, {point.lat}" + if len(polygon.exterior) > 0 and len(polygon.holes) == 0: + if len(polygon.exterior) == 1: + # The exterior ring contains a single ST_POINT + point = polygon.exterior[0] + return _point_to_geo_json(point) + if polygon.exterior[0] != polygon.exterior[-1]: + # The exterior ring contains an open ST_LINESTRING + wkt_linestring = _point_list_to_wkt(polygon.exterior) + return _wkt_to_geo_json(wkt_linestring) + # else + # This is a valid ST_POLYGON + wkt_polygon = str(polygon) + return _wkt_to_geo_json(wkt_polygon) # Sanitization function for column values @@ -145,11 +201,11 @@ def _point_to_comma_delimited(point: Any) -> str: _sanitized_ocient_type_codes: Dict[int, SanitizeFunc] = { TypeCodes.BINARY: _to_hex, - TypeCodes.ST_POINT: _point_to_comma_delimited, + TypeCodes.ST_POINT: _point_to_geo_json, TypeCodes.IP: str, TypeCodes.IPV4: str, - TypeCodes.ST_LINESTRING: _linestring_to_json, - TypeCodes.ST_POLYGON: _polygon_to_json, + TypeCodes.ST_LINESTRING: _linestring_to_geo_json, + TypeCodes.ST_POLYGON: _polygon_to_geo_json, } except ImportError as e: _sanitized_ocient_type_codes = {} diff --git a/tests/unit_tests/db_engine_specs/test_ocient.py b/tests/unit_tests/db_engine_specs/test_ocient.py index c5578fa93727..af9fd2ad1681 100644 --- a/tests/unit_tests/db_engine_specs/test_ocient.py +++ b/tests/unit_tests/db_engine_specs/test_ocient.py @@ -17,13 +17,21 @@ # pylint: disable=import-outside-toplevel -from datetime import datetime -from typing import Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import pytest +from superset.db_engine_specs.ocient import ( + _point_list_to_wkt, + _sanitized_ocient_type_codes, +) from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + +def ocient_is_installed() -> bool: + return len(_sanitized_ocient_type_codes) > 0 + + # (msg,expected) MARSHALED_OCIENT_ERRORS: List[Tuple[str, SupersetError]] = [ ( @@ -213,3 +221,188 @@ def test_connection_errors(msg: str, expected: SupersetError) -> None: result = OcientEngineSpec.extract_errors(Exception(msg)) assert result == [expected] + + +def _generate_gis_type_sanitization_test_cases() -> ( + List[Tuple[str, int, Any, Dict[str, Any]]] +): + if not ocient_is_installed(): + return [] + + from pyocient import _STLinestring, _STPoint, _STPolygon, TypeCodes + + return [ + ( + "empty_point", + TypeCodes.ST_POINT, + _STPoint(long=float("inf"), lat=float("inf")), + { + "geometry": None, + "properties": {}, + "type": "Feature", + }, + ), + ( + "valid_point", + TypeCodes.ST_POINT, + _STPoint(long=float(33), lat=float(45)), + { + "geometry": { + "coordinates": [33.0, 45.0], + "type": "Point", + }, + "properties": {}, + "type": "Feature", + }, + ), + ( + "empty_line", + TypeCodes.ST_LINESTRING, + _STLinestring([]), + { + "geometry": None, + "properties": {}, + "type": "Feature", + }, + ), + ( + "valid_line", + TypeCodes.ST_LINESTRING, + _STLinestring( + [_STPoint(long=t[0], lat=t[1]) for t in [(1, 0), (1, 1), (1, 2)]] + ), + { + "geometry": { + "coordinates": [[1, 0], [1, 1], [1, 2]], + "type": "LineString", + }, + "properties": {}, + "type": "Feature", + }, + ), + ( + "downcast_line_to_point", + TypeCodes.ST_LINESTRING, + _STLinestring([_STPoint(long=t[0], lat=t[1]) for t in [(1, 0)]]), + { + "geometry": { + "coordinates": [1, 0], + "type": "Point", + }, + "properties": {}, + "type": "Feature", + }, + ), + ( + "empty_polygon", + TypeCodes.ST_POLYGON, + _STPolygon(exterior=[], holes=[]), + { + "geometry": None, + "properties": {}, + "type": "Feature", + }, + ), + ( + "valid_polygon_no_holes", + TypeCodes.ST_POLYGON, + _STPolygon( + exterior=[ + _STPoint(long=t[0], lat=t[1]) for t in [(1, 0), (1, 1), (1, 0)] + ], + holes=[], + ), + { + "geometry": { + "coordinates": [[[1, 0], [1, 1], [1, 0]]], + "type": "Polygon", + }, + "properties": {}, + "type": "Feature", + }, + ), + ( + "valid_polygon_with_holes", + TypeCodes.ST_POLYGON, + _STPolygon( + exterior=[ + _STPoint(long=t[0], lat=t[1]) for t in [(1, 0), (1, 1), (1, 0)] + ], + holes=[ + [_STPoint(long=t[0], lat=t[1]) for t in [(2, 0), (2, 1), (2, 0)]], + [_STPoint(long=t[0], lat=t[1]) for t in [(3, 0), (3, 1), (3, 0)]], + ], + ), + { + "geometry": { + "coordinates": [ + [[1, 0], [1, 1], [1, 0]], + [[2, 0], [2, 1], [2, 0]], + [[3, 0], [3, 1], [3, 0]], + ], + "type": "Polygon", + }, + "properties": {}, + "type": "Feature", + }, + ), + ( + "downcast_poly_to_point", + TypeCodes.ST_POLYGON, + _STPolygon( + exterior=[_STPoint(long=t[0], lat=t[1]) for t in [(1, 0)]], + holes=[], + ), + { + "geometry": { + "coordinates": [1, 0], + "type": "Point", + }, + "properties": {}, + "type": "Feature", + }, + ), + ( + "downcast_poly_to_line", + TypeCodes.ST_POLYGON, + _STPolygon( + exterior=[_STPoint(long=t[0], lat=t[1]) for t in [(1, 0), (0, 1)]], + holes=[], + ), + { + "geometry": { + "coordinates": [[1, 0], [0, 1]], + "type": "LineString", + }, + "properties": {}, + "type": "Feature", + }, + ), + ] + + +@pytest.mark.skipif(not ocient_is_installed(), reason="requires ocient dependencies") +@pytest.mark.parametrize( + "name,type_code,geo,expected", _generate_gis_type_sanitization_test_cases() +) +def test_gis_type_sanitization( + name: str, type_code: int, geo: Any, expected: Any +) -> None: + # Hack to silence erroneous mypy errors + def die(any: Any) -> Callable[[Any], Any]: + pytest.fail(f"no sanitizer for type code {type_code}") + raise AssertionError() + + type_sanitizer = _sanitized_ocient_type_codes.get(type_code, die) + actual = type_sanitizer(geo) + assert expected == actual + + +@pytest.mark.skipif(not ocient_is_installed(), reason="requires ocient dependencies") +def test_point_list_to_wkt() -> None: + from pyocient import _STPoint + + wkt = _point_list_to_wkt( + [_STPoint(long=t[0], lat=t[1]) for t in [(2, 0), (2, 1), (2, 0)]] + ) + assert wkt == "LINESTRING(2 0, 2 1, 2 0)"