Skip to content

Commit

Permalink
correct time dim and format
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Nov 6, 2024
1 parent 1b45ca2 commit 59e2358
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions odc/stats/plugins/lc_ml_treelite.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ def input_data(
(self.chunks["x"], self.chunks["y"], -1, -1),
dtype="float32",
name=ds.type.name + "_yxbt",
)
).squeeze("spec")
data_vars[ds.type.name] = input_array
else:
for var in xx.data_vars:
data_vars[var] = yxt_sink(
xx[var].astype("uint8"),
(self.chunks["x"], self.chunks["y"], -1),
name=ds.type.name + "_yxt",
)
).squeeze("spec")

coords = dict((dim, input_array.coords[dim]) for dim in input_array.dims)
return xr.Dataset(data_vars=data_vars, coords=coords)
Expand All @@ -130,10 +130,10 @@ def impute_missing_values(self, xx: xr.Dataset, image):
continue
nodata = xx[var].attrs.get("nodata", -999)
imputed = expr_eval(
"where((a==a)|(b<=nodata), a, b)",
"where((a==a)|(b<=nodata)|(b!=b), a, b)",
{
"a": image,
"b": xx[var].squeeze("spec", drop=True).data,
"b": xx[var].data,
},
name="impute_missing",
dtype="float32",
Expand All @@ -150,7 +150,7 @@ def convert_dtype(var):
image = expr_eval(
"where((a<=nodata), _nan, a)",
{
"a": xx[var].squeeze("spec", drop=True).data,
"a": xx[var].data,
},
name="convert_dtype",
dtype="float32",
Expand All @@ -166,7 +166,7 @@ def convert_dtype(var):
DateTimeRange(v) for v in self.temporal_coverage.get(var)
]
for tr in temporal_range:
if xx.solar_day in tr:
if xx.solar_day.data.astype("M8[ms]") in tr:
self._log.info("Impute missing values of %s", var)
image = convert_dtype(var)
images += [
Expand All @@ -177,11 +177,10 @@ def convert_dtype(var):
# use data from all sensors
image = convert_dtype(var)
images += [image]

else:
veg_mask = expr_eval(
"where(a==_v, 1, 0)",
{"a": xx[var].squeeze("spec", drop=True).data},
{"a": xx[var].data},
name="make_mask",
dtype="float32",
**{"_v": int(self.mask_bands[var])},
Expand Down

0 comments on commit 59e2358

Please sign in to comment.