diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 89f1995..027fc88 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,43 +1,118 @@ name: GerryDB client tests -on: [push, pull_request] +on: + workflow_dispatch: + push: + pull_request: + + jobs: run: - runs-on: ${{ matrix.os }} + name: Run tests Linux + runs-on: ubuntu-latest + services: + postgres: + image: postgis/postgis:16-3.4 + ports: + - 54320:5432 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: dev + POSTGRES_DB: gerrydb + options: >- + --health-cmd "pg_isready -U postgres" + --health-interval 10s + --health-timeout 20s + --health-retries 10 strategy: matrix: - python-version: ['3.10', '3.11'] - os: [ubuntu-latest, macos-latest] + python-version: ['3.10', '3.12'] env: OS: ${{ matrix.os }} PYTHON: ${{ matrix.python-version }} + GERRYDB_DATABASE_URI: "postgresql://postgres:dev@localhost:54320/gerrydb" + GERRYDB_TEST_SERVER: "localhost:8000" + POSTGRES_USER: postgres + POSTGRES_PASSWORD: dev + POSTGRES_DB: gerrydb + steps: - - uses: actions/checkout@main + - name: Checkout repository into custom directory + uses: actions/checkout@v3 + - name: Set up Python - uses: actions/setup-python@master + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + + + - name: Install PostgreSQL client tools + run: | + sudo apt-get update + sudo apt-get install -y postgresql-client + + - name: Install GDAL + run: | + sudo apt-get install -y gdal-bin + + - name: Clone backend repo + run: | + git clone --branch dev https://github.com/mggg/gerrydb-meta.git ../gerrydb-meta + - name: Install dependencies run: | + cd .. + python -m venv .venv + source .venv/bin/activate + pip install --upgrade pip pip install poetry - pip install . pip install black isort - - name: Check formatting + pip install ./gerrydb-client-py/ + pip install ./gerrydb-meta/ + + - name: Set up GerryDB and get API key + run: | + cd .. + source .venv/bin/activate + cd ./gerrydb-meta + echo "export GERRYDB_TEST_API_KEY=$(python init.py --name test --email test --reset)" > ../.env + echo "export GERRYDB_TEST_SERVER=localhost:8000" >> ../.env + + - name: Wait for PostgreSQL to be ready + run: | + until pg_isready -h 127.0.0.1 -p 54320 -U postgres; do + echo "Waiting for PostgreSQL to be ready..." + sleep 2 + done + + - name: Check the postgres run: | - python -m black . --check - python -m isort . --diff + PGPASSWORD=dev psql -h 127.0.0.1 -p 54320 -U postgres -c '\l' + PGPASSWORD=dev psql -h 127.0.0.1 -p 54320 -U postgres -d gerrydb -c 'SELECT postgis_version();' + + - name: Start the uvicorn server + run: | + source ../.venv/bin/activate + source ../.env + cd ../gerrydb-meta + nohup uvicorn gerrydb_meta.main:app --host 0.0.0.0 --port 8000 --log-level trace > uvicorn.log 2>&1 & + + - name: Wait for Uvicorn to be ready + run: | + until curl -s http://localhost:8000/api/v1 > /dev/null; do + echo "Waiting for Uvicorn to be ready..." + sleep 2 + done + - name: Run tests and generate coverage report run: | - pip install pytest - pip install pytest-cov - pytest -v --cov=./ --cov-report=xml - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 - with: - token: ${{ secrets.CODECOV_TOKEN }} - directory: . - env_vars: OS,PYTHON - fail_ci_if_error: true - files: ./coverage.xml - flags: unittests - name: codecov-umbrella - verbose: true + source ../.venv/bin/activate + source ../.env + pip install pytest pytest-cov + pytest -v -s tests --cov=./ --cov-report=xml + + - name: Print Uvicorn logs on Failure + if: failure() + run: | + echo "Displaying Uvicorn logs:" + cat ../gerrydb-meta/uvicorn.log + diff --git a/.gitignore b/.gitignore index dbce267..32964df 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,7 @@ cython_debug/ #.idea/ .DS_Store + + +# Ignore all of the backup files that might be floating around +*.tar \ No newline at end of file diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..89f57b3 --- /dev/null +++ b/TODO.md @@ -0,0 +1,10 @@ +# For Docs + +[ ] Show how to make a column with the write context + [ ] Show how this all works with the the load_dataframe method + + +# Ideas +[ ] Maybe cache all of the columns in the database on the local machine so + validation can be done locally? Periodically checks for columns can be done + to ensure that the local cache is up to date. \ No newline at end of file diff --git a/gerrydb/cache.py b/gerrydb/cache.py index ae33824..e2e90a3 100644 --- a/gerrydb/cache.py +++ b/gerrydb/cache.py @@ -97,9 +97,6 @@ def upsert_view_gpkg( db_cursor.execute("SELECT SUM(file_size_kb) FROM view") total_db_size = db_cursor.fetchone()[0] - print(total_db_size) - print(f"max_size: {self.max_size_gb * 1024 * 1024}") - while total_db_size > self.max_size_gb * 1024 * 1024: db_cursor.execute("SELECT * FROM view ORDER BY cached_at ASC LIMIT 1") oldest = db_cursor.fetchone() diff --git a/gerrydb/client.py b/gerrydb/client.py index 1675aa1..4760e7f 100644 --- a/gerrydb/client.py +++ b/gerrydb/client.py @@ -10,9 +10,9 @@ import geopandas as gpd import httpx import pandas as pd +from pandas.core.indexes.base import Index as pdIndex import tomlkit -from shapely import Point -from shapely.geometry.base import BaseGeometry +from rapidfuzz import process, fuzz from gerrydb.cache import GerryCache from gerrydb.exceptions import ConfigError @@ -28,6 +28,7 @@ ViewRepo, ViewTemplateRepo, ) +from gerrydb.repos.base import normalize_path from gerrydb.repos.geography import GeoValType from gerrydb.schemas import ( Column, @@ -175,6 +176,20 @@ def __init__( transport=self._transport, ) + # TODO: add a flag to all methods to force the use of the context manager + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.client is not None: + self.client.close() + self.client = None + if self._temp_dir is not None: + self._temp_dir.cleanup() + self._temp_dir = None + + return False + def context(self, notes: str = "") -> "WriteContext": """Creates a write context with session-level metadata. @@ -331,10 +346,206 @@ def view_templates(self) -> ViewTemplateRepo: schema=ViewTemplate, base_url="/view-templates", session=self.db, ctx=self ) + def __create_geos( + self, + df: Union[pd.DataFrame, gpd.GeoDataFrame], + *, + namespace: str, + locality: Union[str, Locality], + layer: Union[str, GeoLayer], + batch_size: int, + max_conns: int, + ) -> None: + """ + Private method called by the `load_dataframe` method to load geometries + into the database. + + Adds the geometries in the 'geometry' column of the dataframe to the database. + + Args: + df: The dataframe containing the geometries to be added. + namespace: The namespace to which the geometries belong. + locality: The locality to which the geometries belong. (e.g. 'pennsylvania') + layer: The layer to which the geometries belong. (e.g. 'vtd') + batch_size: The number of rows to import per batch. + max_conns: The maximum number of simultaneous connections to the API. + + """ + if "geometry" in df.columns: + df = df.to_crs("epsg:4269") # import as lat/long + geos = dict(df.geometry) + else: + geos = {key: None for key in df.index} + + # Augment geographies with internal points if available. + if "internal_point" in df.columns: + internal_points = dict(df.internal_point) + geos = {path: (geo, internal_points[path]) for path, geo in geos.items()} + + try: + asyncio.run(_load_geos(self.geo, geos, namespace, batch_size, max_conns)) + + except Exception as e: + if str(e) == "Cannot create geographies that already exist.": + # TODO: Make this error more specific maybe? + raise e + raise e + + if locality is not None and layer is not None: + self.geo_layers.map_locality( + layer=layer, + locality=locality, + geographies=[f"/{namespace}/{key}" for key in df.index], + ) + + def __validate_geos( + self, + df: Union[pd.DataFrame, gpd.GeoDataFrame], + locality: Union[str, Locality], + layer: Union[str, GeoLayer], + ): + """ + A private method called by the `load_dataframe` method to validate that the passed + geometry paths exist in the database. + + All of the geometry paths in the dataframe must exist in a single locality and in a + single layer. If they do not, this method will raise an error. + + Args: + df: The dataframe containing the geometries to be added. + locality: The locality to which the geometries belong. + layer: The layer to which the geometries belong. + + Raises: + ValueError: If the locality or layer is not provided. + ValueError: If the paths in the index of the dataframe do not match any of the paths in + the database. + ValueError: If there are paths missing from the dataframe compared to the paths for + the given locality and layer in the database. All geometries must be updated + at the same time to avoid unintentional null values. + """ + if locality is None or layer is None: + raise ValueError( + "Locality and layer must be provided if create_geo is False." + ) + + locality_path = "" + layer_path = "" + + if isinstance(locality, Locality): + locality_path = locality.canonical_path + else: + locality_path = locality + if isinstance(layer, GeoLayer): + layer_path = layer.path + else: + layer_path = layer + + known_paths = set(self.db.geo.all_paths(locality_path, layer_path)) + df_paths = set(df.index) + + if df_paths - known_paths == df_paths: + raise ValueError( + f"The index of the dataframe does not appear to match any geographies in the namespace " + f"which have the following geoid format: '{list(known_paths)[0] if len(known_paths) > 0 else None}'. " + f"Please ensure that the index of the dataframe matches the format of the geoid." + ) + + if df_paths - known_paths != set(): + raise ValueError( + f"Failure in load_dataframe. Tried to import geographies for layer " + f"'{layer_path}' and locality '{locality_path}', but the following geographies " + f"do not exist in the namespace " + f"'{self.db.namespace}': {df_paths - known_paths}" + ) + + if known_paths - df_paths != set(): + raise ValueError( + f"Failure in load_dataframe. Tried to import geographies for layer " + f"'{layer_path}' and locality '{locality_path}', but the passed dataframe " + f"does not contain the following geographies: " + f"{known_paths - df_paths}. " + f"Please provide values for these geographies in the dataframe." + ) + + def __validate_columns(self, columns): + """ + Private method called by the `load_dataframe` method to validate the columns + passed to the method. + + This method makes sure that the columns passed to the method have the permissible + data types and that they exist in the database before we attempt to load values for + those columns. + + Args: + columns: The columns to be loaded. + + Raises: + ValueError: If the columns parameter is not a list, a pandas Index, or a dictionary. + ValueError: If the columns parameter is a dictionary and contains a value that is + not a Column object. + ValueError: If some of the columns in `columns` do not exist in the database. + This also looks for close matches to the columns in the database + and prints them out for the user. + """ + + if not ( + isinstance(columns, list) + or isinstance(columns, pdIndex) + or isinstance(columns, dict) + ): + raise ValueError( + f"The columns parameter must be a list of paths, a pandas.core.indexes.base.Index, " + f"or a dictionary of paths to Column objects. " + f"Received type {type(columns)}." + ) + + if isinstance(columns, list) or isinstance(columns, pdIndex): + column_paths = [] + for col in self.db.columns.all(): + column_paths.append(col.canonical_path) + column_paths.extend(col.aliases) + + column_paths = set(column_paths) + cur_columns = set([normalize_path(col) for col in columns]) + + for col in cur_columns: + if "/" in col: + raise ValueError( + f"Column paths passed to the `load_dataframe` function " + f"cannot contain '/'. Column '{col}' is invalid." + ) + + else: + for item in columns.values(): + if not isinstance(item, Column): + raise ValueError( + f"The columns parameter must be a list of paths, a pandas.core.indexes.base.Index, " + f"or a dictionary of paths to Column objects. " + f"Found a dictionary with a value of type {type(item)}." + ) + + column_paths = {col.canonical_path for col in self.db.columns.all()} + cur_columns = set([v.canonical_path for v in columns.values()]) + + missing_cols = cur_columns - column_paths + + if missing_cols != set(): + for path in missing_cols: + best_matches = process.extract(path, column_paths, limit=5) + print( + f"Could not find column corresponding to '{path}', the best matches " + f"are: {[match[0] for match in best_matches]}" + ) + raise ValueError( + f"Some of the columns in the dataframe do not exist in the database. " + f"Please create the missing columns first using the `db.columns.create` method." + ) + def load_dataframe( self, df: Union[pd.DataFrame, gpd.GeoDataFrame], - columns: dict[str, Column], + columns: Union[pdIndex, list[str], dict[str, Column]], *, create_geo: bool = False, namespace: Optional[str] = None, @@ -343,7 +554,8 @@ def load_dataframe( batch_size: int = 5000, max_conns: int = 1, ) -> None: - """Imports a DataFrame to GerryDB. + """ + Imports a DataFrame to GerryDB. Plain DataFrames do not include rich column metadata, so the columns used in the DataFrame must be defined before import. @@ -366,8 +578,8 @@ def load_dataframe( rows in `df`. Args: - db: GerryDB client instance. - df: DataFrame to import column values and geographies from. + df: DataFrame to import column values and geographies from. The df MUST be indexed + by the geoid or the import will fail. columns: Mapping between column names in `df` and GerryDB column metadata. Only columns included in the mapping will be imported. create_geo: Determines whether to create geographies from the DataFrame. @@ -382,41 +594,30 @@ def load_dataframe( raise ValueError("No namespace available.") if create_geo: - if "geometry" in df.columns: - df = df.to_crs("epsg:4269") # import as lat/long - geos = dict(df.geometry) - else: - geos = {key: None for key in df.index} - - # Augment geographies with internal points if available. - if "internal_point" in df.columns: - internal_points = dict(df.internal_point) - geos = { - path: (geo, internal_points[path]) for path, geo in geos.items() - } + self.__create_geos( + df=df, + namespace=namespace, + locality=locality, + layer=layer, + batch_size=batch_size, + max_conns=max_conns, + ) - try: - asyncio.run( - _load_geos(self.geo, geos, namespace, batch_size, max_conns) - ) + if not create_geo: + self.__validate_geos(df=df, locality=locality, layer=layer) - except Exception as e: - if str(e) == "Cannot create geographies that already exist.": - # TODO: Make this error more specific maybe? - raise e - raise e + self.__validate_columns(columns) + + # TODO: Check to see if grabbing all of the columns and then filtering + # is significantly different from a data transfer perspective in the + # average case. + if not isinstance(columns, dict): + columns = {c: self.columns.get(c) for c in df.columns} asyncio.run( _load_column_values(self.columns, df, columns, batch_size, max_conns) ) - if create_geo and locality is not None and layer is not None: - self.geo_layers.map_locality( - layer=layer, - locality=locality, - geographies=[f"/{namespace}/{key}" for key in df.index], - ) - # based on https://stackoverflow.com/a/61478547 async def gather_batch(coros, n): @@ -466,13 +667,20 @@ async def _load_column_values( val_batches: list[tuple[Column, dict[str, Any]]] = [] for col_name, col_meta in columns.items(): + # This only works because the df is indexed by geography path. col_vals = list(dict(df[col_name]).items()) for idx in range(0, len(df), batch_size): val_batches.append((col_meta, dict(col_vals[idx : idx + batch_size]))) async with httpx.AsyncClient(**params) as client: tasks = [ - repo.async_set_values(col, col.namespace, values=batch, client=client) + repo.async_set_values( + path=col.path, + namespace=col.namespace, + col=col, + values=batch, + client=client, + ) for col, batch in val_batches ] results = await gather_batch(tasks, max_conns) diff --git a/gerrydb/create.py b/gerrydb/create.py index bdf24d4..d5e00d2 100644 --- a/gerrydb/create.py +++ b/gerrydb/create.py @@ -1,4 +1,5 @@ """CLI for creating GerryDB resources.""" + from typing import Optional import click @@ -6,6 +7,7 @@ from gerrydb import GerryDB from gerrydb.exceptions import ResultError + @click.group() def cli(): """Creates GerryDB resources.""" @@ -27,6 +29,7 @@ def namespace(path: str, description: str, public: bool): else: raise e + @cli.command() @click.argument("path") @click.option("--description", required=True) @@ -37,12 +40,15 @@ def geo_layer(path: str, description: str, namespace: str, source_url: Optional[ db = GerryDB(namespace=namespace) with db.context(notes=f'Creating geographic layer "{path}" from CLI') as ctx: try: - ctx.geo_layers.create(path=path, description=description, source_url=source_url) + ctx.geo_layers.create( + path=path, description=description, source_url=source_url + ) except ResultError as e: if "Failed to create geographic layer" in e.args[0]: print(f"Failed to create {path} layer, already exists") else: raise e + if __name__ == "__main__": cli() diff --git a/gerrydb/exceptions.py b/gerrydb/exceptions.py index 2f7cbbd..501b22d 100644 --- a/gerrydb/exceptions.py +++ b/gerrydb/exceptions.py @@ -37,5 +37,13 @@ class CacheInitError(CacheError): """Raised when a GerryDB cache cannot be initialized.""" +class CacheObjectError(CacheError): + """Raised when the cache cannot load an object.""" + + class ViewLoadError(GerryDBError): """Raised when a view cannot be loaded (e.g. from a GeoPackage).""" + + +class GerryPathError(GerryDBError): + """Raised when an invalid path is provided. Generally, this means invalid characters are present""" diff --git a/gerrydb/repos/__init__.py b/gerrydb/repos/__init__.py index 64411ac..166649e 100644 --- a/gerrydb/repos/__init__.py +++ b/gerrydb/repos/__init__.py @@ -1,4 +1,5 @@ """GerryDB API object repositories.""" + from gerrydb.repos.column import ColumnRepo from gerrydb.repos.column_set import ColumnSetRepo from gerrydb.repos.geo_layer import GeoLayerRepo diff --git a/gerrydb/repos/base.py b/gerrydb/repos/base.py index 7e0828a..4e1f5a7 100644 --- a/gerrydb/repos/base.py +++ b/gerrydb/repos/base.py @@ -7,7 +7,13 @@ import httpx import pydantic -from gerrydb.exceptions import OnlineError, RequestError, ResultError, WriteContextError +from gerrydb.exceptions import ( + OnlineError, + RequestError, + ResultError, + WriteContextError, + GerryPathError, +) from gerrydb.schemas import BaseModel if TYPE_CHECKING: @@ -104,6 +110,18 @@ def write_context_wrapper(*args, **kwargs): return write_context_wrapper +# These characters are most likely to appear in the resource_id part of +# a path (typically the last segment). Exclusion of these characters +# prevents ogr2ogr fails and helps protect against malicious code injection. +INVALID_PATH_SUBSTRINGS = set( + { + "..", + " ", + ";", + } +) + + def normalize_path(path: str, case_sensitive_uid: bool = False) -> str: """Normalizes a path (removes leading, trailing, and duplicate slashes, and lowercases the path if `case_sensitive` is `False`). @@ -111,6 +129,12 @@ def normalize_path(path: str, case_sensitive_uid: bool = False) -> str: Some paths, such as paths containing GEOIDs, are case-sensitive in the last segment. In these cases, `case_sensitive` should be set to `True`. """ + for item in INVALID_PATH_SUBSTRINGS: + if item in path: + raise GerryPathError( + f"Invalid path: '{path}'. Please remove the following substring: '{item}'" + ) + if case_sensitive_uid: path_list = path.strip().split("/") return "/".join( diff --git a/gerrydb/repos/column.py b/gerrydb/repos/column.py index d9e3117..bc4f6e3 100644 --- a/gerrydb/repos/column.py +++ b/gerrydb/repos/column.py @@ -11,6 +11,7 @@ namespaced, online, write_context, + normalize_path, ) from gerrydb.schemas import ( Column, @@ -68,6 +69,7 @@ def create( Returns: Metadata for the new column. """ + path = normalize_path(path) response = self.ctx.client.post( f"{self.base_url}/{namespace}", json=ColumnCreate( @@ -106,29 +108,52 @@ def update( Returns: The updated column. """ + clean_path = normalize_path(f"{self.base_url}/{namespace}/{path}") response = self.ctx.client.patch( - f"{self.base_url}/{namespace}/{path}", + clean_path, json=ColumnPatch(aliases=aliases).dict(), ) response.raise_for_status() return Column(**response.json()) + @err("Failed to retrieve column names") + @online + def all(self) -> list[str]: + response = self.session.client.get(f"/columns/{self.session.namespace}") + response.raise_for_status() + + return [Column(**item) for item in response.json()] + + @err("Failed to retrieve column") + @online + @namespaced + def get(self, path: str, namespace: str = None) -> Column: + path = normalize_path(path) + response = self.session.client.get(f"/columns/{self.session.namespace}/{path}") + response.raise_for_status() + return Column(**response.json()) + @err("Failed to set column values") @namespaced @write_context @online def set_values( self, - path_or_col: Union[Column, str], + path: Optional[str] = None, namespace: Optional[str] = None, *, + col: Optional[Column] = None, values: dict[Union[str, Geography], Any], ) -> None: """Sets the values of a column on a collection of geographies. Args: - path_or_col: Short identifier for the column or a `Column` metadata object. + path: Short identifier for the column. Only this or `col` should be provided. + If both are provided, the path attribute of `col` will be used in place + of the passed `path` argument. + col: `Column` metadata object. If the `path` is not provided, the column's + path will be used. namespace: Namespace of the column (used when `path_or_col` is a raw path). values: A mapping from geography paths or `Geography` metadata objects @@ -137,10 +162,17 @@ def set_values( Raises: RequestError: If the values cannot be set on the server side. """ - path = path_or_col.path if isinstance(path_or_col, Column) else path_or_col + assert path is None or isinstance(path, str) + assert col is None or isinstance(col, Column) + + if path is None and col is None: + raise ValueError("Either `path` or `col` must be provided.") + + path = col.path if col is not None else path + clean_path = normalize_path(f"{self.base_url}/{namespace}/{path}") response = self.ctx.client.put( - f"{self.base_url}/{namespace}/{path}", + clean_path, json=[ ColumnValue( path=( @@ -163,16 +195,21 @@ def set_values( @online async def async_set_values( self, - path_or_col: Union[Column, str], + path: Optional[str] = None, namespace: Optional[str] = None, *, + col: Optional[Column] = None, values: dict[Union[str, Geography], Any], client: Optional[httpx.AsyncClient] = None, ) -> None: """Asynchronously sets the values of a column on a collection of geographies. Args: - path_or_col: Short identifier for the column or a `Column` metadata object. + path: Short identifier for the column. Only this or `col` should be provided. + If both are provided, the path attribute of `col` will be used in place + of the passed `path` argument. + col: `Column` metadata object. If the `path` is not provided, the column's + path will be used. namespace: Namespace of the column (used when `path_or_col` is a raw path). values: A mapping from geography paths or `Geography` metadata objects @@ -183,7 +220,14 @@ async def async_set_values( Raises: RequestError: If the values cannot be set on the server side. """ - path = path_or_col.path if isinstance(path_or_col, Column) else path_or_col + assert path is None or isinstance(path, str) + assert col is None or isinstance(col, Column) + + if path is None and col is None: + raise ValueError("Either `path` or `col` must be provided.") + + path = col.path if col is not None else path + clean_path = normalize_path(f"{self.base_url}/{namespace}/{path}") ephemeral_client = client is None if ephemeral_client: @@ -191,23 +235,24 @@ async def async_set_values( params["transport"] = httpx.AsyncHTTPTransport(retries=1) client = httpx.AsyncClient(**params) + json = [ + ColumnValue( + path=( + f"/{geo.namespace}/{geo.path}" + if isinstance(geo, Geography) + else geo + ), + value=_coerce(value), + ).dict() + for geo, value in values.items() + ] response = await client.put( - f"{self.base_url}/{namespace}/{path}", - json=[ - ColumnValue( - path=( - f"/{geo.namespace}/{geo.path}" - if isinstance(geo, Geography) - else geo - ), - value=_coerce(value), - ).dict() - for geo, value in values.items() - ], + clean_path, + json=json, ) if response.status_code != 204: - log.debug(f"For {path_or_col} returned {response}") + log.debug(f"For path = {path} and col = {col} returned {response}") response.raise_for_status() diff --git a/gerrydb/repos/column_set.py b/gerrydb/repos/column_set.py index cfc89f5..4e2f305 100644 --- a/gerrydb/repos/column_set.py +++ b/gerrydb/repos/column_set.py @@ -1,4 +1,5 @@ """Repository for column sets.""" + from typing import Optional, Union from gerrydb.exceptions import RequestError diff --git a/gerrydb/repos/geo_layer.py b/gerrydb/repos/geo_layer.py index 0d69a0a..0331ffe 100644 --- a/gerrydb/repos/geo_layer.py +++ b/gerrydb/repos/geo_layer.py @@ -1,4 +1,5 @@ """Repository for geographic layers.""" + from typing import Optional, Union from gerrydb.repos.base import ( @@ -69,9 +70,11 @@ def map_locality( response = self.ctx.client.put( f"{self.base_url}/{layer.namespace}/{layer.path}", params={ - "locality": locality.canonical_path - if isinstance(locality, Locality) - else locality + "locality": ( + locality.canonical_path + if isinstance(locality, Locality) + else locality + ) }, json=GeoSetCreate( paths=[ diff --git a/gerrydb/repos/geography.py b/gerrydb/repos/geography.py index fdb7f49..2d2b874 100644 --- a/gerrydb/repos/geography.py +++ b/gerrydb/repos/geography.py @@ -17,6 +17,7 @@ err, online, write_context, + namespaced, ) from gerrydb.schemas import Geography, GeographyCreate, GeoImport @@ -257,3 +258,14 @@ def async_bulk( return AsyncGeoImporter(repo=self, namespace=namespace, max_conns=max_conns) # TODO: get() + @namespaced + @online + @err("Failed to load geographies") + def all_paths(self, fips: str, layer_name: str) -> list[str]: + response = self.session.client.get( + f"/__list_geo/{self.session.namespace}/{fips}/{layer_name}" + ) + response.raise_for_status() + response_json = response.json() + + return response_json diff --git a/gerrydb/repos/locality.py b/gerrydb/repos/locality.py index 5476252..eccba01 100644 --- a/gerrydb/repos/locality.py +++ b/gerrydb/repos/locality.py @@ -1,10 +1,12 @@ """Repository for localities.""" + from dataclasses import dataclass from typing import TYPE_CHECKING, Optional from gerrydb.repos.base import ObjectRepo, err, normalize_path, online, write_context from gerrydb.schemas import Locality, LocalityCreate, LocalityPatch from gerrydb.exceptions import ResultError + if TYPE_CHECKING: from gerrydb.client import GerryDB, WriteContext @@ -35,9 +37,11 @@ def get(self, path: str) -> Optional[Locality]: response.raise_for_status() return Locality(**response.json()) - @err("Failed to create locality") # Decorator for handling HTTP request and Pydantic validation errors - @write_context # Decorator for marking operations that require a write context - @online # Decorator for marking online-only operations + @err( + "Failed to create locality" + ) # Decorator for handling HTTP request and Pydantic validation errors + @write_context # Decorator for marking operations that require a write context + @online # Decorator for marking online-only operations def create( self, canonical_path: str, @@ -104,24 +108,24 @@ def create_bulk( Returns: The new localities. """ - loc_list = [-1]*len(locs) + loc_list = [-1] * len(locs) for i, loc in enumerate(locs): try: - loc_object = self.create(canonical_path=loc.canonical_path, - name=loc.name, - parent_path=loc.parent_path, - default_proj=loc.default_proj, - aliases=loc.aliases,) + loc_object = self.create( + canonical_path=loc.canonical_path, + name=loc.name, + parent_path=loc.parent_path, + default_proj=loc.default_proj, + aliases=loc.aliases, + ) loc_list[i] = loc_object except ResultError as e: if "Failed to create canonical path to new location(s)." in e.args[0]: print(f"Failed to create {loc.name}, path already exists") else: raise e - + return loc_list - - # loc_list = [-1]*len(locs) # for i, loc in enumerate(locs): @@ -132,7 +136,7 @@ def create_bulk( # response.raise_for_status() # loc_list[i] = Locality(**response.json()[0]) # return(loc_list[i]) - + # response = self.ctx.client.post( # "/localities/", # json=[loc.dict() for loc in locs], diff --git a/gerrydb/repos/namespace.py b/gerrydb/repos/namespace.py index 31d5a45..9a58ed9 100644 --- a/gerrydb/repos/namespace.py +++ b/gerrydb/repos/namespace.py @@ -1,4 +1,5 @@ """Repository for namespaces.""" + from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -86,7 +87,7 @@ def create( ) response.raise_for_status() - + return Namespace(**response.json()) def __getitem__(self, path: str) -> Optional[Namespace]: diff --git a/gerrydb/repos/plan.py b/gerrydb/repos/plan.py index 3dc5918..c51cf3a 100644 --- a/gerrydb/repos/plan.py +++ b/gerrydb/repos/plan.py @@ -1,4 +1,5 @@ """Repository for districting plans.""" + from typing import Optional, Union from gerrydb.repos.base import ( diff --git a/gerrydb/repos/view.py b/gerrydb/repos/view.py index ba0d8dd..49d79eb 100644 --- a/gerrydb/repos/view.py +++ b/gerrydb/repos/view.py @@ -1,7 +1,7 @@ """Repository for views.""" import json -import re +import io import sqlite3 from datetime import datetime from pathlib import Path @@ -73,6 +73,9 @@ def _load_gpkg_geometry(geom: bytes) -> BaseGeometry: """Loads a geometry from a raw GeoPackage WKB blob.""" # header format: https://www.geopackage.org/spec/#gpb_format + if geom == None: + raise ValueError("Invalid GeoPackage geometry: empty geometry.") + envelope_flag = (geom[3] & 0b00001110) >> 1 try: envelope_bytes = _GPKG_ENVELOPE_BYTES[envelope_flag] @@ -116,7 +119,14 @@ def __init__(self, meta: ViewMeta, gpkg_path: Path, conn: sqlite3.Connection): @classmethod def from_gpkg(cls, path: Path) -> "View": """Loads a view from a GeoPackage.""" - conn = sqlite3.connect(path) + if isinstance(path, io.BytesIO): + path.seek(0) + conn = sqlite3.connect( + "file:cached_view?mode=memory&cache=shared", uri=True + ) + conn.executescript(path.read().decode("utf-8")) + else: + conn = sqlite3.connect(path) tables = conn.execute( "SELECT name FROM sqlite_master WHERE " @@ -283,25 +293,46 @@ def geographies(self) -> Generator[Geography, None, None]: raw_geo_meta = self._conn.execute( "SELECT meta_id, value FROM gerrydb_geo_meta" ).fetchone() - geo_meta = {row[0]: ObjectMeta(**json.loads(row[1])) for row in raw_geo_meta} + geo_meta = {raw_geo_meta[0]: ObjectMeta(**json.loads(raw_geo_meta[1]))} raw_geos = self._conn.execute( f"""SELECT {self.path}.path, geography, internal_point, meta_id, valid_from FROM {self.path} JOIN {self.path}__internal_points ON {self.path}.path = {self.path}__internal_points.path - JOIN gerrydb_geo_meta + JOIN gerrydb_geo_attrs ON {self.path}.path = gerrydb_geo_attrs.path """ ) for geo_row in raw_geos: - yield Geography( - path=geo_row[0], - geography=_load_gpkg_geometry(geo_row[1]), - internal_point=_load_gpkg_geometry(geo_row[2]), - meta=geo_meta[geo_row[3]], - valid_from=geo_row[4], - ) + if geo_row[2] is not None: + yield Geography( + path=geo_row[0], + geography=_load_gpkg_geometry(geo_row[1]), + internal_point=_load_gpkg_geometry(geo_row[2]), + meta=geo_meta[geo_row[3]], + namespace=self.namespace, + valid_from=geo_row[4], + ) + else: + yield Geography( + path=geo_row[0], + geography=_load_gpkg_geometry(geo_row[1]), + meta=geo_meta[geo_row[3]], + namespace=self.namespace, + valid_from=geo_row[4], + ) + + @property + def values(self) -> list[str]: + raw_paths = self._conn.execute(f"""PRAGMA table_info({self.path})""") + raw_paths = self.to_df().columns + + ret = [] + for item in raw_paths: + if item not in ["geometry"]: + ret.append(f"/{self.namespace}/{item}") + return ret class ViewRepo(NamespacedObjectRepo[ViewMeta]): diff --git a/gerrydb/schemas.py b/gerrydb/schemas.py index e463264..8467f5e 100644 --- a/gerrydb/schemas.py +++ b/gerrydb/schemas.py @@ -2,6 +2,7 @@ This file should be kept in sync with the server-side version. """ + from datetime import datetime from enum import Enum from typing import Any, Optional, Union @@ -16,7 +17,9 @@ UserEmail = constr(max_length=254) # constr is a constrained string, so this is some path that needs to satisfy this regex -GerryPath = constr(regex=r"[a-z0-9][a-z0-9-_/]*") # must start with lowercase or digit, then followed by any lowercase, digit, hyphen, underscore, slash +GerryPath = constr( + regex=r"[a-z0-9][a-z0-9-_/]*" +) # must start with lowercase or digit, then followed by any lowercase, digit, hyphen, underscore, slash NamespacedGerryPath = constr(regex=r"[a-z0-9/][a-z0-9-_/]*") NATIVE_PROJ = pyproj.CRS("EPSG:4269") diff --git a/pyproject.toml b/pyproject.toml index 6ba5e5d..b0e6ec4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,21 +6,26 @@ authors = ["Parker J. Rule "] [tool.poetry.dependencies] python = "^3.9" -tomlkit = "^0.11.6" -msgpack = "^1.0.4" -httpx = "^0.23.3" +tomlkit = "^0.13.0" +msgpack = "^1.1.0" +httpx = "^0.27.2" pydantic = "^1.10.4" -orjson = "^3.8.6" +orjson = "^3.10.0" shapely = "^2.0.1" python-dateutil = "^2.8.2" -geopandas = "^0.12.2" +geopandas = "^1.0.1" networkx = "^3.0" [tool.poetry.group.dev.dependencies] pytest = "^7.2.1" -black = "^23.1.0" +black = "^24.8.0" pytest-vcr = "^1.0.2" +[tool.pytest.ini_options] +markers = [ + "vcr: mark a test as a vcr test", +] + [tool.isort] profile = "black" diff --git a/tests/repos/conftest.py b/tests/repos/conftest.py index 3655ac4..5885372 100644 --- a/tests/repos/conftest.py +++ b/tests/repos/conftest.py @@ -1,4 +1,5 @@ """Fixtures for repository tests.""" + import pytest diff --git a/tests/repos/test_base.py b/tests/repos/test_base.py index cf21178..b8fd052 100644 --- a/tests/repos/test_base.py +++ b/tests/repos/test_base.py @@ -1,4 +1,5 @@ """Tests for base objects and utilities for GerryDB API object repositories.""" + from dataclasses import dataclass from typing import Optional diff --git a/tests/repos/test_column.py b/tests/repos/test_column.py index 0734f85..bafa9a9 100644 --- a/tests/repos/test_column.py +++ b/tests/repos/test_column.py @@ -1,4 +1,5 @@ """Integration/VCR tests for columns.""" + import pytest from shapely import box @@ -46,4 +47,6 @@ def test_column_repo_set_values(client_ns, column): col = ctx.columns.create(**column) with ctx.geo.bulk() as geo_ctx: geo_ctx.create({str(idx): box(0, 0, 1, 1) for idx in range(n)}) - ctx.columns.set_values(col, values={str(idx): idx for idx in range(n)}) + ctx.columns.set_values( + path=col.path, values={str(idx): idx for idx in range(n)} + ) diff --git a/tests/repos/test_column_set.py b/tests/repos/test_column_set.py index 32de339..4320cd5 100644 --- a/tests/repos/test_column_set.py +++ b/tests/repos/test_column_set.py @@ -1,4 +1,5 @@ """Integration/VCR tests for columns.""" + import pytest diff --git a/tests/repos/test_geo_import.py b/tests/repos/test_geo_import.py index 7be5247..0d0bae0 100644 --- a/tests/repos/test_geo_import.py +++ b/tests/repos/test_geo_import.py @@ -1,4 +1,5 @@ """Integration/VCR tests for geographic import metadata.""" + import pytest # The `GeoImport` object is used for internal tracking, so we don't diff --git a/tests/repos/test_geo_layer.py b/tests/repos/test_geo_layer.py index 76a4bff..133aa7f 100644 --- a/tests/repos/test_geo_layer.py +++ b/tests/repos/test_geo_layer.py @@ -1,4 +1,5 @@ """Integration/VCR tests for geographic layers.""" + import pytest diff --git a/tests/repos/test_geography.py b/tests/repos/test_geography.py index c3df225..ab214bd 100644 --- a/tests/repos/test_geography.py +++ b/tests/repos/test_geography.py @@ -1,4 +1,5 @@ """Integration/VCR tests for columns.""" + from shapely import box diff --git a/tests/repos/test_locality.py b/tests/repos/test_locality.py index 0798f9d..0916f07 100644 --- a/tests/repos/test_locality.py +++ b/tests/repos/test_locality.py @@ -1,4 +1,5 @@ """Integration/VCR tests for localities.""" + import pytest from gerrydb.schemas import LocalityCreate diff --git a/tests/repos/test_namespace.py b/tests/repos/test_namespace.py index 52f90c9..58aea42 100644 --- a/tests/repos/test_namespace.py +++ b/tests/repos/test_namespace.py @@ -1,4 +1,5 @@ """Integration/VCR tests for namespaces.""" + import pytest diff --git a/tests/repos/test_plan.py b/tests/repos/test_plan.py index fe4b8d8..da64a4f 100644 --- a/tests/repos/test_plan.py +++ b/tests/repos/test_plan.py @@ -1,4 +1,5 @@ """Integration/VCR tests for districting plans.""" + import pytest from gerrydb.exceptions import ResultError diff --git a/tests/repos/test_view.py b/tests/repos/test_view.py index b6122b9..33edec9 100644 --- a/tests/repos/test_view.py +++ b/tests/repos/test_view.py @@ -1,4 +1,5 @@ """Tests for views.""" + import pytest @@ -18,9 +19,6 @@ def test_view_repo_create__valid(client_with_ia_layer_loc, ia_dataframe): assert set(geo.path for geo in view.geographies) == set(ia_dataframe.index) assert set(col.full_path for col in columns.values()) == set(view.values) - assert all( - len(col_values) == len(view.geographies) for col_values in view.values.values() - ) assert view.graph is None @@ -85,7 +83,11 @@ def test_view_repo_view_to_graph(ia_view_with_graph, ia_graph): expected_cols = set( "/".join(col.split("/")[2:]) for col in ia_view_with_graph.values ) - assert all(set(data) == expected_cols for _, data in view_graph.nodes(data=True)) + # Previous tests in the test suite can add some values to the graph nodes. + # so we just check that the expected columns are present. + assert all( + expected_cols - set(data) == set() for _, data in view_graph.nodes(data=True) + ) @pytest.mark.vcr @@ -97,5 +99,10 @@ def test_view_repo_view_to_graph_geo(ia_view_with_graph, ia_graph): expected_cols = set( "/".join(col.split("/")[2:]) for col in ia_view_with_graph.values - ) | {"area", "geometry"} - assert all(set(data) == expected_cols for _, data in view_graph.nodes(data=True)) + ) | {"internal_point", "geometry"} + + # Previous tests in the test suite can add some values to the graph nodes. + # so we just check that the expected columns are present. + assert all( + expected_cols - set(data) == set() for _, data in view_graph.nodes(data=True) + ) diff --git a/tests/repos/test_view_template.py b/tests/repos/test_view_template.py index 6d4347f..4f5da07 100644 --- a/tests/repos/test_view_template.py +++ b/tests/repos/test_view_template.py @@ -1,4 +1,5 @@ """Integration/VCR tests for view templates.""" + import pytest @@ -12,5 +13,5 @@ def test_view_template_repo_create_get__online_columns_only( view_template = ctx.view_templates.create( path="pops", members=[pop_col, vap_col], description="Population view." ) - print(view_template) + # print(view_template) # TODO: more evaluation here. diff --git a/tests/test_cache.py b/tests/test_cache.py index c773e12..3f6a651 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,35 +1,38 @@ """Tests for GerryDB's local caching layer.""" -import uuid -from datetime import datetime, timedelta, timezone import pytest -from gerrydb.cache import CacheInitError, CacheObjectError, CachePolicyError, GerryCache -from gerrydb.schemas import BaseModel, ObjectCachePolicy, ObjectMeta +from gerrydb.cache import CacheInitError, GerryCache +from tempfile import TemporaryDirectory +from pathlib import Path @pytest.fixture def cache(): """An in-memory instance of `GerryCache`.""" - return GerryCache(":memory:") + cache_dir = TemporaryDirectory() + return GerryCache( + ":memory:", + data_dir=Path(cache_dir.name), + ) def test_gerry_cache_init__no_schema_version(cache): cache._conn.execute("DELETE FROM cache_meta") cache._conn.commit() with pytest.raises(CacheInitError, match="no schema version"): - GerryCache(cache._conn) + GerryCache(cache._conn, cache.data_dir) def test_gerry_cache_init__bad_schema_version(cache): cache._conn.execute("UPDATE cache_meta SET value='bad' WHERE key='schema_version'") cache._conn.commit() with pytest.raises(CacheInitError, match="expected schema version"): - GerryCache(cache._conn) + GerryCache(cache._conn, cache.data_dir) def test_gerry_cache_init__missing_table(cache): - cache._conn.execute("DROP TABLE object") + cache._conn.execute("DROP TABLE view") cache._conn.commit() - with pytest.raises(CacheInitError, match="missing tables"): - GerryCache(cache._conn) + with pytest.raises(CacheInitError, match="missing table"): + GerryCache(cache._conn, cache.data_dir) diff --git a/tests/test_client.py b/tests/test_client.py index 095dcfd..9a8e266 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,5 @@ """Tests for GerryDB session management.""" + import os from unittest import mock