From f1f2771f3fd9de1b68a087740deb5247ba190e79 Mon Sep 17 00:00:00 2001 From: Emma Ai Date: Thu, 22 Aug 2024 00:11:16 +0000 Subject: [PATCH] combine dask ops --- docker/requirements.txt | 2 +- odc/stats/plugins/lc_treelite_cultivated.py | 23 +++++++-------------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/docker/requirements.txt b/docker/requirements.txt index e74eb8fd..757e188c 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -11,7 +11,7 @@ odc-dscache>=0.2.3 odc-stac @ git+https://github.com/opendatacube/odc-stac@69bdf64 # odc-stac is in PyPI -odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@d7307d9 +odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@a9bdd82 # For ML tflite-runtime diff --git a/odc/stats/plugins/lc_treelite_cultivated.py b/odc/stats/plugins/lc_treelite_cultivated.py index 955daca6..76ceeb28 100644 --- a/odc/stats/plugins/lc_treelite_cultivated.py +++ b/odc/stats/plugins/lc_treelite_cultivated.py @@ -227,6 +227,12 @@ def generate_features(input_block, bands_indices): return output_block +def cultivated_predict(input_block, bands_indices): + feature_block = generate_features(input_block, bands_indices) + cc = mask_and_predict(feature_block, ptype="categorical", nodata=NODATA) + return cc + + class StatsCultivatedClass(StatsMLTree): NAME = "ga_ls_cultivated" SHORT_NAME = NAME @@ -240,23 +246,10 @@ def measurements(self) -> Tuple[str, ...]: def predict(self, input_array): bands_indices = dict(zip(self.input_bands, np.arange(len(self.input_bands)))) - input_features = da.map_blocks( - generate_features, + cc = da.map_blocks( + cultivated_predict, input_array, bands_indices, - chunks=( - input_array.chunks[0], - input_array.chunks[1], - 15 + len(bands_indices) - bands_indices["bcdev"] - 1, - ), - dtype="float32", - name="generate_features", - ) - cc = da.map_blocks( - mask_and_predict, - input_features, - ptype="categorical", - nodata=NODATA, drop_axis=-1, dtype="float32", name="cultivated_predict",