Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update load_dataframe #1

Merged
merged 10 commits into from
Oct 10, 2024
63 changes: 57 additions & 6 deletions gerrydb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import httpx
import pandas as pd
import tomlkit
from shapely import Point
from shapely.geometry.base import BaseGeometry

from gerrydb.cache import GerryCache
from gerrydb.exceptions import ConfigError
Expand Down Expand Up @@ -368,6 +366,8 @@ def load_dataframe(
Args:
db: GerryDB client instance.
df: DataFrame to import column values and geographies from.
The df MUST be indexed by the geoid or the import will not
work correctly.
columns: Mapping between column names in `df` and GerryDB column metadata.
Only columns included in the mapping will be imported.
create_geo: Determines whether to create geographies from the DataFrame.
Expand Down Expand Up @@ -405,10 +405,50 @@ def load_dataframe(
# TODO: Make this error more specific maybe?
raise e
raise e
else:
if locality is None or layer is None:
raise ValueError(
"Locality and layer must be provided if create_geo is False."
)

asyncio.run(
_load_column_values(self.columns, df, columns, batch_size, max_conns)
)
locality_path = ""
layer_path = ""

if isinstance(locality, Locality):
locality_path = locality.canonical_path
else:
locality_path = locality
if isinstance(layer, GeoLayer):
layer_path = layer.path
else:
layer_path = layer

known_paths = set(self.db.geo.all_paths(locality_path, layer_path))
df_paths = set(df.index)

if df_paths - known_paths == df_paths:
raise ValueError(
f"The index of the dataframe does not appear to match any geographies in the namespace "
f"which have the following geoid format: '{list(known_paths)[0] if len(known_paths) > 0 else None}'. "
f"Please ensure that the index of the dataframe matches the format of the geoid."
)

if df_paths - known_paths != set():
raise ValueError(
f"Failure in load_dataframe. Tried to import geographies for layer "
f"'{layer_path}' and locality '{locality_path}', but the following geographies "
f"do not exist in the namespace "
f"'{self.db.namespace}': {df_paths - known_paths}"
)

if known_paths - df_paths != set():
raise ValueError(
f"Failure in load_dataframe. Tried to import geographies for layer "
f"'{layer_path}' and locality '{locality_path}', but the passed dataframe "
f"does not contain the following geographies: "
f"{known_paths - df_paths}. "
f"Please provide values for these geographies in the dataframe."
)

if create_geo and locality is not None and layer is not None:
self.geo_layers.map_locality(
Expand All @@ -417,6 +457,10 @@ def load_dataframe(
geographies=[f"/{namespace}/{key}" for key in df.index],
)

asyncio.run(
_load_column_values(self.columns, df, columns, batch_size, max_conns)
)


# based on https://stackoverflow.com/a/61478547
async def gather_batch(coros, n):
Expand Down Expand Up @@ -466,13 +510,20 @@ async def _load_column_values(

val_batches: list[tuple[Column, dict[str, Any]]] = []
for col_name, col_meta in columns.items():
# This only works because the df is indexed by geography path.
col_vals = list(dict(df[col_name]).items())
for idx in range(0, len(df), batch_size):
val_batches.append((col_meta, dict(col_vals[idx : idx + batch_size])))

async with httpx.AsyncClient(**params) as client:
tasks = [
repo.async_set_values(col, col.namespace, values=batch, client=client)
repo.async_set_values(
path=col.path,
namespace=col.namespace,
col=col,
values=batch,
client=client,
)
for col, batch in val_batches
]
results = await gather_batch(tasks, max_conns)
Expand Down
56 changes: 38 additions & 18 deletions gerrydb/repos/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,18 @@ def update(
@online
def set_values(
self,
path_or_col: Union[Column, str],
path: Optional[str],
namespace: Optional[str] = None,
*,
col: Optional[Column] = None,
values: dict[Union[str, Geography], Any],
) -> None:
"""Sets the values of a column on a collection of geographies.

Args:
path_or_col: Short identifier for the column or a `Column` metadata object.
path: Short identifier for the column. Only this or `col` should be provided.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update this comment about the behavior of the function when both path and col are passed?

col: `Column` metadata object. If the `path` is not provided, the column's
path will be used.
namespace: Namespace of the column (used when `path_or_col` is a raw path).
values:
A mapping from geography paths or `Geography` metadata objects
Expand All @@ -137,7 +140,14 @@ def set_values(
Raises:
RequestError: If the values cannot be set on the server side.
"""
path = path_or_col.path if isinstance(path_or_col, Column) else path_or_col

assert path is None or isinstance(path, str)
assert col is None or isinstance(col, Column)

if path is None and col is None:
raise ValueError("Either `path` or `col` must be provided.")

path = col.path if col is not None else path

response = self.ctx.client.put(
f"{self.base_url}/{namespace}/{path}",
Expand All @@ -163,16 +173,19 @@ def set_values(
@online
async def async_set_values(
self,
path_or_col: Union[Column, str],
path: Optional[str],
namespace: Optional[str] = None,
*,
col: Optional[Column] = None,
values: dict[Union[str, Geography], Any],
client: Optional[httpx.AsyncClient] = None,
) -> None:
"""Asynchronously sets the values of a column on a collection of geographies.

Args:
path_or_col: Short identifier for the column or a `Column` metadata object.
path: Short identifier for the column. Only this or `col` should be provided.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, pls clarify behavior for different combinations of path or col provided

col: `Column` metadata object. If the `path` is not provided, the column's
path will be used.
namespace: Namespace of the column (used when `path_or_col` is a raw path).
values:
A mapping from geography paths or `Geography` metadata objects
Expand All @@ -183,31 +196,38 @@ async def async_set_values(
Raises:
RequestError: If the values cannot be set on the server side.
"""
path = path_or_col.path if isinstance(path_or_col, Column) else path_or_col
assert path is None or isinstance(path, str)
assert col is None or isinstance(col, Column)

if path is None and col is None:
raise ValueError("Either `path` or `col` must be provided.")

path = col.path if col is not None else path

ephemeral_client = client is None
if ephemeral_client:
params = self.ctx.client_params.copy()
params["transport"] = httpx.AsyncHTTPTransport(retries=1)
client = httpx.AsyncClient(**params)

json = [
ColumnValue(
path=(
f"/{geo.namespace}/{geo.path}"
if isinstance(geo, Geography)
else geo
),
value=_coerce(value),
).dict()
for geo, value in values.items()
]
response = await client.put(
f"{self.base_url}/{namespace}/{path}",
json=[
ColumnValue(
path=(
f"/{geo.namespace}/{geo.path}"
if isinstance(geo, Geography)
else geo
),
value=_coerce(value),
).dict()
for geo, value in values.items()
],
json=json,
)

if response.status_code != 204:
log.debug(f"For {path_or_col} returned {response}")
log.debug(f"For path = {path} and col = {col} returned {response}")

response.raise_for_status()

Expand Down
12 changes: 12 additions & 0 deletions gerrydb/repos/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
err,
online,
write_context,
namespaced,
)
from gerrydb.schemas import Geography, GeographyCreate, GeoImport

Expand Down Expand Up @@ -257,3 +258,14 @@ def async_bulk(
return AsyncGeoImporter(repo=self, namespace=namespace, max_conns=max_conns)

# TODO: get()
@namespaced
@online
@err("Failed to load geographies")
def all_paths(self, fips: str, layer_name: str) -> list[str]:
response = self.session.client.get(
f"/__list_geo/{self.session.namespace}/{fips}/{layer_name}"
)
response.raise_for_status()
response_json = response.json()

return response_json
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ authors = ["Parker J. Rule <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.9"
tomlkit = "^0.11.6"
msgpack = "^1.0.4"
httpx = "^0.23.3"
tomlkit = "^0.13.0
msgpack = "^1.1.0"
httpx = "^0.27.2"
pydantic = "^1.10.4"
orjson = "^3.8.6"
orjson = "^3.10.0"
shapely = "^2.0.1"
python-dateutil = "^2.8.2"
geopandas = "^0.12.2"
geopandas = "^1.0.1"
networkx = "^3.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.2.1"
black = "^23.1.0"
black = "^24.8.0"
pytest-vcr = "^1.0.2"

[tool.isort]
Expand Down