Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 30, 2023
1 parent 4cffd35 commit 682794c
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 104 deletions.
114 changes: 56 additions & 58 deletions osm_rawdata/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,121 +20,116 @@
# <[email protected]>

import argparse
import concurrent.futures
import logging
import subprocess
import sys
import os
import concurrent.futures
import geojson
from geojson import Feature, FeatureCollection
from sys import argv
from pathlib import Path
from cpuinfo import get_cpu_info
from shapely.geometry import shape
from sys import argv

# from geoalchemy2 import shape
import geoalchemy2
import shapely

from pandas import DataFrame
import pyarrow.parquet as pq
import geojson
from codetiming import Timer
from osm_rawdata.postgres import uriParser
from progress.spinner import PixelSpinner
from cpuinfo import get_cpu_info
from pandas import DataFrame
from shapely import wkb
from shapely.geometry import shape
from sqlalchemy import MetaData, cast, column, create_engine, select, table, text
from sqlalchemy.dialects.postgresql import JSONB, insert
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import create_database, database_exists
from sqlalchemy.engine.base import Connection
from shapely.geometry import Point, LineString, Polygon
from shapely import wkt, wkb

# Find the other files for this project
import osm_rawdata as rw
import osm_rawdata.db_models
from osm_rawdata.db_models import Base
from osm_rawdata.overture import Overture
from osm_rawdata.postgres import uriParser

rootdir = rw.__path__[0]

# Instantiate logger
log = logging.getLogger('osm-rawdata')
log = logging.getLogger("osm-rawdata")

# The number of threads is based on the CPU cores
info = get_cpu_info()
cores = info['count']
cores = info["count"]


def importThread(
data: list,
db: Connection,
):
data: list,
db: Connection,
):
"""Thread to handle importing
Args:
data (list): The list of tiles to download
db (Connection): A database connection
"""
# log.debug(f"In importThread()")
#timer = Timer(text="importThread() took {seconds:.0f}s")
#timer.start()
# timer = Timer(text="importThread() took {seconds:.0f}s")
# timer.start()
ways = table(
"ways_poly",
column("id"),
column("user"),
column("geom"),
column("tags"),
)
)

nodes = table(
"nodes",
column("id"),
column("user"),
column("geom"),
column("tags"),
)
)

nodes = table(
"ways_line",
column("id"),
column("user"),
column("geom"),
column("tags"),
)
)

index = 0

for feature in data:
# log.debug(feature)
index -= 1
entry = dict()
tags = feature['properties']
tags['building'] = 'yes'
entry['id'] = index
tags = feature["properties"]
tags["building"] = "yes"
entry["id"] = index
ewkt = shape(feature["geometry"])
geom = wkb.dumps(ewkt)
type = ewkt.geom_type
scalar = select(cast(tags, JSONB))

if type == 'Polygon':
if type == "Polygon":
sql = insert(ways).values(
# id = entry['id'],
geom=geom,
tags=scalar,
)
elif type == 'Point':
)
elif type == "Point":
sql = insert(nodes).values(
# id = entry['id'],
geom=geom,
tags=scalar,
)
)

db.execute(sql)
# db.commit()


def parquetThread(
data: DataFrame,
db: Connection,
):
):
"""Thread to handle importing
Args:
Expand All @@ -149,61 +144,61 @@ def parquetThread(
column("user"),
column("geom"),
column("tags"),
)
)

nodes = table(
"nodes",
column("id"),
column("user"),
column("geom"),
column("tags"),
)
)

lines = table(
"ways_line",
column("id"),
column("user"),
column("geom"),
column("tags"),
)
)

index = -1
log.debug(f"There are {len(data)} entries in the data")
if len(data) == 0:
if len(data) == 0:
return

overture = Overture()
for index in data.index:
feature = data.loc[index]
dataset = feature['sources'][0]['dataset']
if dataset == 'OpenStreetMap' or dataset == 'Microsoft ML Buildings':
dataset = feature["sources"][0]["dataset"]
if dataset == "OpenStreetMap" or dataset == "Microsoft ML Buildings":
continue
tags = overture.parse(feature)
geom = feature['geometry']
geom = feature["geometry"]
hex = wkb.loads(geom, hex=True)
gdata = geoalchemy2.shape.from_shape(hex, srid=4326, extended=True)
# geom_type = wkb.loads(geom).geom_type
scalar = select(cast(tags['properties'], JSONB))
scalar = select(cast(tags["properties"], JSONB))
sql = None
if hex.geom_type == 'Polygon':
if hex.geom_type == "Polygon":
sql = insert(ways).values(
# osm_id = entry['osm_id'],
geom=bytes(gdata.data),
tags=scalar,
)
elif hex.geom_type == 'MultiPolygon':
elif hex.geom_type == "MultiPolygon":
gdata = geoalchemy2.shape.from_shape(hex.convex_hull, srid=4326, extended=True)
sql = insert(ways).values(
geom=bytes(gdata.data),
tags=scalar,
)
elif hex.geom_type == 'Point':
elif hex.geom_type == "Point":
sql = insert(nodes).values(
# osm_id = entry['osm_id'],
geom=bytes(gdata.data),
tags=scalar,
)
elif hex.geom_type == 'LineString':
elif hex.geom_type == "LineString":
sql = insert(lines).values(
# osm_id = entry['osm_id'],
geom=bytes(gdata.data),
Expand All @@ -219,6 +214,7 @@ def parquetThread(
# print(f"FIXME2: {entry}")
timer.stop()


class MapImporter(object):
def __init__(
self,
Expand Down Expand Up @@ -246,14 +242,14 @@ def __init__(
meta = MetaData()
meta.create_all(engine)

# if dburi:
# self.uri = uriParser(dburi)
# engine = create_engine(f"postgresql://{self.dburi}", echo=True)
# if not database_exists(engine.url):
# create_database(engine.url)
# self.db = engine.connect()
# if dburi:
# self.uri = uriParser(dburi)
# engine = create_engine(f"postgresql://{self.dburi}", echo=True)
# if not database_exists(engine.url):
# create_database(engine.url)
# self.db = engine.connect()

# Add the extension we need to process the data
# Add the extension we need to process the data
sql = text(
"CREATE EXTENSION IF NOT EXISTS postgis; CREATE EXTENSION IF NOT EXISTS hstore;CREATE EXTENSION IF NOT EXISTS dblink;"
)
Expand Down Expand Up @@ -364,8 +360,8 @@ def importGeoJson(
"""
# load the GeoJson file
file = open(infile, "r")
#size = os.path.getsize(infile)
#for line in file.readlines():
# size = os.path.getsize(infile)
# for line in file.readlines():
# print(line)
data = geojson.load(file)

Expand All @@ -378,27 +374,28 @@ def importGeoJson(
timer.start()

# A chunk is a group of threads
entries = len(data['features'])
entries = len(data["features"])
chunk = round(entries / cores)

# For small files we only need one thread
if entries <= chunk:
result = importThread(data['features'], self.connections[0])
result = importThread(data["features"], self.connections[0])
timer.stop()
return True

with concurrent.futures.ThreadPoolExecutor(max_workers=cores) as executor:
block = 0
while block <= entries:
log.debug("Dispatching Block %d:%d" % (block, block + chunk))
result = executor.submit(importThread, data['features'][block : block + chunk], self.connections[index])
result = executor.submit(importThread, data["features"][block : block + chunk], self.connections[index])
block += chunk
index += 1
executor.shutdown()
timer.stop()

return True


def main():
"""This main function lets this class be run standalone by a bash script."""
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -445,6 +442,7 @@ def main():
mi.importParquet(args.infile)
log.info(f"Imported {args.infile} into {args.uri}")


if __name__ == "__main__":
"""This is just a hook so this file can be run standalone during development."""
main()
Loading

0 comments on commit 682794c

Please sign in to comment.