Skip to content

Commit

Permalink
aggregate results when multiple inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Jul 19, 2024
1 parent 8e245c9 commit ab56f58
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 31 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ repos:
args: ['--django']
- id: requirements-txt-fixer
- id: check-added-large-files
args: ["--maxkb=5000"]
- id: check-merge-conflict
- repo: https://github.com/PyCQA/flake8
rev: '7.1.0'
Expand Down
58 changes: 43 additions & 15 deletions odc/stats/plugins/lc_tf_urban.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,27 +164,55 @@ def impute_missing_values_from_group(self, xx):
images += [image]
return images

def aggregate_results_from_group(self, urban_masks):
# if there are >= 2 images
# any is urban -> final class is urban
# any is valid -> final class is valid
# for each pixel
m_size = len(urban_masks)
if m_size > 1:
urban_masks = da.stack(urban_masks).sum(axis=0)
else:
urban_masks = urban_masks[0]

urban_masks = expr_eval(
"where((a/nodata)>=_l, nodata, a%nodata)",
{"a": urban_masks},
name="mark_nodata",
dtype="float32",
**{"_l": m_size, "nodata": NODATA},
)

urban_masks = expr_eval(
"where((a>0)&(a<nodata), _u, a)",
{"a": urban_masks},
name="output_classes_artificial",
dtype="float32",
**{
"_u": self.output_classes["artificial"],
"nodata": NODATA,
},
)

urban_masks = expr_eval(
"where(a<=0, _nu, a)",
{"a": urban_masks},
name="output_classes_natrual",
dtype="uint8",
**{
"_nu": self.output_classes["natural"],
},
)

return urban_masks.rechunk(-1, -1)

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
urban_masks = []
images = self.impute_missing_values_from_group(xx)
for image in images:
urban_masks += [self.urban_class(image)]

# if there are >= 2 images
# any is urban -> final class is urban for each pixel
um = urban_masks[0]
for _um in urban_masks[1:]:
um = expr_eval(
"where((a==1)|(b==1), _u, _nu)",
{"a": um, "b": _um},
name="merge_masks",
dtype="uint8",
**{
"_u": self.output_classes["artificial"],
"_nu": self.output_classes["natural"],
},
)
um = um.rechunk(-1, -1)
um = self.aggregate_results_from_group(urban_masks)

attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
Expand Down
Binary file removed tests/data/expected_img.npy
Binary file not shown.
Binary file removed tests/data/img_1.npy
Binary file not shown.
Binary file removed tests/data/img_2.npy
Binary file not shown.
112 changes: 97 additions & 15 deletions tests/test_urban_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,77 @@ def output_classes():
return {"artificial": 215, "natural": 216}


@pytest.fixture(scope="module")
def urban_masks():
return [
da.array([[0, 0, 0], [1, 255, 0], [0, 0, 1]], dtype="uint8"),
da.array([[255, 0, 0], [0, 255, 1], [1, 0, 0]], dtype="uint8"),
]


@pytest.fixture(scope="module")
def image_groups():

img_1 = np.load(f"{data_dir}/img_1.npy")
img_1 = np.array(
[
[
[-999, -999, -999, -999, -999, -999],
[491, 878, 315, 324, 820, 610],
[134, 178, 458, 55, 832, 684],
[896, 345, 392, 755, 742, 752],
],
[
[707, 980, 767, 665, 101, 229],
[352, 410, 176, 400, 72, 722],
[0, 858, 629, 121, 662, 477],
[891, 934, 766, 929, 626, 561],
],
[
[19, 586, 496, 964, 869, 389],
[447, 325, 609, 366, 490, 457],
[706, 156, 950, 171, 848, 994],
[474, 100, 985, 277, 579, 289],
],
[
[186, 365, 275, 109, 800, 927],
[365, 509, 872, 288, 390, 262],
[200, 503, 323, 566, 861, 659],
[796, 117, 4, 814, 631, 789],
],
],
dtype="int16",
)

img_1 = da.from_array(img_1, chunks=(-1, -1, -1))
img_2 = np.load(f"{data_dir}/img_2.npy")
img_2 = np.array(
[
[
[772, 115, 814, 44, 951, 824],
[8, 602, 170, 331, 117, 483],
[121, 112, 124, 172, 704, 388],
[741, 588, 289, 665, 320, 303],
],
[
[846, 126, 357, 805, 192, 380],
[875, 880, 446, 458, 116, 828],
[672, 290, 795, 727, 746, 967],
[170, 813, 471, 885, 919, 944],
],
[
[63, 101, 718, 772, 313, 637],
[618, 576, 254, 541, 138, 13],
[403, 248, 891, 169, 164, 132],
[830, 66, 87, 129, 703, 476],
],
[
[532, 620, 928, 617, 630, 666],
[275, 253, 586, 604, 662, 948],
[532, 807, 59, 505, 210, 149],
[-999, -999, -999, -999, -999, -999],
],
],
dtype="int16",
)
img_2 = da.from_array(img_2, chunks=(-1, -1, -1))

coords = {
Expand All @@ -61,27 +126,44 @@ def image_groups():
return xx


def test_impute_missing_values(tflite_model_path, image_groups):
def test_impute_missing_values(output_classes, tflite_model_path, image_groups):
stats_urban = StatsUrbanClass(output_classes, tflite_model_path)
res = stats_urban.impute_missing_values_from_group(image_groups)
expect_res = np.load(f"{data_dir}/expected_img.npy")
assert res[0].dtype == "float32"
assert res[1].dtype == "float32"
assert (res[0][5:, 5:, :] == expect_res[0, 5:, 5:, :]).all()
assert (res[0][:5, :5, :] == expect_res[1, :5, :5, :]).all()
assert (res[1][10:15, 10:15, :] == expect_res[0, 10:15, 10:15, :]).all()
assert (res[1][:10, :10, :] == expect_res[1, :10, :10, :]).all()
assert (res[1][15:, 15:, :] == expect_res[1, 15:, 15:, :]).all()
assert (res[0][0, 0, :] == res[1][0, 0, :]).all()
assert (res[0][3, 3, :] == res[1][3, 3, :]).all()
assert (res[0][1:, 1:, :] == image_groups["ga_ls7"][1:, 1:, :]).all()
assert (res[1][:3, :3, :] == image_groups["ga_ls8"][:3, :3, :]).all()


def test_urban_class(tflite_model_path, image_groups):
def test_urban_class(output_classes, tflite_model_path, image_groups):
# test better than random for a prediction
# check correctness in integration test
stats_urban = StatsUrbanClass(output_classes, tflite_model_path)
client.register_plugin(stats_urban.dask_worker_plugin)
input_img = np.load(f"{data_dir}/expected_img.npy")
urban_mask = []
input_img = stats_urban.impute_missing_values_from_group(image_groups)
input_img[0][1, 1, :] = np.nan
input_img[1][1, 1, :] = np.nan
for img in input_img:
img = da.from_array(img, chunks=(-1, -1, -1))
urban_mask += [stats_urban.urban_class(img)]
assert (np.array(urban_mask) == 0).all()
urban_mask = stats_urban.urban_class(img)
urban_mask = urban_mask.compute()
assert (urban_mask[1, 1] == 255).all()
assert (
urban_mask[np.where(urban_mask < 255)[0], np.where(urban_mask < 255)[1]]
== 0
).all()


def test_aggregate_results_from_group(output_classes, tflite_model_path, urban_masks):
stats_urban = StatsUrbanClass(output_classes, tflite_model_path)
res = stats_urban.aggregate_results_from_group([urban_masks[0]])
expected_res = np.array(
[[216, 216, 216], [215, 255, 216], [216, 216, 215]], dtype="uint8"
)
assert (res == expected_res).all()
res = stats_urban.aggregate_results_from_group(urban_masks)
expected_res = np.array(
[[216, 216, 216], [215, 255, 215], [215, 216, 215]], dtype="uint8"
)
assert (res == expected_res).all()

0 comments on commit ab56f58

Please sign in to comment.