Skip to content

Commit

Permalink
feat: allow passing extra param to execQuery (for direct URL access) (#…
Browse files Browse the repository at this point in the history
…14)

* refactor: update the json config parsing for clarity

* fix: handle data extract zip entirely in memory

* feat: allow passing extra params to queryExec, return URL if fgb

* test: for generating data extracts, zip and fgb formats

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: add extra error handling for extract respone json

* test: add conftest to init logger

* refactor: update error handling for data extract download

* feat: add optional support for auth token with remote query

* refactor: extra debug logging for raw data api polling

* fix: handle raw-data-api status PENDING and STARTED

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
spwoodcock and pre-commit-ci[bot] authored Feb 8, 2024
1 parent 9888975 commit 0785381
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 165 deletions.
85 changes: 46 additions & 39 deletions osm_rawdata/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,15 @@ def _yaml_parse_select_and_keep(self, data):
self.config["select"][table].append({tag: []})

def parseJson(self, config: Union[str, BytesIO]):
"""Parse the JSON format config file used by the raw-data-api
and export tool.
"""Parse the JSON format config file using the Underpass schema.
Args:
config (str, BytesIO): the file or BytesIO object to read.
Returns:
config (dict): the config data
"""
# Check the type of config and load data accordingly
if isinstance(config, str):
with open(config, "r") as config_file:
data = json.load(config_file)
Expand All @@ -203,51 +203,58 @@ def parseJson(self, config: Union[str, BytesIO]):
log.error(f"Unsupported config format: {config}")
raise ValueError(f"Invalid config {config}")

# Get the geometry
# Helper function to convert geometry names
def convert_geometry(geom):
if geom == "point":
return "nodes"
elif geom == "line":
return "ways_line"
elif geom == "polygon":
return "ways_poly"
return geom

# Extract geometry
self.geometry = shape(data["geometry"])

# Iterate through each key-value pair in the flattened dictionary
for key, value in flatdict.FlatDict(data).items():
keys = key.split(":")
# print(keys)
# print(f"\t{value}")
# We already have the geometry
if key[:8] == "geometry":
# Skip the keys related to geometry
if key.startswith("geometry"):
continue
# If it's a top-level key, directly update self.config
if len(keys) == 1:
self.config.update({key: value})
self.config[key] = value
continue
# keys[0] is currently always 'filters'
# keys[1] is currently 'tags' for the WHERE clause,
# of attributes for the SELECT
geom = keys[2]
# tag = keys[4]
# Get the geometry
if geom == "point":
geom = "nodes"
elif geom == "line":
geom = "ways_line"
elif geom == "polygon":
geom = "ways_poly"
if keys[1] == "attributes":
for v1 in value:
if geom == "all_geometry":
self.config["select"]["nodes"].append({v1: {}})
self.config["select"]["ways_line"].append({v1: {}})
self.config["select"]["ways_poly"].append({v1: {}})
self.config["tables"].append("nodes")
self.config["tables"].append("ways_poly")
self.config["tables"].append("ways_line")

# Extract meaningful parts from the key
section, subsection = keys[:2]
geom_type = keys[2] if len(keys) > 2 else None
tag_type = keys[3] if len(keys) > 3 else None
tag_name = keys[4] if len(keys) > 4 else None

# Convert geometry type to meaningful names
geom_type = convert_geometry(geom_type)

if subsection == "attributes":
# For attributes, update select fields and tables
for attribute_name in value:
if geom_type == "all_geometry":
for geometry_type in ["nodes", "ways_line", "ways_poly"]:
self.config["select"][geometry_type].append({attribute_name: {}})
self.config["tables"].append(geometry_type)
else:
self.config["tables"].append(geom)
self.config["select"][geom].append({v1: {}})
if keys[1] == "tags":
newtag = {keys[4]: value}
newtag["op"] = keys[3][5:]
if geom == "all_geometry":
self.config["where"]["nodes"].append(newtag)
self.config["where"]["ways_poly"].append(newtag)
self.config["where"]["ways_line"].append(newtag)
self.config["select"][geom_type].append({attribute_name: {}})
self.config["tables"].append(geom_type)
elif subsection == "tags":
# For tags, update where fields
option = tag_type[5:] if tag_type else None
new_tag = {tag_name: value, "op": option}
if geom_type == "all_geometry":
for geometry_type in ["nodes", "ways_line", "ways_poly"]:
self.config["where"][geometry_type].append(new_tag)
else:
self.config["where"][geom].append(newtag)
self.config["where"][geom_type].append(new_tag)

return self.config

Expand Down
65 changes: 34 additions & 31 deletions osm_rawdata/pgasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# <[email protected]>

import argparse
import asyncio
import json
import logging
import os
Expand All @@ -28,13 +29,11 @@
import zipfile
from io import BytesIO
from pathlib import Path
from sys import argv
from urllib.parse import urlparse

import asyncpg
import geojson
import requests
import asyncio
import asyncpg
from asyncpg import exceptions
from geojson import Feature, FeatureCollection, Polygon
from shapely import wkt
from shapely.geometry import Polygon, shape
Expand All @@ -48,39 +47,41 @@
# Instantiate logger
log = logging.getLogger(__name__)


class DatabaseAccess(object):
def __init__(self):
"""This is a class to setup a database connection."""
self.pg = None
self.dburi = None
self.qc = None

async def connect(self,
dburi: str,
):
async def connect(
self,
dburi: str,
):
self.dburi = dict()
uri = urlparse(dburi)
if not uri.username:
self.dburi['dbuser'] = os.getenv("PGUSER", default=None)
if not self.dburi['dbuser']:
log.error(f"You must specify the user name in the database URI, or set PGUSER")
self.dburi["dbuser"] = os.getenv("PGUSER", default=None)
if not self.dburi["dbuser"]:
log.error("You must specify the user name in the database URI, or set PGUSER")
else:
self.dburi['dbuser'] = uri.username
self.dburi["dbuser"] = uri.username
if not uri.password:
self.dburi['dbpass'] = os.getenv("PGPASSWORD", default=None)
if not self.dburi['dbpass']:
log.error(f"You must specify the user password in the database URI, or set PGPASSWORD")
self.dburi["dbpass"] = os.getenv("PGPASSWORD", default=None)
if not self.dburi["dbpass"]:
log.error("You must specify the user password in the database URI, or set PGPASSWORD")
else:
self.dburi['dbpass'] = uri.password
self.dburi["dbpass"] = uri.password
if not uri.hostname:
self.dburi['dbhost'] = os.getenv("PGHOST", default="localhost")
self.dburi["dbhost"] = os.getenv("PGHOST", default="localhost")
else:
self.dburi['dbhost'] = uri.hostname
self.dburi["dbhost"] = uri.hostname

slash = uri.path.find('/')
self.dburi['dbname'] = uri.path[slash + 1:]
slash = uri.path.find("/")
self.dburi["dbname"] = uri.path[slash + 1 :]
connect = f"postgres://{self.dburi['dbuser']}:{ self.dburi['dbpass']}@{self.dburi['dbhost']}/{self.dburi['dbname']}"

if self.dburi["dbname"] == "underpass":
# Authentication data
# self.auth = HTTPBasicAuth(self.user, self.passwd)
Expand Down Expand Up @@ -292,11 +293,11 @@ async def createTable(

return True

async def execute(self,
sql: str,
):
"""
Execute a raw SQL query and return the results.
async def execute(
self,
sql: str,
):
"""Execute a raw SQL query and return the results.
Args:
sql (str): The SQL to execute
Expand Down Expand Up @@ -441,17 +442,18 @@ def __init__(
# output: str = None
):
"""This is a client for a postgres database.
Returns:
(PostgresClient): An instance of this class
"""
super().__init__()
self.qc = None

async def loadConfig(self,
config: str,
):
"""
Load the JSON or YAML config file that defines the SQL query
async def loadConfig(
self,
config: str,
):
"""Load the JSON or YAML config file that defines the SQL query
Args:
config (str): The filespec for the query config file
Expand Down Expand Up @@ -534,6 +536,7 @@ async def execQuery(
collection = await self.queryRemote(request)
return collection


async def main():
"""This main function lets this class be run standalone by a bash script."""
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -601,9 +604,9 @@ async def main():

log.debug(f"Wrote {args.outfile}")


if __name__ == "__main__":
"""This is just a hook so this file can be run standalone during development."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())

Loading

0 comments on commit 0785381

Please sign in to comment.