Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update load_dataframe #1

Merged
merged 10 commits into from
Oct 10, 2024
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,7 @@ cython_debug/
#.idea/

.DS_Store


# Ignore all of the backup files that might be floating around
*.tar
10 changes: 10 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# For Docs

[ ] Show how to make a column with the write context
[ ] Show how this all works with the the load_dataframe method


# Ideas
[ ] Maybe cache all of the columns in the database on the local machine so
validation can be done locally? Periodically checks for columns can be done
to ensure that the local cache is up to date.
3 changes: 0 additions & 3 deletions gerrydb/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def upsert_view_gpkg(
db_cursor.execute("SELECT SUM(file_size_kb) FROM view")
total_db_size = db_cursor.fetchone()[0]

print(total_db_size)
print(f"max_size: {self.max_size_gb * 1024 * 1024}")

while total_db_size > self.max_size_gb * 1024 * 1024:
db_cursor.execute("SELECT * FROM view ORDER BY cached_at ASC LIMIT 1")
oldest = db_cursor.fetchone()
Expand Down
278 changes: 243 additions & 35 deletions gerrydb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import geopandas as gpd
import httpx
import pandas as pd
from pandas.core.indexes.base import Index as pdIndex
import tomlkit
from shapely import Point
from shapely.geometry.base import BaseGeometry
from rapidfuzz import process, fuzz

from gerrydb.cache import GerryCache
from gerrydb.exceptions import ConfigError
Expand All @@ -28,6 +28,7 @@
ViewRepo,
ViewTemplateRepo,
)
from gerrydb.repos.base import normalize_path
from gerrydb.repos.geography import GeoValType
from gerrydb.schemas import (
Column,
Expand Down Expand Up @@ -175,6 +176,20 @@ def __init__(
transport=self._transport,
)

# TODO: add a flag to all methods to force the use of the context manager
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
if self.client is not None:
self.client.close()
self.client = None
if self._temp_dir is not None:
self._temp_dir.cleanup()
self._temp_dir = None

return False

def context(self, notes: str = "") -> "WriteContext":
"""Creates a write context with session-level metadata.

Expand Down Expand Up @@ -331,10 +346,206 @@ def view_templates(self) -> ViewTemplateRepo:
schema=ViewTemplate, base_url="/view-templates", session=self.db, ctx=self
)

def __create_geos(
self,
df: Union[pd.DataFrame, gpd.GeoDataFrame],
*,
namespace: str,
locality: Union[str, Locality],
layer: Union[str, GeoLayer],
batch_size: int,
max_conns: int,
) -> None:
"""
Private method called by the `load_dataframe` method to load geometries
into the database.

Adds the geometries in the 'geometry' column of the dataframe to the database.

Args:
df: The dataframe containing the geometries to be added.
namespace: The namespace to which the geometries belong.
locality: The locality to which the geometries belong. (e.g. 'pennsylvania')
layer: The layer to which the geometries belong. (e.g. 'vtd')
batch_size: The number of rows to import per batch.
max_conns: The maximum number of simultaneous connections to the API.

"""
if "geometry" in df.columns:
df = df.to_crs("epsg:4269") # import as lat/long
geos = dict(df.geometry)
else:
geos = {key: None for key in df.index}

# Augment geographies with internal points if available.
if "internal_point" in df.columns:
internal_points = dict(df.internal_point)
geos = {path: (geo, internal_points[path]) for path, geo in geos.items()}

try:
asyncio.run(_load_geos(self.geo, geos, namespace, batch_size, max_conns))

except Exception as e:
if str(e) == "Cannot create geographies that already exist.":
# TODO: Make this error more specific maybe?
raise e
raise e

if locality is not None and layer is not None:
self.geo_layers.map_locality(
layer=layer,
locality=locality,
geographies=[f"/{namespace}/{key}" for key in df.index],
)

def __validate_geos(
self,
df: Union[pd.DataFrame, gpd.GeoDataFrame],
locality: Union[str, Locality],
layer: Union[str, GeoLayer],
):
"""
A private method called by the `load_dataframe` method to validate that the passed
geometry paths exist in the database.

All of the geometry paths in the dataframe must exist in a single locality and in a
single layer. If they do not, this method will raise an error.

Args:
df: The dataframe containing the geometries to be added.
locality: The locality to which the geometries belong.
layer: The layer to which the geometries belong.

Raises:
ValueError: If the locality or layer is not provided.
ValueError: If the paths in the index of the dataframe do not match any of the paths in
the database.
ValueError: If there are paths missing from the dataframe compared to the paths for
the given locality and layer in the database. All geometries must be updated
at the same time to avoid unintentional null values.
"""
if locality is None or layer is None:
raise ValueError(
"Locality and layer must be provided if create_geo is False."
)

locality_path = ""
layer_path = ""

if isinstance(locality, Locality):
locality_path = locality.canonical_path
else:
locality_path = locality
if isinstance(layer, GeoLayer):
layer_path = layer.path
else:
layer_path = layer

known_paths = set(self.db.geo.all_paths(locality_path, layer_path))
df_paths = set(df.index)

if df_paths - known_paths == df_paths:
raise ValueError(
f"The index of the dataframe does not appear to match any geographies in the namespace "
f"which have the following geoid format: '{list(known_paths)[0] if len(known_paths) > 0 else None}'. "
f"Please ensure that the index of the dataframe matches the format of the geoid."
)

if df_paths - known_paths != set():
raise ValueError(
f"Failure in load_dataframe. Tried to import geographies for layer "
f"'{layer_path}' and locality '{locality_path}', but the following geographies "
f"do not exist in the namespace "
f"'{self.db.namespace}': {df_paths - known_paths}"
)

if known_paths - df_paths != set():
raise ValueError(
f"Failure in load_dataframe. Tried to import geographies for layer "
f"'{layer_path}' and locality '{locality_path}', but the passed dataframe "
f"does not contain the following geographies: "
f"{known_paths - df_paths}. "
f"Please provide values for these geographies in the dataframe."
)

def __validate_columns(self, columns):
"""
Private method called by the `load_dataframe` method to validate the columns
passed to the method.

This method makes sure that the columns passed to the method have the permissible
data types and that they exist in the database before we attempt to load values for
those columns.

Args:
columns: The columns to be loaded.

Raises:
ValueError: If the columns parameter is not a list, a pandas Index, or a dictionary.
ValueError: If the columns parameter is a dictionary and contains a value that is
not a Column object.
ValueError: If some of the columns in `columns` do not exist in the database.
This also looks for close matches to the columns in the database
and prints them out for the user.
"""

if not (
isinstance(columns, list)
or isinstance(columns, pdIndex)
or isinstance(columns, dict)
):
raise ValueError(
f"The columns parameter must be a list of paths, a pandas.core.indexes.base.Index, "
f"or a dictionary of paths to Column objects. "
f"Received type {type(columns)}."
)

if isinstance(columns, list) or isinstance(columns, pdIndex):
column_paths = []
for col in self.db.columns.all():
column_paths.append(col.canonical_path)
column_paths.extend(col.aliases)

column_paths = set(column_paths)
cur_columns = set([normalize_path(col) for col in columns])

for col in cur_columns:
if "/" in col:
raise ValueError(
f"Column paths passed to the `load_dataframe` function "
f"cannot contain '/'. Column '{col}' is invalid."
)

else:
for item in columns.values():
if not isinstance(item, Column):
raise ValueError(
f"The columns parameter must be a list of paths, a pandas.core.indexes.base.Index, "
f"or a dictionary of paths to Column objects. "
f"Found a dictionary with a value of type {type(item)}."
)

column_paths = {col.canonical_path for col in self.db.columns.all()}
cur_columns = set([v.canonical_path for v in columns.values()])

missing_cols = cur_columns - column_paths

if missing_cols != set():
for path in missing_cols:
best_matches = process.extract(path, column_paths, limit=5)
print(
f"Could not find column corresponding to '{path}', the best matches "
f"are: {[match[0] for match in best_matches]}"
)
raise ValueError(
f"Some of the columns in the dataframe do not exist in the database. "
f"Please create the missing columns first using the `db.columns.create` method."
)

def load_dataframe(
self,
df: Union[pd.DataFrame, gpd.GeoDataFrame],
columns: dict[str, Column],
columns: Union[pdIndex, list[str], dict[str, Column]],
*,
create_geo: bool = False,
namespace: Optional[str] = None,
Expand All @@ -343,7 +554,8 @@ def load_dataframe(
batch_size: int = 5000,
max_conns: int = 1,
) -> None:
"""Imports a DataFrame to GerryDB.
"""
Imports a DataFrame to GerryDB.

Plain DataFrames do not include rich column metadata, so the columns used
in the DataFrame must be defined before import.
Expand All @@ -366,8 +578,8 @@ def load_dataframe(
rows in `df`.

Args:
db: GerryDB client instance.
df: DataFrame to import column values and geographies from.
df: DataFrame to import column values and geographies from. The df MUST be indexed
by the geoid or the import will fail.
columns: Mapping between column names in `df` and GerryDB column metadata.
Only columns included in the mapping will be imported.
create_geo: Determines whether to create geographies from the DataFrame.
Expand All @@ -382,41 +594,30 @@ def load_dataframe(
raise ValueError("No namespace available.")

if create_geo:
if "geometry" in df.columns:
df = df.to_crs("epsg:4269") # import as lat/long
geos = dict(df.geometry)
else:
geos = {key: None for key in df.index}

# Augment geographies with internal points if available.
if "internal_point" in df.columns:
internal_points = dict(df.internal_point)
geos = {
path: (geo, internal_points[path]) for path, geo in geos.items()
}
self.__create_geos(
df=df,
namespace=namespace,
locality=locality,
layer=layer,
batch_size=batch_size,
max_conns=max_conns,
)

try:
asyncio.run(
_load_geos(self.geo, geos, namespace, batch_size, max_conns)
)
if not create_geo:
self.__validate_geos(df=df, locality=locality, layer=layer)

except Exception as e:
if str(e) == "Cannot create geographies that already exist.":
# TODO: Make this error more specific maybe?
raise e
raise e
self.__validate_columns(columns)

# TODO: Check to see if grabbing all of the columns and then filtering
# is significantly different from a data transfer perspective in the
# average case.
if not isinstance(columns, dict):
columns = {c: self.columns.get(c) for c in df.columns}

asyncio.run(
_load_column_values(self.columns, df, columns, batch_size, max_conns)
)

if create_geo and locality is not None and layer is not None:
self.geo_layers.map_locality(
layer=layer,
locality=locality,
geographies=[f"/{namespace}/{key}" for key in df.index],
)


# based on https://stackoverflow.com/a/61478547
async def gather_batch(coros, n):
Expand Down Expand Up @@ -466,13 +667,20 @@ async def _load_column_values(

val_batches: list[tuple[Column, dict[str, Any]]] = []
for col_name, col_meta in columns.items():
# This only works because the df is indexed by geography path.
col_vals = list(dict(df[col_name]).items())
for idx in range(0, len(df), batch_size):
val_batches.append((col_meta, dict(col_vals[idx : idx + batch_size])))

async with httpx.AsyncClient(**params) as client:
tasks = [
repo.async_set_values(col, col.namespace, values=batch, client=client)
repo.async_set_values(
path=col.path,
namespace=col.namespace,
col=col,
values=batch,
client=client,
)
for col, batch in val_batches
]
results = await gather_batch(tasks, max_conns)
Expand Down
Loading
Loading