From 59e23582ccd0cce2fd92bd93a67d63388dcd564b Mon Sep 17 00:00:00 2001 From: Emma Ai Date: Tue, 5 Nov 2024 04:25:44 +0000 Subject: [PATCH] correct time dim and format --- odc/stats/plugins/lc_ml_treelite.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/odc/stats/plugins/lc_ml_treelite.py b/odc/stats/plugins/lc_ml_treelite.py index 39542e37..6070350b 100644 --- a/odc/stats/plugins/lc_ml_treelite.py +++ b/odc/stats/plugins/lc_ml_treelite.py @@ -110,7 +110,7 @@ def input_data( (self.chunks["x"], self.chunks["y"], -1, -1), dtype="float32", name=ds.type.name + "_yxbt", - ) + ).squeeze("spec") data_vars[ds.type.name] = input_array else: for var in xx.data_vars: @@ -118,7 +118,7 @@ def input_data( xx[var].astype("uint8"), (self.chunks["x"], self.chunks["y"], -1), name=ds.type.name + "_yxt", - ) + ).squeeze("spec") coords = dict((dim, input_array.coords[dim]) for dim in input_array.dims) return xr.Dataset(data_vars=data_vars, coords=coords) @@ -130,10 +130,10 @@ def impute_missing_values(self, xx: xr.Dataset, image): continue nodata = xx[var].attrs.get("nodata", -999) imputed = expr_eval( - "where((a==a)|(b<=nodata), a, b)", + "where((a==a)|(b<=nodata)|(b!=b), a, b)", { "a": image, - "b": xx[var].squeeze("spec", drop=True).data, + "b": xx[var].data, }, name="impute_missing", dtype="float32", @@ -150,7 +150,7 @@ def convert_dtype(var): image = expr_eval( "where((a<=nodata), _nan, a)", { - "a": xx[var].squeeze("spec", drop=True).data, + "a": xx[var].data, }, name="convert_dtype", dtype="float32", @@ -166,7 +166,7 @@ def convert_dtype(var): DateTimeRange(v) for v in self.temporal_coverage.get(var) ] for tr in temporal_range: - if xx.solar_day in tr: + if xx.solar_day.data.astype("M8[ms]") in tr: self._log.info("Impute missing values of %s", var) image = convert_dtype(var) images += [ @@ -177,11 +177,10 @@ def convert_dtype(var): # use data from all sensors image = convert_dtype(var) images += [image] - else: veg_mask = expr_eval( "where(a==_v, 1, 0)", - {"a": xx[var].squeeze("spec", drop=True).data}, + {"a": xx[var].data}, name="make_mask", dtype="float32", **{"_v": int(self.mask_bands[var])},