Skip to content

Commit

Permalink
Update load_dataframe to do better validation and give GerrDB a conte…
Browse files Browse the repository at this point in the history
…xt to properly close sockets
  • Loading branch information
peterrrock2 committed Sep 20, 2024
1 parent 032c43b commit e2c74bd
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 77 deletions.
312 changes: 235 additions & 77 deletions gerrydb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import geopandas as gpd
import httpx
import pandas as pd
from pandas.core.indexes.base import Index as pdIndex
import tomlkit
from rapidfuzz import process, fuzz

from gerrydb.cache import GerryCache
from gerrydb.exceptions import ConfigError
Expand All @@ -26,6 +28,7 @@
ViewRepo,
ViewTemplateRepo,
)
from gerrydb.repos.base import normalize_path
from gerrydb.repos.geography import GeoValType
from gerrydb.schemas import (
Column,
Expand Down Expand Up @@ -173,6 +176,20 @@ def __init__(
transport=self._transport,
)

# TODO: add a flag to all methods to force the use of the context manager
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
if self.client is not None:
self.client.close()
self.client = None
if self._temp_dir is not None:
self._temp_dir.cleanup()
self._temp_dir = None

return False

def context(self, notes: str = "") -> "WriteContext":
"""Creates a write context with session-level metadata.
Expand Down Expand Up @@ -329,10 +346,206 @@ def view_templates(self) -> ViewTemplateRepo:
schema=ViewTemplate, base_url="/view-templates", session=self.db, ctx=self
)

def __create_geos(
self,
df: Union[pd.DataFrame, gpd.GeoDataFrame],
*,
namespace: str,
locality: Union[str, Locality],
layer: Union[str, GeoLayer],
batch_size: int,
max_conns: int,
) -> None:
"""
Private method called by the `load_dataframe` method to load geometries
into the database.
Adds the geometries in the 'geometry' column of the dataframe to the database.
Args:
df: The dataframe containing the geometries to be added.
namespace: The namespace to which the geometries belong.
locality: The locality to which the geometries belong. (e.g. 'pennsylvania')
layer: The layer to which the geometries belong. (e.g. 'vtd')
batch_size: The number of rows to import per batch.
max_conns: The maximum number of simultaneous connections to the API.
"""
if "geometry" in df.columns:
df = df.to_crs("epsg:4269") # import as lat/long
geos = dict(df.geometry)
else:
geos = {key: None for key in df.index}

# Augment geographies with internal points if available.
if "internal_point" in df.columns:
internal_points = dict(df.internal_point)
geos = {path: (geo, internal_points[path]) for path, geo in geos.items()}

try:
asyncio.run(_load_geos(self.geo, geos, namespace, batch_size, max_conns))

except Exception as e:
if str(e) == "Cannot create geographies that already exist.":
# TODO: Make this error more specific maybe?
raise e
raise e

if locality is not None and layer is not None:
self.geo_layers.map_locality(
layer=layer,
locality=locality,
geographies=[f"/{namespace}/{key}" for key in df.index],
)

def __validate_geos(
self,
df: Union[pd.DataFrame, gpd.GeoDataFrame],
locality: Union[str, Locality],
layer: Union[str, GeoLayer],
):
"""
A private method called by the `load_dataframe` method to validate that the passed
geometry paths exist in the database.
All of the geometry paths in the dataframe must exist in a single locality and in a
single layer. If they do not, this method will raise an error.
Args:
df: The dataframe containing the geometries to be added.
locality: The locality to which the geometries belong.
layer: The layer to which the geometries belong.
Raises:
ValueError: If the locality or layer is not provided.
ValueError: If the paths in the index of the dataframe do not match any of the paths in
the database.
ValueError: If there are paths missing from the dataframe compared to the paths for
the given locality and layer in the database. All geometries must be updated
at the same time to avoid unintentional null values.
"""
if locality is None or layer is None:
raise ValueError(
"Locality and layer must be provided if create_geo is False."
)

locality_path = ""
layer_path = ""

if isinstance(locality, Locality):
locality_path = locality.canonical_path
else:
locality_path = locality
if isinstance(layer, GeoLayer):
layer_path = layer.path
else:
layer_path = layer

known_paths = set(self.db.geo.all_paths(locality_path, layer_path))
df_paths = set(df.index)

if df_paths - known_paths == df_paths:
raise ValueError(
f"The index of the dataframe does not appear to match any geographies in the namespace "
f"which have the following geoid format: '{list(known_paths)[0] if len(known_paths) > 0 else None}'. "
f"Please ensure that the index of the dataframe matches the format of the geoid."
)

if df_paths - known_paths != set():
raise ValueError(
f"Failure in load_dataframe. Tried to import geographies for layer "
f"'{layer_path}' and locality '{locality_path}', but the following geographies "
f"do not exist in the namespace "
f"'{self.db.namespace}': {df_paths - known_paths}"
)

if known_paths - df_paths != set():
raise ValueError(
f"Failure in load_dataframe. Tried to import geographies for layer "
f"'{layer_path}' and locality '{locality_path}', but the passed dataframe "
f"does not contain the following geographies: "
f"{known_paths - df_paths}. "
f"Please provide values for these geographies in the dataframe."
)

def __validate_columns(self, columns):
"""
Private method called by the `load_dataframe` method to validate the columns
passed to the method.
This method makes sure that the columns passed to the method have the permissible
data types and that they exist in the database before we attempt to load values for
those columns.
Args:
columns: The columns to be loaded.
Raises:
ValueError: If the columns parameter is not a list, a pandas Index, or a dictionary.
ValueError: If the columns parameter is a dictionary and contains a value that is
not a Column object.
ValueError: If some of the columns in `columns` do not exist in the database.
This also looks for close matches to the columns in the database
and prints them out for the user.
"""

if not (
isinstance(columns, list)
or isinstance(columns, pdIndex)
or isinstance(columns, dict)
):
raise ValueError(
f"The columns parameter must be a list of paths, a pandas.core.indexes.base.Index, "
f"or a dictionary of paths to Column objects. "
f"Received type {type(columns)}."
)

if isinstance(columns, list) or isinstance(columns, pdIndex):
column_paths = []
for col in self.db.columns.all():
column_paths.append(col.canonical_path)
column_paths.extend(col.aliases)

column_paths = set(column_paths)
cur_columns = set([normalize_path(col) for col in columns])

for col in cur_columns:
if "/" in col:
raise ValueError(
f"Column paths passed to the `load_dataframe` function "
f"cannot contain '/'. Column '{col}' is invalid."
)

else:
for item in columns.values():
if not isinstance(item, Column):
raise ValueError(
f"The columns parameter must be a list of paths, a pandas.core.indexes.base.Index, "
f"or a dictionary of paths to Column objects. "
f"Found a dictionary with a value of type {type(item)}."
)

column_paths = {col.canonical_path for col in self.db.columns.all()}
cur_columns = set([v.canonical_path for v in columns.values()])

missing_cols = cur_columns - column_paths

if missing_cols != set():
for path in missing_cols:
best_matches = process.extract(path, column_paths, limit=5)
print(
f"Could not find column corresponding to '{path}', the best matches "
f"are: {[match[0] for match in best_matches]}"
)
raise ValueError(
f"Some of the columns in the dataframe do not exist in the database. "
f"Please create the missing columns first using the `db.columns.create` method."
)

def load_dataframe(
self,
df: Union[pd.DataFrame, gpd.GeoDataFrame],
columns: dict[str, Column],
columns: Union[pdIndex, list[str], dict[str, Column]],
*,
create_geo: bool = False,
namespace: Optional[str] = None,
Expand All @@ -341,7 +554,8 @@ def load_dataframe(
batch_size: int = 5000,
max_conns: int = 1,
) -> None:
"""Imports a DataFrame to GerryDB.
"""
Imports a DataFrame to GerryDB.
Plain DataFrames do not include rich column metadata, so the columns used
in the DataFrame must be defined before import.
Expand All @@ -364,10 +578,8 @@ def load_dataframe(
rows in `df`.
Args:
db: GerryDB client instance.
df: DataFrame to import column values and geographies from.
The df MUST be indexed by the geoid or the import will not
work correctly.
df: DataFrame to import column values and geographies from. The df MUST be indexed
by the geoid or the import will fail.
columns: Mapping between column names in `df` and GerryDB column metadata.
Only columns included in the mapping will be imported.
create_geo: Determines whether to create geographies from the DataFrame.
Expand All @@ -382,81 +594,27 @@ def load_dataframe(
raise ValueError("No namespace available.")

if create_geo:
if "geometry" in df.columns:
df = df.to_crs("epsg:4269") # import as lat/long
geos = dict(df.geometry)
else:
geos = {key: None for key in df.index}

# Augment geographies with internal points if available.
if "internal_point" in df.columns:
internal_points = dict(df.internal_point)
geos = {
path: (geo, internal_points[path]) for path, geo in geos.items()
}

try:
asyncio.run(
_load_geos(self.geo, geos, namespace, batch_size, max_conns)
)

except Exception as e:
if str(e) == "Cannot create geographies that already exist.":
# TODO: Make this error more specific maybe?
raise e
raise e
else:
if locality is None or layer is None:
raise ValueError(
"Locality and layer must be provided if create_geo is False."
)

locality_path = ""
layer_path = ""

if isinstance(locality, Locality):
locality_path = locality.canonical_path
else:
locality_path = locality
if isinstance(layer, GeoLayer):
layer_path = layer.path
else:
layer_path = layer

known_paths = set(self.db.geo.all_paths(locality_path, layer_path))
df_paths = set(df.index)

if df_paths - known_paths == df_paths:
raise ValueError(
f"The index of the dataframe does not appear to match any geographies in the namespace "
f"which have the following geoid format: '{list(known_paths)[0] if len(known_paths) > 0 else None}'. "
f"Please ensure that the index of the dataframe matches the format of the geoid."
)
self.__create_geos(
df=df,
namespace=namespace,
locality=locality,
layer=layer,
batch_size=batch_size,
max_conns=max_conns,
)

if df_paths - known_paths != set():
raise ValueError(
f"Failure in load_dataframe. Tried to import geographies for layer "
f"'{layer_path}' and locality '{locality_path}', but the following geographies "
f"do not exist in the namespace "
f"'{self.db.namespace}': {df_paths - known_paths}"
)
if not create_geo:
self.__validate_geos(df=df, locality=locality, layer=layer)

if known_paths - df_paths != set():
raise ValueError(
f"Failure in load_dataframe. Tried to import geographies for layer "
f"'{layer_path}' and locality '{locality_path}', but the passed dataframe "
f"does not contain the following geographies: "
f"{known_paths - df_paths}. "
f"Please provide values for these geographies in the dataframe."
)
self.__validate_columns(columns)

if create_geo and locality is not None and layer is not None:
self.geo_layers.map_locality(
layer=layer,
locality=locality,
geographies=[f"/{namespace}/{key}" for key in df.index],
)
# TODO: Check to see if grabbing all of the columns and then filtering
# is significantly different from a data transfer perspective in the
# average case.
if not isinstance(columns, dict):
columns = {c: self.columns.get(c) for c in df.columns}

return
asyncio.run(
_load_column_values(self.columns, df, columns, batch_size, max_conns)
)
Expand Down
16 changes: 16 additions & 0 deletions gerrydb/repos/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,22 @@ def update(

return Column(**response.json())

@err("Failed to retrieve column names")
@online
def all(self) -> list[str]:
response = self.session.client.get(f"/columns/{self.session.namespace}")
response.raise_for_status()

return [Column(**item) for item in response.json()]

@err("Failed to retrieve column")
@online
def get(self, path: str) -> Column:
path = normalize_path(path)
response = self.session.client.get(f"/columns/{self.session.namespace}/{path}")
response.raise_for_status()
return Column(**response.json())

@err("Failed to set column values")
@namespaced
@write_context
Expand Down

0 comments on commit e2c74bd

Please sign in to comment.