Skip to content

Commit

Permalink
Apply urban mask on landcover level3 (#172)
Browse files Browse the repository at this point in the history
* add urban mask for landcover

* change vector rasterization to gdal

* correct plugin name

* remote vector file validation

* deconvolute landcover l34 tests

* fix the urban mask condition in level3

---------

Co-authored-by: Emma Ai <[email protected]>
  • Loading branch information
emmaai and Emma Ai authored Nov 25, 2024
1 parent 37fd7fa commit 50fdef1
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 106 deletions.
44 changes: 44 additions & 0 deletions odc/stats/plugins/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import dask
from osgeo import gdal, ogr, osr


def rasterize_vector_mask(
shape_file, transform, dst_shape, filter_expression=None, threshold=None
):
source_ds = ogr.Open(shape_file)
source_layer = source_ds.GetLayer()

if filter_expression is not None:
source_layer.SetAttributeFilter(filter_expression)

yt, xt = dst_shape[1:]
no_data = 0
albers = osr.SpatialReference()
albers.ImportFromEPSG(3577)

geotransform = (
transform.c,
transform.a,
transform.b,
transform.f,
transform.d,
transform.e,
)
target_ds = gdal.GetDriverByName("MEM").Create("", xt, yt, gdal.GDT_Byte)
target_ds.SetGeoTransform(geotransform)
target_ds.SetProjection(albers.ExportToWkt())
mask = target_ds.GetRasterBand(1)
mask.SetNoDataValue(no_data)
gdal.RasterizeLayer(target_ds, [1], source_layer, burn_values=[1])

mask = mask.ReadAsArray()

# used by landcover level3 urban
# if valid area >= threshold
# then the whole tile is valid

if threshold is not None:
if mask.sum() > mask.size * threshold:
return dask.array.ones(dst_shape, name=False)

return dask.array.from_array(mask.reshape(dst_shape), name=False)
File renamed without changes.
19 changes: 16 additions & 3 deletions odc/stats/plugins/l34_utils/lc_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
NODATA = 255


def lc_level3(xx: xr.Dataset):
def lc_level3(xx: xr.Dataset, urban_mask):

# Cultivated pipeline applies a mask which feeds only terrestrial veg (110) to the model
# Just exclude no data (255 or nan) and apply the cultivated results
Expand All @@ -23,16 +23,29 @@ def lc_level3(xx: xr.Dataset):
# Mask urban results with bare sfc (210)

res = expr_eval(
"where(a==_u, b, a)",
"where((a==_u), b, a)",
{
"a": res,
"b": xx.urban_classes.data,
"b": xx.artificial_surface.data,
},
name="mark_urban",
dtype="float32",
**{"_u": 210},
)

# Enforce non-urban mask area to be n/artificial (216)

res = expr_eval(
"where((b<=0)&(a==_u), _nu, a)",
{
"a": res,
"b": urban_mask,
},
name="mask_non_urban",
dtype="float32",
**{"_u": 215, "_nu": 216},
)

# Mark nodata to 255 in case any nan
res = expr_eval(
"where(a==a, a, nodata)",
Expand Down
31 changes: 29 additions & 2 deletions odc/stats/plugins/lc_level34.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import xarray as xr

from ._registry import StatsPluginInterface, register
from ._utils import rasterize_vector_mask
from osgeo import gdal

from .l34_utils import (
l4_water_persistence,
Expand All @@ -33,12 +35,28 @@ class StatsLccsLevel4(StatsPluginInterface):

def __init__(
self,
urban_mask: str = None,
filter_expression: str = None,
mask_threshold: Optional[float] = None,
veg_threshold: Optional[List] = None,
bare_threshold: Optional[List] = None,
watper_threshold: Optional[List] = None,
**kwargs,
):
super().__init__(**kwargs)
if urban_mask is None:
raise ValueError("Missing urban mask shapefile")

file_metadata = gdal.VSIStatL(urban_mask)
if file_metadata is None:
raise FileNotFoundError(f"{urban_mask} not found")

if filter_expression is None:
raise ValueError("Missing urban mask filter")

self.urban_mask = urban_mask
self.filter_expression = filter_expression
self.mask_threshold = mask_threshold

self.veg_threshold = (
veg_threshold if veg_threshold is not None else [1, 4, 15, 40, 65, 100]
Expand All @@ -51,6 +69,7 @@ def __init__(
def fuser(self, xx):
return xx

# pylint: disable=too-many-locals
def reduce(self, xx: xr.Dataset) -> xr.Dataset:

# Water persistence
Expand All @@ -62,7 +81,15 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
l4 = l4_water.water_classification(xx, water_persistence)

# Generate Level3 classes
level3 = lc_level3.lc_level3(xx)
urban_mask = rasterize_vector_mask(
self.urban_mask,
xx.geobox.transform,
xx.artificial_surface.shape,
filter_expression=self.filter_expression,
threshold=self.mask_threshold,
)

level3 = lc_level3.lc_level3(xx, urban_mask)

# Vegetation cover
veg_cover = l4_veg_cover.canopyco_veg_con(xx, self.veg_threshold)
Expand Down Expand Up @@ -98,4 +125,4 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
return leve34


register("lc_l3_l4", StatsLccsLevel4)
register("lccs_level34", StatsLccsLevel4)
21 changes: 2 additions & 19 deletions odc/stats/plugins/mangroves.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import dask
import os
from odc.algo import keep_good_only, erase_bad
import fiona
from rasterio import features

from ._registry import StatsPluginInterface, register
from ._utils import rasterize_vector_mask

NODATA = 255

Expand Down Expand Up @@ -45,22 +44,6 @@ def measurements(self) -> Tuple[str, ...]:
_measurements = ["canopy_cover_class"]
return _measurements

def rasterize_mangroves_extent(self, shape_file, transform, dst_shape):
with fiona.open(shape_file) as source_ds:
geoms = [s["geometry"] for s in source_ds]

mangrove_extent = features.rasterize(
geoms,
transform=transform,
out_shape=dst_shape[1:],
all_touched=False,
fill=0,
default_value=1,
dtype="uint8",
)

return dask.array.from_array(mangrove_extent.reshape(dst_shape), name=False)

def fuser(self, xx):
"""
no fuse required for mangroves since group by none
Expand All @@ -73,7 +56,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
mangroves computation here
it is not a 'reduce' though
"""
extent_mask = self.rasterize_mangroves_extent(
extent_mask = rasterize_vector_mask(
self.mangroves_extent, xx.geobox.transform, xx.pv_pc_10.shape
)
good_data = extent_mask == 1
Expand Down
77 changes: 77 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
import boto3
from moto import mock_aws
from odc.stats.plugins import register
import json
import tempfile
import os
import fiona
from fiona.crs import CRS

from . import DummyPlugin

TEST_DIR = pathlib.Path(__file__).parent.absolute()
Expand Down Expand Up @@ -148,3 +154,74 @@ def usgs_ls8_sr_definition():
],
}
return definition


@pytest.fixture
def urban_shape():
data = """
{
"type":"FeatureCollection",
"features":[
{
"geometry":{
"type":"Polygon",
"coordinates":[
[
[
0,
0
],
[
0,
100
],
[
100,
100
],
[
100,
0
],
[
0,
0
]
]
]
},
"type":"Feature",
"properties":
{
"name": "mock",
"value": 10
}
}
]
}
"""
data = json.loads(data)["features"][0]
tmpdir = tempfile.mkdtemp()
filename = os.path.join(tmpdir, "test.json")
with fiona.open(
filename,
"w",
driver="GeoJSON",
crs=CRS.from_epsg(3577),
schema={
"geometry": "Polygon",
"properties": {"name": "str", "value": "int"},
},
) as dst:
dst.write(data)
return filename


@pytest.fixture()
def veg_threshold():
return [1, 4, 15, 40, 65, 100]


@pytest.fixture()
def watper_threshold():
return [1, 4, 7, 10]
26 changes: 19 additions & 7 deletions tests/test_lc_l34.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import pandas as pd
import xarray as xr
import dask.array as da
from datacube.utils.geometry import GeoBox
from affine import Affine


import pytest

Expand Down Expand Up @@ -112,18 +115,22 @@ def image_groups():
(np.datetime64("2000-01-01T00"), np.datetime64("2000-01-01")),
]
index = pd.MultiIndex.from_tuples(tuples, names=["time", "solar_day"])
coords = {
"x": np.linspace(10, 20, l34.shape[2]),
"y": np.linspace(0, 5, l34.shape[1]),
}

affine = Affine.translation(10, 0) * Affine.scale(
(20 - 10) / l34.shape[2], (5 - 0) / l34.shape[1]
)
geobox = GeoBox(
crs="epsg:3577", affine=affine, width=l34.shape[2], height=l34.shape[1]
)
coords = geobox.xr_coords()

data_vars = {
"level_3_4": xr.DataArray(
da.from_array(l34, chunks=(1, -1, -1)),
dims=("spec", "y", "x"),
attrs={"nodata": 255},
),
"urban_classes": xr.DataArray(
"artificial_surface": xr.DataArray(
da.from_array(urban, chunks=(1, -1, -1)),
dims=("spec", "y", "x"),
attrs={"nodata": 255},
Expand Down Expand Up @@ -165,11 +172,16 @@ def image_groups():
return xx


def test_l4_classes(image_groups):
def test_l4_classes(image_groups, urban_shape):
expected_l3 = [[216, 216, 215], [216, 216, 216], [220, 215, 215], [220, 220, 220]]

expected_l4 = [[95, 97, 93], [97, 96, 96], [100, 93, 93], [101, 101, 101]]
stats_l4 = StatsLccsLevel4(measurements=["level3", "level4"])
stats_l4 = StatsLccsLevel4(
measurements=["level3", "level4"],
urban_mask=urban_shape,
filter_expression="mock > 9",
mask_threshold=0.3,
)
ds = stats_l4.reduce(image_groups)

assert (ds.level3.compute() == expected_l3).all()
Expand Down
Loading

0 comments on commit 50fdef1

Please sign in to comment.