Skip to content

Commit

Permalink
change vector rasterization to gdal
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Nov 20, 2024
1 parent 4435edd commit 3d41c9e
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 84 deletions.
50 changes: 34 additions & 16 deletions odc/stats/plugins/_utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
import dask
import fiona
from rasterio import features


def rasterize_vector_mask(shape_file, transform, dst_shape, threshold=None):
with fiona.open(shape_file) as source_ds:
geoms = [s["geometry"] for s in source_ds]

mask = features.rasterize(
geoms,
transform=transform,
out_shape=dst_shape[1:],
all_touched=False,
fill=0,
default_value=1,
dtype="uint8",
from osgeo import gdal, ogr, osr


def rasterize_vector_mask(
shape_file, transform, dst_shape, filter_expression=None, threshold=None
):
source_ds = ogr.Open(shape_file)
source_layer = source_ds.GetLayer()

if filter_expression is not None:
source_layer.SetAttributeFilter(filter_expression)

yt, xt = dst_shape[1:]
no_data = 0
albers = osr.SpatialReference()
albers.ImportFromEPSG(3577)

geotransform = (
transform.c,
transform.a,
transform.b,
transform.f,
transform.d,
transform.e,
)
target_ds = gdal.GetDriverByName("MEM").Create("", xt, yt, gdal.GDT_Byte)
target_ds.SetGeoTransform(geotransform)
target_ds.SetProjection(albers.ExportToWkt())
mask = target_ds.GetRasterBand(1)
mask.SetNoDataValue(no_data)
gdal.RasterizeLayer(target_ds, [1], source_layer, burn_values=[1])

mask = mask.ReadAsArray()

# used by landcover level3 urban
# if valid area >= threshold
# then the whole tile is valid

if threshold is not None:
if mask.sum() > mask.size * threshold:
return dask.array.ones(dst_shape, name=False)
Expand Down
8 changes: 8 additions & 0 deletions odc/stats/plugins/lc_level34.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class StatsLccsLevel4(StatsPluginInterface):
def __init__(
self,
urban_mask: str = None,
filter_expression: str = None,
mask_threshold: Optional[float] = None,
veg_threshold: Optional[List] = None,
bare_threshold: Optional[List] = None,
Expand All @@ -51,7 +52,12 @@ def __init__(
raise ValueError("Missing urban mask shapefile")
if not os.path.exists(urban_mask):
raise FileNotFoundError(f"{urban_mask} not found")

if filter_expression is None:
raise ValueError("Missing urban mask filter")

self.urban_mask = urban_mask
self.filter_expression = filter_expression
self.mask_threshold = mask_threshold

self.veg_threshold = (
Expand Down Expand Up @@ -90,8 +96,10 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
self.urban_mask,
xx.geobox.transform,
xx.artificial_surface.shape,
filter_expression=self.filter_expression,
threshold=self.mask_threshold,
)

level3 = lc_level3.lc_level3(xx, urban_mask)

# Vegetation cover
Expand Down
67 changes: 67 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
import boto3
from moto import mock_aws
from odc.stats.plugins import register
import json
import tempfile
import os
import fiona
from fiona.crs import CRS

from . import DummyPlugin

TEST_DIR = pathlib.Path(__file__).parent.absolute()
Expand Down Expand Up @@ -148,3 +154,64 @@ def usgs_ls8_sr_definition():
],
}
return definition


@pytest.fixture
def urban_shape():
data = """
{
"type":"FeatureCollection",
"features":[
{
"geometry":{
"type":"Polygon",
"coordinates":[
[
[
0,
0
],
[
0,
100
],
[
100,
100
],
[
100,
0
],
[
0,
0
]
]
]
},
"type":"Feature",
"properties":
{
"name": "mock",
"value": 10
}
}
]
}
"""
data = json.loads(data)["features"][0]
tmpdir = tempfile.mkdtemp()
filename = os.path.join(tmpdir, "test.json")
with fiona.open(
filename,
"w",
driver="GeoJSON",
crs=CRS.from_epsg(3577),
schema={
"geometry": "Polygon",
"properties": {"name": "str", "value": "int"},
},
) as dst:
dst.write(data)
return filename
65 changes: 4 additions & 61 deletions tests/test_lc_l34.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
import pandas as pd
import xarray as xr
import dask.array as da
import json
import tempfile
import os
import fiona
from fiona.crs import CRS
from datacube.utils.geometry import GeoBox
from affine import Affine

Expand All @@ -18,61 +13,6 @@
NODATA = 255


@pytest.fixture(scope="module")
def urban_shape():
data = """
{
"type":"FeatureCollection",
"features":[
{
"geometry":{
"type":"Polygon",
"coordinates":[
[
[
0,
0
],
[
0,
100
],
[
100,
100
],
[
100,
0
],
[
0,
0
]
]
]
},
"type":"Feature"
}
]
}
"""
data = json.loads(data)["features"][0]
tmpdir = tempfile.mkdtemp()
filename = os.path.join(tmpdir, "test.json")
with fiona.open(
filename,
"w",
driver="GeoJSON",
crs=CRS.from_epsg(3577),
schema={
"geometry": "Polygon",
},
) as dst:
dst.write(data)
return filename


@pytest.fixture(scope="module")
def image_groups():
l34 = np.array(
Expand Down Expand Up @@ -220,7 +160,10 @@ def test_l4_classes(image_groups, urban_shape):

expected_l4 = [[95, 97, 93], [97, 96, 96], [93, 93, 93], [93, 93, 93]]
stats_l4 = StatsLccsLevel4(
measurements=["level3", "level4"], urban_mask=urban_shape, mask_threshold=0.3
measurements=["level3", "level4"],
urban_mask=urban_shape,
filter_expression="mock > 9",
mask_threshold=0.3,
)
ds = stats_l4.reduce(image_groups)

Expand Down
30 changes: 23 additions & 7 deletions tests/test_lc_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import dask.array as da

from odc.stats.plugins.l34_utils import lc_level3
from odc.stats.plugins._utils import rasterize_vector_mask
from datacube.utils.geometry import GeoBox
from affine import Affine

import pytest

NODATA = 255
Expand Down Expand Up @@ -58,18 +62,22 @@ def image_groups():
(np.datetime64("2000-01-01T00"), np.datetime64("2000-01-01")),
]
index = pd.MultiIndex.from_tuples(tuples, names=["time", "solar_day"])
coords = {
"x": np.linspace(10, 20, l34.shape[2]),
"y": np.linspace(0, 5, l34.shape[1]),
}

affine = Affine.translation(10, 0) * Affine.scale(
(20 - 10) / l34.shape[2], (5 - 0) / l34.shape[1]
)
geobox = GeoBox(
crs="epsg:3577", affine=affine, width=l34.shape[2], height=l34.shape[1]
)
coords = geobox.xr_coords()

data_vars = {
"classes_l3_l4": xr.DataArray(
da.from_array(l34, chunks=(1, -1, -1)),
dims=("spec", "y", "x"),
attrs={"nodata": 255},
),
"urban_classes": xr.DataArray(
"artificial_surface": xr.DataArray(
da.from_array(urban, chunks=(1, -1, -1)),
dims=("spec", "y", "x"),
attrs={"nodata": 255},
Expand All @@ -85,7 +93,15 @@ def image_groups():
return xx


def test_l3_classes(image_groups):
def test_l3_classes(image_groups, urban_shape):
filter_expression = "mock > 9"
urban_mask = rasterize_vector_mask(
urban_shape,
image_groups.geobox.transform,
image_groups.artificial_surface.shape,
filter_expression=filter_expression,
threshold=0.3,
)

level3_classes = lc_level3.lc_level3(image_groups)
level3_classes = lc_level3.lc_level3(image_groups, urban_mask)
assert (level3_classes == expected_l3_classes).all()

0 comments on commit 3d41c9e

Please sign in to comment.