Skip to content

Commit

Permalink
Bias correction fix (#1969)
Browse files Browse the repository at this point in the history
* Combine masks explicitly when calculating forecast error.

* Update acc-tests to capture masked domains in calculate_forecast_bias

* Simplify mask handling.
  • Loading branch information
benowen-bom authored Nov 27, 2023
1 parent 3eee0a9 commit b8e8d46
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 13 deletions.
9 changes: 8 additions & 1 deletion improver/calibration/simple_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def evaluate_additive_error(
An array containing the mean additive forecast error values.
"""
forecast_errors = forecasts - truths
# Set the masks explicitly to inherit the masks from both cubes
if isinstance(forecasts.data, ma.MaskedArray) or isinstance(
truths.data, ma.MaskedArray
):
forecast_errors.data.mask = ma.mask_or(
ma.asarray(forecasts.data).mask, ma.asarray(truths.data).mask
)
if collapse_dim in get_dim_coord_names(forecast_errors):
mean_forecast_error = collapsed(
forecast_errors, collapse_dim, iris.analysis.MEAN
Expand Down Expand Up @@ -289,7 +296,7 @@ def _get_mean_bias(self, bias_values: CubeList) -> Cube:
Cubelist containing the input bias cube(s).
Returns:
Cube containing the mean bias evaulated from set of bias_values.
Cube containing the mean bias evaluated from set of bias_values.
"""
# Currently only support for cases where the input bias_values are defined
# over a single forecast_reference_time.
Expand Down
8 changes: 8 additions & 0 deletions improver_tests/acceptance/SHA256SUMS
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,16 @@ ac609404e0083db721ded42d983f3a616f8ce7ca220ef3bbfc2af9ccfd83c02f ./calculate-fo
ec3c422222124570bae70a4add38888a33eb8a4d38a8b5bdabafd8362ed9b0b1 ./calculate-forecast-bias/inputs/20220812T0300Z-PT0003H00M-wind_speed_at_10m.nc
a37a4829a3d45295e38d454c9983f499cca11f14b21707e4f36f1bb3da32de63 ./calculate-forecast-bias/inputs/20220813T0300Z-PT0000H00M-wind_speed_at_10m.nc
be26d2ea432284b08ef68a2764f7b18463840d57b9dc495961c1bd90c1bf73c5 ./calculate-forecast-bias/inputs/20220813T0300Z-PT0003H00M-wind_speed_at_10m.nc
39b87f497a89aed63538ce48d3f65fe11b931fb315908e46f187dd832c58bf7a ./calculate-forecast-bias/inputs/masked/20220811T0300Z-PT0000H00M-wind_speed_at_10m.nc
c353a8fa343361719de535102a231cee74b90fda412591127ae511400d20b18d ./calculate-forecast-bias/inputs/masked/20220811T0300Z-PT0003H00M-wind_speed_at_10m.nc
345a05b9e852c87608edf82f397c5f80dac2123af24c0d5caacc2b8a01f9fbe6 ./calculate-forecast-bias/inputs/masked/20220812T0300Z-PT0000H00M-wind_speed_at_10m.nc
776eb8ef27fbd40742c2217ab89df5a423635ffeff2e6ee485bf78c90964125a ./calculate-forecast-bias/inputs/masked/20220812T0300Z-PT0003H00M-wind_speed_at_10m.nc
4e2308fb14ce761ac8a151dab427708d9f18842dbbb8bf323b5e20bd80633ba6 ./calculate-forecast-bias/inputs/masked/20220813T0300Z-PT0000H00M-wind_speed_at_10m.nc
c918097aa9985d02cb2cc6d89f85cfa026c7e1832d6b13f8777df0914b07a6b4 ./calculate-forecast-bias/inputs/masked/20220813T0300Z-PT0003H00M-wind_speed_at_10m.nc
3220caf063928ee4ff7905c0d30cdb48a856a5e513e8d7f5cfd6cd162dbf170d ./calculate-forecast-bias/multiple_frt/kgo.nc
41b16c2a8feed29f28c92480a64238242bba5945312ac27715017b2f887ca2f5 ./calculate-forecast-bias/multiple_frt_masked_inputs/kgo.nc
ec49997369aab4bb84049efcbf7720d15ca7a371d366bfcb707a2ca5a0473cfd ./calculate-forecast-bias/single_frt/kgo.nc
fcf348a7b41d561a16ca5d1713292b196679da7041977b9e871420c995a9cf89 ./calculate-forecast-bias/single_frt_masked_inputs/kgo.nc
f4e13dec400ec945ba5bd03a286780b183665c22d707c038c0990ef3491e888e ./categorical-modes/blend_mismatch_inputs/20201209T0700Z-weather_symbols-PT01H.nc
ee9557cf229b1099e64b36760a1b2dd82aec663490d75b6be2894bfb7a9103ea ./categorical-modes/blend_mismatch_inputs/20201209T0800Z-weather_symbols-PT01H.nc
601d331f490953f3ac36ac334b705848b8d4bc9024bfe001f4cd2d8d8b6645a2 ./categorical-modes/blend_mismatch_inputs/20201209T0900Z-weather_symbols-PT01H.nc
Expand Down
38 changes: 38 additions & 0 deletions improver_tests/acceptance/test_calculate_forecast_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,41 @@ def test_multiple_frt(tmp_path):
]
run_cli(args)
acc.compare(output_path, kgo_path)


def test_single_frt_masked_inputs(tmp_path):
"""
Test case where single historical forecast value provided.
"""
kgo_dir = acc.kgo_root() / "calculate-forecast-bias"
kgo_path = kgo_dir / "single_frt_masked_inputs" / "kgo.nc"
inputs_path = (kgo_dir / "inputs/masked").glob("20220811T0300Z-PT00*.nc")
output_path = tmp_path / "output.nc"
args = [
*inputs_path,
"--truth-attribute",
"mosg__model_configuration=msas_det",
"--output",
output_path,
]
run_cli(args)
acc.compare(output_path, kgo_path)


def test_multiple_frt_masked_inputs(tmp_path):
"""
Test case where multiple historical forecast values provided.
"""
kgo_dir = acc.kgo_root() / "calculate-forecast-bias"
kgo_path = kgo_dir / "multiple_frt_masked_inputs" / "kgo.nc"
inputs_path = (kgo_dir / "inputs/masked").glob("202208*T0300Z-PT00*.nc")
output_path = tmp_path / "output.nc"
args = [
*inputs_path,
"--truth-attribute",
"mosg__model_configuration=msas_det",
"--output",
output_path,
]
run_cli(args)
acc.compare(output_path, kgo_path)
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@


def generate_dataset(
num_frts: int = 1, truth_dataset: bool = False, data: ndarray = None
num_frts: int = 1,
truth_dataset: bool = False,
data: ndarray = None,
masked: bool = False,
) -> Cube:
"""Generate sample input datasets.
Expand Down Expand Up @@ -90,7 +93,15 @@ def generate_dataset(
else:
data_shape = data.shape
# Construct the cubes.
if masked:
data = np.ma.masked_array(data)
data.mask = np.zeros(shape=data_shape, dtype=bool)
if truth_dataset:
data.mask[:, -1] = True
else:
data.mask[0, :] = True
ref_forecast_cubes = CubeList()
data_mask = data.mask if isinstance(data, np.ma.MaskedArray) else False
for time in times:
if (num_frts > 1) and (not truth_dataset):
noise = rng.normal(0.0, 0.1, data_shape).astype(np.float32)
Expand All @@ -108,11 +119,13 @@ def generate_dataset(
)
ref_forecast_cube = ref_forecast_cubes.merge_cube()

return ref_forecast_cube
return ref_forecast_cube, data_mask


@pytest.mark.parametrize("num_frt", (1, 30))
def test_evaluate_additive_error(num_frt):
@pytest.mark.parametrize("mask_truth", (False, True))
@pytest.mark.parametrize("mask_forecast", (False, True)) # , True))
def test_evaluate_additive_error(num_frt, mask_truth, mask_forecast):
"""test additive error evaluation gives expected value (within tolerance)."""
data = 273.0 + np.array(
[[1.0, 2.0, 2.0], [2.0, 1.0, 3.0], [3.0, 3.0, 3.0]], dtype=np.float32
Expand All @@ -122,19 +135,26 @@ def test_evaluate_additive_error(num_frt):
)
truth_data = data - diff

historic_forecasts = generate_dataset(num_frt, data=data)
truths = generate_dataset(num_frt, truth_dataset=True, data=truth_data)
historic_forecasts, forecasts_mask = generate_dataset(
num_frt, data=data, masked=mask_forecast
)
truths, truths_mask = generate_dataset(
num_frt, truth_dataset=True, data=truth_data, masked=mask_truth
)
truths.remove_coord("forecast_reference_time")

result = evaluate_additive_error(historic_forecasts, truths, collapse_dim="time")

assert np.allclose(result, diff, atol=0.05)
if mask_forecast or mask_truth:
assert np.all(result.mask == np.ma.mask_or(truths_mask, forecasts_mask))


# Test case where we have a single or multiple reference forecasts.
@pytest.mark.parametrize("num_frt", (1, 4))
def test__define_metadata(num_frt):
"""Test the resultant metadata is as expected."""
reference_forecast_cube = generate_dataset(num_frt)
reference_forecast_cube, _ = generate_dataset(num_frt)

expected = ATTRIBUTES.copy()
expected["title"] = "Forecast bias data"
Expand All @@ -150,7 +170,7 @@ def test__define_metadata(num_frt):
@pytest.mark.parametrize("num_frt", (1, 4))
def test__create_bias_cube(num_frt):
"""Test that the bias cube has the expected structure."""
reference_forecast_cube = generate_dataset(num_frt)
reference_forecast_cube, _ = generate_dataset(num_frt)
result = CalculateForecastBias()._create_bias_cube(reference_forecast_cube)

# Check all but the time dim coords are consistent
Expand Down Expand Up @@ -193,11 +213,17 @@ def test__create_bias_cube(num_frt):
# truth values including case where num_truth_frt != num_fcst_frt.
@pytest.mark.parametrize("num_fcst_frt", (1, 50))
@pytest.mark.parametrize("num_truth_frt", (1, 48, 50))
def test_process(num_fcst_frt, num_truth_frt):
@pytest.mark.parametrize("mask_truth", (False, True))
@pytest.mark.parametrize("mask_forecast", (False, True)) # , True))
def test_process(num_fcst_frt, num_truth_frt, mask_truth, mask_forecast):
"""Test process function over a variations in number of historical forecasts and
truth values passed in."""
reference_forecast_cube = generate_dataset(num_fcst_frt)
truth_cube = generate_dataset(num_truth_frt, truth_dataset=True)
reference_forecast_cube, forecasts_mask = generate_dataset(
num_fcst_frt, masked=mask_forecast
)
truth_cube, truth_mask = generate_dataset(
num_truth_frt, truth_dataset=True, masked=mask_truth
)

result = CalculateForecastBias().process(reference_forecast_cube, truth_cube)
# Check that the values used in calculate mean bias are expected based on
Expand Down Expand Up @@ -225,6 +251,8 @@ def test_process(num_fcst_frt, num_truth_frt):
# from expected value, so here we use a larger tolerance.
expected_tol = 0.2 if (num_truth_frt == 1 and num_fcst_frt > 1) else 0.05
assert np.allclose(result.data, 0.0, atol=expected_tol)
if mask_forecast or mask_truth:
assert np.all(result.data.mask == np.ma.mask_or(truth_mask, forecasts_mask))


@pytest.mark.parametrize("num_fcst_frt", (1, 5))
Expand All @@ -236,7 +264,7 @@ def test_ensure_single_valued_forecast(num_fcst_frt, single_value_as_dim_coord):
# length > 1.
data = np.ones(shape=(4, 3, 3), dtype=np.float32)
# Test realization data
realization_cube = generate_dataset(num_frts=num_fcst_frt, data=data)
realization_cube, _ = generate_dataset(num_frts=num_fcst_frt, data=data)
with pytest.raises(ValueError, match="Multiple realization values"):
CalculateForecastBias()._ensure_single_valued_forecast(realization_cube)
# Test percentile data
Expand Down Expand Up @@ -287,7 +315,7 @@ def test_ensure_single_valued_forecast(num_fcst_frt, single_value_as_dim_coord):
assert result == expected_percentile

# Test the case where the input data does not have an associated ensemble coord.
cube_without_ens_coord = generate_dataset(num_frts=num_fcst_frt)
cube_without_ens_coord, _ = generate_dataset(num_frts=num_fcst_frt)
result = CalculateForecastBias()._ensure_single_valued_forecast(
cube_without_ens_coord
)
Expand Down

0 comments on commit b8e8d46

Please sign in to comment.