diff --git a/odc/stats/plugins/_registry.py b/odc/stats/plugins/_registry.py index 3f428d8..f9d8c64 100644 --- a/odc/stats/plugins/_registry.py +++ b/odc/stats/plugins/_registry.py @@ -40,7 +40,6 @@ def import_all(): # TODO: make that more automatic modules = [ "odc.stats.plugins.lc_treelite_cultivated", - "odc.stats.plugins.lc_level3", "odc.stats.plugins.lc_treelite_woody", "odc.stats.plugins.lc_tf_urban", "odc.stats.plugins.lc_level34", diff --git a/odc/stats/plugins/_utils.py b/odc/stats/plugins/_utils.py index 06f218a..b4deab9 100644 --- a/odc/stats/plugins/_utils.py +++ b/odc/stats/plugins/_utils.py @@ -1,3 +1,5 @@ +import re +import operator import dask from osgeo import gdal, ogr, osr @@ -42,3 +44,99 @@ def rasterize_vector_mask( return dask.array.ones(dst_shape, name=False) return dask.array.from_array(mask.reshape(dst_shape), name=False) + + +OPERATORS = { + ">": operator.gt, + ">=": operator.ge, + "<": operator.lt, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, +} + +BRACKETS = { + "[": operator.ge, # Inclusive lower bound + "(": operator.gt, # Exclusive lower bound + "]": operator.le, # Inclusive upper bound + ")": operator.lt, # Exclusive upper bound +} + + +def parse_rule(rule): + """ + Parse a single condition or range condition. + Supports range notations like '[]', '[)', '(]', and '()', + and treats standalone numbers as '=='. + """ + # Special case for 255 (rule doesn't apply) + if (rule == "255") | (rule == "nan"): + return None + + # Check for range conditions like '[a, b)' or '(a, b]' + range_pattern = r"([\[\(])(-?\d+\.?\d*),\s*(-?\d+\.?\d*)([\]\)])" + match = re.match(range_pattern, rule) + if match: + # Extract the bounds and the bracket types + lower_bracket, lower_value, upper_value, upper_bracket = match.groups() + return [ + (BRACKETS[lower_bracket], float(lower_value)), + (BRACKETS[upper_bracket], float(upper_value)), + ] + + ordered_operators = sorted(OPERATORS.items(), key=lambda x: -len(x[0])) + + # Single condition (no range notation, no explicit operator) + for op_str, op_func in ordered_operators: + if op_str in rule: + value = float(rule.replace(op_str, "").strip()) + return [(op_func, value)] + + # Default to equality (==) if no operator is found + return [(operator.eq, int(rule.strip()))] + + +def generate_numexpr_expressions(rules_df, final_class_column, previous): + """ + Generate a list of numexpr-compatible expressions for classification rules. + :param rules_df: DataFrame containing the classification rules + :param final_class_column: Name of the column containing the final class values + :return: List of expressions (one for each rule) + """ + expressions = [] + + for _, rules in rules_df.iterrows(): + conditions = [] + + for col in rules.index: + if col == final_class_column: + continue + subconditions = parse_rule(rules[col]) + if subconditions is None: # Skip rule if it's None + continue + for op_func, value in subconditions: + if op_func is operator.eq: + conditions.append(f"({col}=={value})") + elif op_func is operator.gt: + conditions.append(f"({col}>{value})") + elif op_func is operator.ge: + conditions.append(f"({col}>={value})") + elif op_func is operator.lt: + conditions.append(f"({col}<{value})") + elif op_func is operator.le: + conditions.append(f"({col}<={value})") + elif op_func is operator.ne: + conditions.append(f"({col}!={value})") + + if not conditions: + continue + + condition = "&".join(conditions) + + final_class = rules[final_class_column] + expressions.append(f"where({condition}, {final_class}, {previous})") + + expressions = list(set(expressions)) + expressions = sorted(expressions, key=len) + + return expressions diff --git a/odc/stats/plugins/l34_utils/__init__.py b/odc/stats/plugins/l34_utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/odc/stats/plugins/l34_utils/l4_bare_gradation.py b/odc/stats/plugins/l34_utils/l4_bare_gradation.py deleted file mode 100644 index 45e0880..0000000 --- a/odc/stats/plugins/l34_utils/l4_bare_gradation.py +++ /dev/null @@ -1,46 +0,0 @@ -import xarray as xr -from odc.stats._algebra import expr_eval - - -NODATA = 255 - - -def bare_gradation(xx: xr.Dataset, bare_threshold, veg_cover): - - # Address nodata - bs_pc_50 = expr_eval( - "where((a!=a), nodata, a)", - {"a": xx.bs_pc_50.data}, - name="mark_bare_gradation_nodata", - dtype="float32", - **{"nodata": NODATA}, - ) - - # 60% <= data --> 15 - bs_mask = expr_eval( - "where((a>=m)&(a!=nodata), 15, a)", - {"a": bs_pc_50}, - name="mark_bare", - dtype="uint8", - **{"m": bare_threshold[1], "nodata": NODATA}, - ) - - # 20% <= data < 60% --> 12 - bs_mask = expr_eval( - "where((a>=m)&(a 10 - bs_mask = expr_eval( - "where(a woody - # everything else -> herbaceous - - water_seasonality = expr_eval( - "where((a==a), a, nodata)", - { - "a": water_season, - }, - name="mark_water_season", - dtype="float32", - **{"nodata": NODATA}, - ) - - res = expr_eval( - "where((a==124), 56, a)", - { - "a": l4, - }, - name="mark_woody", - dtype="uint8", - ) - - res = expr_eval( - "where((a==125), 57, a)", - { - "a": res, - }, - name="mark_herbaceous", - dtype="uint8", - ) - - # res = expr_eval( - # "where((a!=124)|(a!=125), 255, a)", - # { - # "a": res, - # }, - # name="mark_nodata", - # dtype="uint8", - # ) - - # mark water season - # use some value not used in final class - res = expr_eval( - "where((a==56)&(b==1), 254, a)", - { - "a": res, - "b": water_seasonality, - }, - name="mark_water_season", - dtype="uint8", - ) - - res = expr_eval( - "where((a==56)&(b==2), 253, a)", - { - "a": res, - "b": water_seasonality, - }, - name="mark_water_season", - dtype="uint8", - ) - - res = expr_eval( - "where((a==57)&(b==1), 252, a)", - { - "a": res, - "b": water_seasonality, - }, - name="mark_water_season", - dtype="uint8", - ) - - res = expr_eval( - "where((a==57)&(b==2), 251, a)", - { - "a": res, - "b": water_seasonality, - }, - name="mark_water_season", - dtype="uint8", - ) - - # mark final - - res = expr_eval( - "where((a==254)&(b==10), 64, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==253)&(b==10), 65, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==252)&(b==10), 79, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==251)&(b==10), 80, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - ######################################### - res = expr_eval( - "where((a==254)&(b==12), 67, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==253)&(b==12), 68, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==252)&(b==12), 82, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==251)&(b==12), 83, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - ########################################## - res = expr_eval( - "where((a==254)&(b==13), 70, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==253)&(b==13), 71, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==252)&(b==13), 85, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==251)&(b==13), 86, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - ######################################### - - res = expr_eval( - "where((a==254)&(b==15), 73, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==253)&(b==15), 74, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==252)&(b==15), 88, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==251)&(b==15), 89, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - ########################################## - res = expr_eval( - "where((a==254)&(b==16), 76, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==253)&(b==16), 77, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==252)&(b==16), 91, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - res = expr_eval( - "where((a==251)&(b==16), 92, a)", - { - "a": res, - "b": veg_cover, - }, - name="mark_final", - dtype="uint8", - ) - - # There are cases where a tile falls over water. - # In these cases, the PC will have no data so we map back 251-254 to their corresponding classes - res = expr_eval( - "where((a>=251)&(a<=252), 57, a)", - { - "a": res, - }, - name="mark_herbaceous", - dtype="uint8", - ) - res = expr_eval( - "where((a>=253)&(a<=254), 56, a)", - { - "a": res, - }, - name="mark_woody", - dtype="uint8", - ) - - return res diff --git a/odc/stats/plugins/l34_utils/l4_natural_veg.py b/odc/stats/plugins/l34_utils/l4_natural_veg.py deleted file mode 100644 index 8a03999..0000000 --- a/odc/stats/plugins/l34_utils/l4_natural_veg.py +++ /dev/null @@ -1,133 +0,0 @@ -from odc.stats._algebra import expr_eval - -NODATA = 255 - - -def lc_l4_natural_veg(l4, l3, woody, veg_cover): - - woody = expr_eval( - "where((a!=a), nodata, a)", - {"a": woody.data}, - name="mask_woody_nodata", - dtype="float32", - **{"nodata": NODATA}, - ) - - l4 = expr_eval( - "where((b==nodata), nodata, a)", - {"a": l4, "b": l3}, - name="mark_cultivated", - dtype="uint8", - **{"nodata": NODATA}, - ) - - l4 = expr_eval( - "where((a==112)&(b==113), 20, d)", - {"a": l3, "b": woody, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - - l4 = expr_eval( - "where((a==112)&(b==114), 21, d)", - {"a": l3, "b": woody, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - - l4 = expr_eval( - "where((a==112)&(c==10), 22, d)", - {"a": l3, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==12), 23, d)", - {"a": l3, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==13), 24, d)", - {"a": l3, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==15), 25, d)", - {"a": l3, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==16), 26, d)", - {"a": l3, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - - l4 = expr_eval( - "where((a==112)&(c==10)&(b==113), 27, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==12)&(b==113), 28, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==13)&(b==113), 29, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==15)&(b==113), 30, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - - l4 = expr_eval( - "where((a==112)&(c==16)&(b==113), 31, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - - l4 = expr_eval( - "where((a==112)&(c==10)&(b==114), 32, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==12)&(b==114), 33, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==13)&(b==114), 34, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - l4 = expr_eval( - "where((a==112)&(c==15)&(b==114), 35, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - - l4 = expr_eval( - "where((a==112)&(c==16)&(b==114), 36, d)", - {"a": l3, "b": woody, "c": veg_cover, "d": l4}, - name="mark_cultivated", - dtype="uint8", - ) - - return l4 diff --git a/odc/stats/plugins/l34_utils/l4_surface.py b/odc/stats/plugins/l34_utils/l4_surface.py deleted file mode 100644 index fcf640b..0000000 --- a/odc/stats/plugins/l34_utils/l4_surface.py +++ /dev/null @@ -1,37 +0,0 @@ -from odc.stats._algebra import expr_eval - - -def lc_l4_surface(l4, level3, bare_gradation): - - l4 = expr_eval( - "where((c==215), 93, a)", - {"a": l4, "c": level3}, - name="mark_surface", - dtype="uint8", - ) - l4 = expr_eval( - "where((c==216), 94, a)", - {"a": l4, "c": level3}, - name="mark_surface", - dtype="uint8", - ) - l4 = expr_eval( - "where((b==10)&(c==216), 95, a)", - {"a": l4, "b": bare_gradation, "c": level3}, - name="mark_surface", - dtype="uint8", - ) - l4 = expr_eval( - "where((b==12)&(c==216), 96, a)", - {"a": l4, "b": bare_gradation, "c": level3}, - name="mark_surface", - dtype="uint8", - ) - l4 = expr_eval( - "where((b==15)&(c==216), 97, a)", - {"a": l4, "b": bare_gradation, "c": level3}, - name="mark_surface", - dtype="uint8", - ) - - return l4 diff --git a/odc/stats/plugins/l34_utils/l4_veg_cover.py b/odc/stats/plugins/l34_utils/l4_veg_cover.py deleted file mode 100644 index 60bc820..0000000 --- a/odc/stats/plugins/l34_utils/l4_veg_cover.py +++ /dev/null @@ -1,90 +0,0 @@ -# from typing import Tuple, Optional, Dict, List -import xarray as xr -from odc.stats._algebra import expr_eval - -NODATA = 255 - - -def canopyco_veg_con(xx: xr.Dataset, veg_threshold): - - # Mask NODATA - pv_pc_50 = expr_eval( - "where(a==a, a, nodata)", - {"a": xx.pv_pc_50.data}, - name="mark_nodata", - dtype="float32", - **{"nodata": NODATA}, - ) - - # data < 1 ---> 0 - veg_mask = expr_eval( - "where(a 16 - veg_mask = expr_eval( - "where((a>=m)&(a 15 - veg_mask = expr_eval( - "where((a>=m)&(a 13 - veg_mask = expr_eval( - "where((a>=m)&(a 12 - veg_mask = expr_eval( - "where((a>=m)&(a 10 - veg_mask = expr_eval( - "where((a>=m)&(a<=n), 10, b)", - { - "a": pv_pc_50, - "b": veg_mask, - }, - name="mark_veg", - dtype="uint8", - **{"m": veg_threshold[4], "n": veg_threshold[5]}, - ) - - return veg_mask diff --git a/odc/stats/plugins/l34_utils/l4_water.py b/odc/stats/plugins/l34_utils/l4_water.py deleted file mode 100644 index c93ecce..0000000 --- a/odc/stats/plugins/l34_utils/l4_water.py +++ /dev/null @@ -1,65 +0,0 @@ -from odc.stats._algebra import expr_eval - -NODATA = 255 - - -def water_classification(xx, water_persistence): - - # Replace nan with nodata - l4 = expr_eval( - "where((a==a), a, nodata)", - {"a": xx.level_3_4.data}, - name="mark_water", - dtype="uint8", - **{"nodata": NODATA}, - ) - - l4 = expr_eval( - "where((a==223)|(a==221), 98, a)", {"a": l4}, name="mark_water", dtype="uint8" - ) - - l4 = expr_eval( - "where((a==98)&(b!=_u), 99, a)", - {"a": l4, "b": xx.level_3_4.data}, - name="mark_water", - dtype="uint8", - **{"_u": 223}, - ) - - l4 = expr_eval( - "where((a==98)&(b==_u), 100, a)", - {"a": l4, "b": xx.level_3_4.data}, - name="mark_water", - dtype="uint8", - **{"_u": 223}, - ) - - l4 = expr_eval( - "where((a==99)&(b==1), 101, a)", - {"a": l4, "b": water_persistence}, - name="mark_water", - dtype="uint8", - ) - - l4 = expr_eval( - "where((a==99)&(b==7), 102, a)", - {"a": l4, "b": water_persistence}, - name="mark_water", - dtype="uint8", - ) - - l4 = expr_eval( - "where((a==99)&(b==8), 103, a)", - {"a": l4, "b": water_persistence}, - name="mark_water", - dtype="uint8", - ) - - l4 = expr_eval( - "where((a==99)&(b==9), 104, a)", - {"a": l4, "b": water_persistence}, - name="mark_water", - dtype="uint8", - ) - - return l4 diff --git a/odc/stats/plugins/l34_utils/l4_water_persistence.py b/odc/stats/plugins/l34_utils/l4_water_persistence.py deleted file mode 100644 index fdab934..0000000 --- a/odc/stats/plugins/l34_utils/l4_water_persistence.py +++ /dev/null @@ -1,56 +0,0 @@ -import xarray as xr - -from odc.stats._algebra import expr_eval - -NODATA = 255 -WATER_FREQ_NODATA = -999 - - -def water_persistence(xx: xr.Dataset, watper_threshold): - - # Address nan - water_frequency = expr_eval( - "where((a!=a), nodata, a)", - {"a": xx.water_frequency.data}, - name="mark_water", - dtype="float32", - **{"nodata": NODATA}, - ) - - # 10 <= water_frequency < 1 --> 1 - water_mask = expr_eval( - "where((a>=m)&(a!=nodata), 1, a)", - {"a": water_frequency}, - name="mark_water", - dtype="uint8", - **{"m": watper_threshold[3], "nodata": NODATA}, - ) - - # 7 <= water_frequency < 10 --> 7 - water_mask = expr_eval( - "where((a>=m)&(a 8 - water_mask = expr_eval( - "where((a>=m)&(a 9 - water_mask = expr_eval( - "where((a>=m)&(a=nodata), b, a)", - {"a": xx.cultivated.data, "b": xx.level_3_4.data}, - name="mask_cultivated", - dtype="float32", - **{"nodata": xx.cultivated.attrs.get("nodata")}, - ) - - # Mask urban results with bare sfc (210) - - res = expr_eval( - "where((a==_u), b, a)", - { - "a": res, - "b": xx.artificial_surface.data, - }, - name="mark_urban", - dtype="float32", - **{"_u": 210}, - ) - - # Enforce non-urban mask area to be n/artificial (216) - - res = expr_eval( - "where((b<=0)&(a==_u), _nu, a)", - { - "a": res, - "b": urban_mask, - }, - name="mask_non_urban", - dtype="float32", - **{"_u": 215, "_nu": 216}, - ) - - # Mark nodata to 255 in case any nan - res = expr_eval( - "where(a==a, a, nodata)", - { - "a": res, - }, - name="mark_nodata", - dtype="uint8", - **{"nodata": NODATA}, - ) - # Add intertidal as water - res = expr_eval( - "where((a==223)|(a==221), 220, b)", - {"a": xx.level_3_4.data, "b": res}, - name="mark_urban", - dtype="uint8", - ) - - # Combine woody and herbaceous aquatic vegetation - res = expr_eval( - "where((a==124)|(a==125), 124, b)", - {"a": xx.level_3_4.data, "b": res}, - name="mark_aquatic_veg", - dtype="uint8", - ) - - return res diff --git a/odc/stats/plugins/lc_level34.py b/odc/stats/plugins/lc_level34.py index 00f4472..edbdf98 100644 --- a/odc/stats/plugins/lc_level34.py +++ b/odc/stats/plugins/lc_level34.py @@ -2,28 +2,19 @@ Plugin of Module A3 in LandCover PipeLine """ -from typing import Optional, List +from typing import Optional, Dict -import numpy as np import xarray as xr +import s3fs +import os +import pandas as pd +import dask.array as da from ._registry import StatsPluginInterface, register -from ._utils import rasterize_vector_mask +from ._utils import rasterize_vector_mask, generate_numexpr_expressions +from odc.stats._algebra import expr_eval from osgeo import gdal -from .l34_utils import ( - l4_water_persistence, - l4_veg_cover, - lc_level3, - l4_cultivated, - l4_natural_veg, - l4_natural_aquatic, - l4_surface, - l4_bare_gradation, - l4_water, -) - - NODATA = 255 @@ -35,15 +26,23 @@ class StatsLccsLevel4(StatsPluginInterface): def __init__( self, + class_def_path: str = None, urban_mask: str = None, filter_expression: str = None, mask_threshold: Optional[float] = None, - veg_threshold: Optional[List] = None, - bare_threshold: Optional[List] = None, - watper_threshold: Optional[List] = None, + data_var_condition: Optional[Dict] = None, **kwargs, ): super().__init__(**kwargs) + if class_def_path is None: + raise ValueError("Missing level34 class definition csv") + + if class_def_path.startswith("s3://"): + if not s3fs.S3FileSystem().exists(class_def_path): + raise FileNotFoundError(f"{class_def_path} not found") + elif not os.path.exists(class_def_path): + raise FileNotFoundError(f"{class_def_path} not found") + if urban_mask is None: raise ValueError("Missing urban mask shapefile") @@ -54,33 +53,57 @@ def __init__( if filter_expression is None: raise ValueError("Missing urban mask filter") + self.class_def = pd.read_csv(class_def_path) + cols = list(self.class_def.columns[:6]) + list(self.class_def.columns[9:-6]) + self.class_def = self.class_def[cols].astype(str).fillna("nan") + self.urban_mask = urban_mask self.filter_expression = filter_expression self.mask_threshold = mask_threshold - - self.veg_threshold = ( - veg_threshold if veg_threshold is not None else [1, 4, 15, 40, 65, 100] - ) - self.bare_threshold = bare_threshold if bare_threshold is not None else [20, 60] - self.watper_threshold = ( - watper_threshold if watper_threshold is not None else [1, 4, 7, 10] + self.data_var_condition = ( + {} if data_var_condition is None else data_var_condition ) def fuser(self, xx): return xx - # pylint: disable=too-many-locals - def reduce(self, xx: xr.Dataset) -> xr.Dataset: - - # Water persistence - water_persistence = l4_water_persistence.water_persistence( - xx, self.watper_threshold + def classification(self, xx, class_def, con_cols, class_col): + expressions = generate_numexpr_expressions( + class_def[con_cols + [class_col]], class_col, "res" + ) + local_dict = { + key: xx[self.data_var_condition.get(key, key)].data for key in con_cols + } + res = da.full(xx.level_3_4.shape, 0, dtype="uint8") + + for expression in expressions: + local_dict.update({"res": res}) + res = expr_eval( + expression, + local_dict, + name="apply_rule", + dtype="uint8", + ) + + # This seems redundant while res can be init to NODATA above, + # but it's a point to sanity check no class is missed + res = expr_eval( + "where((a!=a)|(a>=_n), _n, b)", + {"a": xx.level_3_4.data, "b": res}, + name="mark_nodata", + dtype="uint8", + **{"_n": NODATA}, ) - # #TODO WATER (99-104) - l4 = l4_water.water_classification(xx, water_persistence) + return res + + def reduce(self, xx: xr.Dataset) -> xr.Dataset: + con_cols = ["level1", "artificial_surface", "cultivated"] + class_col = "level3" + level3 = self.classification(xx, self.class_def, con_cols, class_col) - # Generate Level3 classes + # apply urban mask + # 215 -> 216 if urban_mask == 0 urban_mask = rasterize_vector_mask( self.urban_mask, xx.geobox.transform, @@ -89,38 +112,39 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset: threshold=self.mask_threshold, ) - level3 = lc_level3.lc_level3(xx, urban_mask) - - # Vegetation cover - veg_cover = l4_veg_cover.canopyco_veg_con(xx, self.veg_threshold) - - # Apply cultivated Level-4 classes (1-18) - l4 = l4_cultivated.lc_l4_cultivated(l4, level3, xx.woody, veg_cover) - - # Apply terrestrial vegetation classes [19-36] - l4 = l4_natural_veg.lc_l4_natural_veg(l4, level3, xx.woody, veg_cover) - - # Bare gradation - bare_gradation = l4_bare_gradation.bare_gradation( - xx, self.bare_threshold, veg_cover + level3 = expr_eval( + "where((a==215)&(b<1), 216, a)", + {"a": level3, "b": urban_mask}, + name="mask_non_urban", + dtype="uint8", ) - l4 = l4_natural_aquatic.natural_auquatic_veg(l4, veg_cover, xx.water_season) - - level4 = l4_surface.lc_l4_surface(l4, level3, bare_gradation) - - level3 = level3.astype(np.uint8) - level4 = level4.astype(np.uint8) - attrs = xx.attrs.copy() attrs["nodata"] = NODATA dims = xx.level_3_4.dims[1:] + coords = dict((dim, xx.coords[dim]) for dim in dims) + xx["level3"] = xr.DataArray( + level3.squeeze(), dims=dims, attrs=attrs, coords=coords + ) + + con_cols = [ + "level1", + "level3", + "woody", + "water_season", + "water_frequency", + "pv_pc_50", + "bs_pc_50", + ] + class_col = "level4" + + level4 = self.classification(xx, self.class_def, con_cols, class_col) + data_vars = { k: xr.DataArray(v, dims=dims, attrs=attrs) for k, v in zip(self.measurements, [level3.squeeze(), level4.squeeze()]) } - coords = dict((dim, xx.coords[dim]) for dim in dims) leve34 = xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs) return leve34