Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ls s2 fc poly drill #29

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions datacube-wps-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ processes:
fuse_func: datacube_wps.processes.wofls_fuser

style:
csv: None
table:
columns:
Wet:
Expand Down Expand Up @@ -107,7 +106,6 @@ processes:
resolution: [30, -30]
output_crs: EPSG:3577
style:
csv: None
table:
columns:
Bare Soil:
Expand Down Expand Up @@ -146,7 +144,6 @@ processes:
resolution: [30, -30]

style:
csv: None
table:
columns:
Woodland:
Expand Down Expand Up @@ -206,5 +203,33 @@ processes:
- product: ga_ls_wo_3
measurements: [water]
style:
csv: None
table: None


- process: datacube_wps.processes.ls_s2_fc_drill.LS_S2_FC_Drill

about:
identifier: LS S2 FC Drill
version: '0.1'
title: Landsat/Sentinel-2 Fractional Cover Drill
abstract: Performs Landsat/Sentinel-2 Fractional Cover Drill
store_supported: False
status_supported: True
geometry_type: polygon

input:
reproject:
output_crs: EPSG:3577
resolution: [-30, 30]
resampling: nearest
input:
product: ls_s2_fc_c3
measurements: [tc]

style:
csv: True
table:
columns:
Total Cover %:
units: "#"
chartLineColor: "#3B7F00"
active: True
159 changes: 61 additions & 98 deletions datacube_wps/processes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections import Counter

import altair
# import altair_saver
import boto3
import botocore
import datacube
Expand All @@ -18,7 +17,7 @@
import rasterio.features
import xarray
from botocore.client import Config
from dask.distributed import Client, worker_client
from dask.distributed import Client
from datacube.utils.geometry import CRS, Geometry
from datacube.utils.rio import configure_s3_access
from datacube.virtual.impl import Product, Juxtapose
Expand Down Expand Up @@ -77,26 +76,16 @@ def log_wrapper(*args, **kwargs):

@log_call
def _uploadToS3(filename, data, mimetype):
# AWS_S3_CREDS = {
# "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
# "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
# }
# s3 = session.client("s3", **AWS_S3_CREDS)
session = boto3.Session(profile_name="default")
bucket = config.get_config_value("s3", "bucket")
s3 = session.client("s3")

# bucket = s3.Bucket('test-wps')

s3.upload_fileobj(
data,
bucket,
filename,
ExtraArgs={"ACL": "public-read", "ContentType": mimetype},
)

print('Made it to before the presigned url generation')
bucket = config.get_config_value("s3", "bucket")
# Create unsigned s3 client for determining public s3 url
s3 = session.client("s3", config=Config(signature_version=botocore.UNSIGNED))
return s3.generate_presigned_url(
Expand All @@ -108,14 +97,14 @@ def _uploadToS3(filename, data, mimetype):

def upload_chart_html_to_S3(chart: altair.Chart, process_id: str):
html_io = io.StringIO()
chart.save(html_io, format="html", engine="vl-convert")
chart.save(html_io, format="html")#, engine="vl-convert")
html_bytes = io.BytesIO(html_io.getvalue().encode())
return _uploadToS3(process_id + "/chart.html", html_bytes, "text/html")


def upload_chart_svg_to_S3(chart: altair.Chart, process_id: str):
img_io = io.StringIO()
chart.save(img_io, format="svg", engine="vl-convert")
chart.save(img_io, format="svg")#, engine="vl-convert")
img_bytes = io.BytesIO(img_io.getvalue().encode())
return _uploadToS3(process_id + "/chart.svg", img_bytes, "image/svg+xml")

Expand Down Expand Up @@ -193,7 +182,6 @@ def _guard_rail(input, box):
byte_count *= x
byte_count *= sum(np.dtype(m.dtype).itemsize for m in measurement_dicts.values())

print("byte count for query: ", byte_count)
if byte_count > MAX_BYTES_IN_GB * GB:
raise ProcessError(
("requested area requires {}GB data to load - " "maximum is {}GB").format(
Expand All @@ -203,7 +191,6 @@ def _guard_rail(input, box):

grouped = box.box

print("grouped shape", grouped.shape)
assert len(grouped.shape) == 1

if grouped.shape[0] == 0:
Expand Down Expand Up @@ -364,13 +351,7 @@ def _render_outputs(


def _populate_response(response, outputs):
print('before response is populated')
print(response.outputs)
print('------------')
for ident, output_value in outputs.items():
print('TESTING')
print(ident)
print(output_value)
if ident in response.outputs:
if "data" in output_value:
response.outputs[ident].data = output_value["data"]
Expand Down Expand Up @@ -407,16 +388,6 @@ def __init__(self, about, input, style):
self.style = style
self.json_version = "v8"

# self.dask_client = dask_client = Client(
# n_workers=num_dask_workers(), processes=True, threads_per_worker=1
# )

self.dask_enabled = True

if self.dask_enabled:
# get the Dask Client associated with the current Gunicorn worker
self.dask_client = worker_client()

def input_formats(self):
return [
ComplexInput(
Expand All @@ -443,29 +414,35 @@ def request_handler(self, request, response):
parameters = _get_parameters(request)

result = self.query_handler(time, feature, parameters=parameters)
if self.style['csv']:

if 'csv' in self.style:
outputs = self.render_outputs(result["data"], None)

elif self.style['table']:
elif 'table' in self.style:
outputs = self.render_outputs(result["data"], result["chart"])

raise ProcessError('No output style configured for process!')

_populate_response(response, outputs)
return response

@log_call
def query_handler(self, time, feature, parameters=None):
def query_handler(self, time, feature, dask_client=None, parameters=None):
if parameters is None:
parameters = {}

configure_s3_access(
# aws_unsigned=True,
region_name=os.getenv("AWS_DEFAULT_REGION", "auto"),
client=self.dask_client,
)
if dask_client is None:
dask_client = Client(
n_workers=1, processes=False, threads_per_worker=num_workers()
)

with dask_client:
configure_s3_access(
aws_unsigned=True,
region_name=os.getenv("AWS_DEFAULT_REGION", "auto"),
client=dask_client,
)

with datacube.Datacube() as dc:
data = self.input_data(dc, time, feature)
with datacube.Datacube() as dc:
data = self.input_data(dc, time, feature)

df = self.process_data(data, {"time": time, "feature": feature, **parameters})
chart = self.render_chart(df)
Expand Down Expand Up @@ -493,11 +470,8 @@ def input_data(self, dc, time, feature):
lonlat = feature.coords[0]
measurements = self.input.output_measurements(bag.product_definitions)

if self.dask_enabled:
data = self.input.fetch(box, dask_chunks={"time": 1})
data = data.compute()
else:
data = self.input.fetch(box)
data = self.input.fetch(box, dask_chunks={"time": 1})
data = data.compute()

coords = {
"longitude": np.array([lonlat[0]]),
Expand Down Expand Up @@ -567,15 +541,6 @@ def __init__(self, about, input, style):
self.mask_all_touched = False
self.json_version = "v8"

# self.dask_client = dask_client = Client(
# n_workers=num_dask_workers(), processes=True, threads_per_worker=1
# )
self.dask_enabled = True

if self.dask_enabled:
# get the Dask Client associated with the current Gunicorn worker
self.dask_client = worker_client()

def input_formats(self):
return [
ComplexInput(
Expand Down Expand Up @@ -604,90 +569,88 @@ def request_handler(self, request, response):

result = self.query_handler(time, feature, parameters=parameters)

if self.style['csv']:
if 'csv' in self.style:
outputs = self.render_outputs(result["data"], None)

elif self.style['table']:
elif 'table' in self.style:
outputs = self.render_outputs(result["data"], result["chart"])

else:
raise ProcessError('No output style configured for process!')

_populate_response(response, outputs)
return response

@log_call
def query_handler(self, time, feature, parameters=None):
def query_handler(self, time, feature, dask_client=None, parameters=None):
if parameters is None:
parameters = {}

configure_s3_access(
# aws_unsigned=True,
region_name=os.getenv("AWS_DEFAULT_REGION", "auto"),
client=self.dask_client,
)
if dask_client is None:
dask_client = Client(
n_workers=num_workers(), processes=True, threads_per_worker=1
)

with dask_client:
configure_s3_access(
aws_unsigned=True,
region_name=os.getenv("AWS_DEFAULT_REGION", "auto"),
client=dask_client,
)

with datacube.Datacube() as dc:
data = self.input_data(dc, time, feature)
with datacube.Datacube() as dc:
data = self.input_data(dc, time, feature)

df = self.process_data(data, {"time": time, "feature": feature, **parameters})

# If csv specified, return timeseries in csv form
if self.style['csv']:
if 'csv' in self.style:
return {"data": df}

# If table style specified in config, return chart (static timeseries)
elif self.style['table'] is not None:
elif 'table' in self.style:
chart = self.render_chart(df)
return {"data": df, "chart": chart}


else:
return {}

def input_data(self, dc, time, feature):
if time is None:
bag = self.input.query(dc, geopolygon=feature)
else:
bag = self.input.query(dc, time=time, geopolygon=feature)

output_crs = self.input.get('output_crs')
resolution = self.input.get('resolution')
align = self.input.get('align')

if not (output_crs and resolution):
print('parameters for Geobox not found in inputs')
if type(self.input) in (Product,):
print('Checking grid_spec in product')
if bag.product_definitions[self.input._product].grid_spec:
print('grid_spec exists - do nothing')
else:
if not bag.product_definitions[self.input._product].grid_spec:
output_crs = mostcommon_crs(list(bag.bag))

elif type(self.input) in (Juxtapose,):
print('Checking grid_spec of each product')
print(list(bag.product_definitions.values()))

grid_specs = [product_definition.grid_spec for product_definition in list(bag.product_definitions.values()) if getattr(product_definition, 'grid_spec', None)]
if len(set(grid_specs)) == 1:
print('grid_spec exists for all products and are all the same - do nothing')

elif len(set(grid_specs)) > 1:
if len(set(grid_specs)) > 1:
raise ValueError('Multiple grid_spec detected across all products - override target output_crs, resolution in config')

else:
if not resolution:
raise ValueError('add target resolution to config')

elif not output_crs:
output_crs = mostcommon_crs(bag.contained_datasets())

box = self.input.group(bag, output_crs=output_crs, resolution=resolution, align=align)

if self.about.get("guard_rail", True):
# HACK: Get around issue where VirtualDatasetBox has a geobox but thinks it doesn't because load_natively flag is True.
# Need load_natively to be False to be able to call box.shape() inside guard_rail check function.
# Don't have time to understand how VirtualDatasets work and why this is happening in any more detail - just need the drill to work :)
run_hack = box.load_natively and box.geobox is not None
if run_hack:
load_natively = box.load_natively
box.load_natively = False
_guard_rail(self.input, box)
if run_hack:
box.load_natively = load_natively

# TODO customize the number of processes
if self.dask_enabled:
data = self.input.fetch(box, dask_chunks={"time": 1})
else:
data = self.input.fetch(box)

data = self.input.fetch(box, dask_chunks={"time": 1})
mask = geometry_mask(
feature, data.geobox, all_touched=self.mask_all_touched, invert=True
)
Expand Down
Loading