-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from hotosm/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
- Loading branch information
Showing
3 changed files
with
109 additions
and
107 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,121 +20,116 @@ | |
# <[email protected]> | ||
|
||
import argparse | ||
import concurrent.futures | ||
import logging | ||
import subprocess | ||
import sys | ||
import os | ||
import concurrent.futures | ||
import geojson | ||
from geojson import Feature, FeatureCollection | ||
from sys import argv | ||
from pathlib import Path | ||
from cpuinfo import get_cpu_info | ||
from shapely.geometry import shape | ||
from sys import argv | ||
|
||
# from geoalchemy2 import shape | ||
import geoalchemy2 | ||
import shapely | ||
|
||
from pandas import DataFrame | ||
import pyarrow.parquet as pq | ||
import geojson | ||
from codetiming import Timer | ||
from osm_rawdata.postgres import uriParser | ||
from progress.spinner import PixelSpinner | ||
from cpuinfo import get_cpu_info | ||
from pandas import DataFrame | ||
from shapely import wkb | ||
from shapely.geometry import shape | ||
from sqlalchemy import MetaData, cast, column, create_engine, select, table, text | ||
from sqlalchemy.dialects.postgresql import JSONB, insert | ||
from sqlalchemy.engine.base import Connection | ||
from sqlalchemy.orm import sessionmaker | ||
from sqlalchemy_utils import create_database, database_exists | ||
from sqlalchemy.engine.base import Connection | ||
from shapely.geometry import Point, LineString, Polygon | ||
from shapely import wkt, wkb | ||
|
||
# Find the other files for this project | ||
import osm_rawdata as rw | ||
import osm_rawdata.db_models | ||
from osm_rawdata.db_models import Base | ||
from osm_rawdata.overture import Overture | ||
from osm_rawdata.postgres import uriParser | ||
|
||
rootdir = rw.__path__[0] | ||
|
||
# Instantiate logger | ||
log = logging.getLogger('osm-rawdata') | ||
log = logging.getLogger("osm-rawdata") | ||
|
||
# The number of threads is based on the CPU cores | ||
info = get_cpu_info() | ||
cores = info['count'] | ||
cores = info["count"] | ||
|
||
|
||
def importThread( | ||
data: list, | ||
db: Connection, | ||
): | ||
data: list, | ||
db: Connection, | ||
): | ||
"""Thread to handle importing | ||
Args: | ||
data (list): The list of tiles to download | ||
db (Connection): A database connection | ||
""" | ||
# log.debug(f"In importThread()") | ||
#timer = Timer(text="importThread() took {seconds:.0f}s") | ||
#timer.start() | ||
# timer = Timer(text="importThread() took {seconds:.0f}s") | ||
# timer.start() | ||
ways = table( | ||
"ways_poly", | ||
column("id"), | ||
column("user"), | ||
column("geom"), | ||
column("tags"), | ||
) | ||
) | ||
|
||
nodes = table( | ||
"nodes", | ||
column("id"), | ||
column("user"), | ||
column("geom"), | ||
column("tags"), | ||
) | ||
) | ||
|
||
nodes = table( | ||
"ways_line", | ||
column("id"), | ||
column("user"), | ||
column("geom"), | ||
column("tags"), | ||
) | ||
) | ||
|
||
index = 0 | ||
|
||
for feature in data: | ||
# log.debug(feature) | ||
index -= 1 | ||
entry = dict() | ||
tags = feature['properties'] | ||
tags['building'] = 'yes' | ||
entry['id'] = index | ||
tags = feature["properties"] | ||
tags["building"] = "yes" | ||
entry["id"] = index | ||
ewkt = shape(feature["geometry"]) | ||
geom = wkb.dumps(ewkt) | ||
type = ewkt.geom_type | ||
scalar = select(cast(tags, JSONB)) | ||
|
||
if type == 'Polygon': | ||
if type == "Polygon": | ||
sql = insert(ways).values( | ||
# id = entry['id'], | ||
geom=geom, | ||
tags=scalar, | ||
) | ||
elif type == 'Point': | ||
) | ||
elif type == "Point": | ||
sql = insert(nodes).values( | ||
# id = entry['id'], | ||
geom=geom, | ||
tags=scalar, | ||
) | ||
) | ||
|
||
db.execute(sql) | ||
# db.commit() | ||
|
||
|
||
def parquetThread( | ||
data: DataFrame, | ||
db: Connection, | ||
): | ||
): | ||
"""Thread to handle importing | ||
Args: | ||
|
@@ -149,61 +144,61 @@ def parquetThread( | |
column("user"), | ||
column("geom"), | ||
column("tags"), | ||
) | ||
) | ||
|
||
nodes = table( | ||
"nodes", | ||
column("id"), | ||
column("user"), | ||
column("geom"), | ||
column("tags"), | ||
) | ||
) | ||
|
||
lines = table( | ||
"ways_line", | ||
column("id"), | ||
column("user"), | ||
column("geom"), | ||
column("tags"), | ||
) | ||
) | ||
|
||
index = -1 | ||
log.debug(f"There are {len(data)} entries in the data") | ||
if len(data) == 0: | ||
if len(data) == 0: | ||
return | ||
|
||
overture = Overture() | ||
for index in data.index: | ||
feature = data.loc[index] | ||
dataset = feature['sources'][0]['dataset'] | ||
if dataset == 'OpenStreetMap' or dataset == 'Microsoft ML Buildings': | ||
dataset = feature["sources"][0]["dataset"] | ||
if dataset == "OpenStreetMap" or dataset == "Microsoft ML Buildings": | ||
continue | ||
tags = overture.parse(feature) | ||
geom = feature['geometry'] | ||
geom = feature["geometry"] | ||
hex = wkb.loads(geom, hex=True) | ||
gdata = geoalchemy2.shape.from_shape(hex, srid=4326, extended=True) | ||
# geom_type = wkb.loads(geom).geom_type | ||
scalar = select(cast(tags['properties'], JSONB)) | ||
scalar = select(cast(tags["properties"], JSONB)) | ||
sql = None | ||
if hex.geom_type == 'Polygon': | ||
if hex.geom_type == "Polygon": | ||
sql = insert(ways).values( | ||
# osm_id = entry['osm_id'], | ||
geom=bytes(gdata.data), | ||
tags=scalar, | ||
) | ||
elif hex.geom_type == 'MultiPolygon': | ||
elif hex.geom_type == "MultiPolygon": | ||
gdata = geoalchemy2.shape.from_shape(hex.convex_hull, srid=4326, extended=True) | ||
sql = insert(ways).values( | ||
geom=bytes(gdata.data), | ||
tags=scalar, | ||
) | ||
elif hex.geom_type == 'Point': | ||
elif hex.geom_type == "Point": | ||
sql = insert(nodes).values( | ||
# osm_id = entry['osm_id'], | ||
geom=bytes(gdata.data), | ||
tags=scalar, | ||
) | ||
elif hex.geom_type == 'LineString': | ||
elif hex.geom_type == "LineString": | ||
sql = insert(lines).values( | ||
# osm_id = entry['osm_id'], | ||
geom=bytes(gdata.data), | ||
|
@@ -219,6 +214,7 @@ def parquetThread( | |
# print(f"FIXME2: {entry}") | ||
timer.stop() | ||
|
||
|
||
class MapImporter(object): | ||
def __init__( | ||
self, | ||
|
@@ -246,14 +242,14 @@ def __init__( | |
meta = MetaData() | ||
meta.create_all(engine) | ||
|
||
# if dburi: | ||
# self.uri = uriParser(dburi) | ||
# engine = create_engine(f"postgresql://{self.dburi}", echo=True) | ||
# if not database_exists(engine.url): | ||
# create_database(engine.url) | ||
# self.db = engine.connect() | ||
# if dburi: | ||
# self.uri = uriParser(dburi) | ||
# engine = create_engine(f"postgresql://{self.dburi}", echo=True) | ||
# if not database_exists(engine.url): | ||
# create_database(engine.url) | ||
# self.db = engine.connect() | ||
|
||
# Add the extension we need to process the data | ||
# Add the extension we need to process the data | ||
sql = text( | ||
"CREATE EXTENSION IF NOT EXISTS postgis; CREATE EXTENSION IF NOT EXISTS hstore;CREATE EXTENSION IF NOT EXISTS dblink;" | ||
) | ||
|
@@ -364,8 +360,8 @@ def importGeoJson( | |
""" | ||
# load the GeoJson file | ||
file = open(infile, "r") | ||
#size = os.path.getsize(infile) | ||
#for line in file.readlines(): | ||
# size = os.path.getsize(infile) | ||
# for line in file.readlines(): | ||
# print(line) | ||
data = geojson.load(file) | ||
|
||
|
@@ -378,27 +374,28 @@ def importGeoJson( | |
timer.start() | ||
|
||
# A chunk is a group of threads | ||
entries = len(data['features']) | ||
entries = len(data["features"]) | ||
chunk = round(entries / cores) | ||
|
||
# For small files we only need one thread | ||
if entries <= chunk: | ||
result = importThread(data['features'], self.connections[0]) | ||
result = importThread(data["features"], self.connections[0]) | ||
timer.stop() | ||
return True | ||
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=cores) as executor: | ||
block = 0 | ||
while block <= entries: | ||
log.debug("Dispatching Block %d:%d" % (block, block + chunk)) | ||
result = executor.submit(importThread, data['features'][block : block + chunk], self.connections[index]) | ||
result = executor.submit(importThread, data["features"][block : block + chunk], self.connections[index]) | ||
block += chunk | ||
index += 1 | ||
executor.shutdown() | ||
timer.stop() | ||
|
||
return True | ||
|
||
|
||
def main(): | ||
"""This main function lets this class be run standalone by a bash script.""" | ||
parser = argparse.ArgumentParser( | ||
|
@@ -445,6 +442,7 @@ def main(): | |
mi.importParquet(args.infile) | ||
log.info(f"Imported {args.infile} into {args.uri}") | ||
|
||
|
||
if __name__ == "__main__": | ||
"""This is just a hook so this file can be run standalone during development.""" | ||
main() |
Oops, something went wrong.