Skip to content

Commit 9e8c3e3

Browse files
committed
Add methods for summary statistics for Leaderboard
This commit adds functionality to find the best single model, worst single model, and median though `Leaderboard.find_best_single_model`, `Leaderboard.find_worst_single_model`, `Leaderboard.median`. For handling NaNs, all NaNs are converted to positive Inf or negative Inf as appopriate for finding the best single model or worst single model. For finding the median, the NaNs are filtered out.
1 parent 300b0b9 commit 9e8c3e3

File tree

5 files changed

+153
-5
lines changed

5 files changed

+153
-5
lines changed

NEWS.md

+15
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,21 @@ ClimaAnalysis.add_unit!(rmse_var, "CliMA", "K")
368368
ClimaAnalysis.add_unit!(rmse_var, Dict("CliMA" => "K")) # for adding multiple units
369369
```
370370
371+
#### Summary statistics
372+
Comparsion between models can be done using `find_best_single_model`,
373+
`find_worst_single_model`, and `median`. The functions `find_best_single_model` and
374+
`find_worst_single_model` default to the category "ANN" (corresponding to the annual mean),
375+
but any category be considered using the parameter `category_name`. Furthermore, the model's
376+
root mean squared errors (RMSEs) and name is returned. The function `median` only return the
377+
model's RMSEs. Any `NaN` that appear in the data is ignored when computing the summary
378+
statistics. See the example below on how to use this functionality.
379+
380+
```julia rmse_var
381+
ClimaAnalysis.find_best_single_model(rmse_var, category_name = "DJF")
382+
ClimaAnalysis.find_worst_single_model(rmse_var, category_name = "DJF")
383+
ClimaAnalysis.median(rmse_var)
384+
```
385+
371386
## Bug fixes
372387

373388
- Increased the default value for `warp_string` to 72.

docs/src/api.md

+3
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ Base.setindex!(rmse_var::RMSEVariable, rmse, model_name::String)
8888
Leaderboard.add_category
8989
Leaderboard.add_model
9090
Leaderboard.add_unit!
91+
Leaderboard.find_best_single_model
92+
Leaderboard.find_worst_single_model
93+
Leaderboard.median
9194
```
9295

9396
## Utilities

docs/src/rmse_var.md

+18-4
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,23 @@ ClimaAnalysis.add_unit!(rmse_var, "CliMA", "K")
129129
ClimaAnalysis.add_unit!(rmse_var, Dict("CliMA" => "K")) # for adding multiple units
130130
```
131131

132+
## Summary statistics
133+
134+
`ClimaAnalysis` provides several functions to compute summary statistics. As of now,
135+
`ClimaAnalysis` provides methods for find the best single model, the worst single model,
136+
and the median model.
137+
138+
The functions `find_best_single_model` and `find_worst_single_model` default to the category
139+
"ANN" (corresponding to the annual mean), but any category can be considered using the
140+
parameter `category_name`. Furthermore, the model's root mean squared errors (RMSEs) and the
141+
model's name are returned. The function `median` only returns the median model's RMSEs.
142+
143+
Any `NaN` that appears in the data is ignored when computing the summary statistics.
144+
145+
See the example below using this functionality.
146+
132147
```@repl rmse_var
133-
ClimaAnalysis.category_names(rmse_var2)
134-
ClimaAnalysis.model_names(rmse_var)
135-
ClimaAnalysis.rmse_units(rmse_var)
136-
rmse_var[:,:]
148+
ClimaAnalysis.find_best_single_model(rmse_var, category_name = "DJF")
149+
ClimaAnalysis.find_worst_single_model(rmse_var, category_name = "DJF")
150+
ClimaAnalysis.median(rmse_var)
137151
```

src/Leaderboard.jl

+75-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ export RMSEVariable,
1212
setindex!,
1313
add_category,
1414
add_model,
15-
add_unit!
15+
add_unit!,
16+
find_best_single_model,
17+
find_worst_single_model,
18+
median
1619

1720
"""
1821
Holding root mean squared errors over multiple categories and models for a single
@@ -518,4 +521,75 @@ function add_unit!(rmse_var::RMSEVariable, model_name2unit::Dict)
518521
return nothing
519522
end
520523

524+
"""
525+
_unit_check(rmse_var::RMSEVariable)
526+
527+
Return nothing if units are not missing and units are the same across all models. Otherwise,
528+
return an error.
529+
"""
530+
function _unit_check(rmse_var::RMSEVariable)
531+
units = values(rmse_var.units) |> collect
532+
unit_equal = all(unit -> unit == first(units), units)
533+
(!unit_equal || first(units) == "") &&
534+
error("Units are not the same across all models or units are missing")
535+
return nothing
536+
end
537+
538+
"""
539+
find_best_single_model(rmse_var::RMSEVariable; category_name = "ANN")
540+
541+
Return a tuple of the best single model and the name of the model. Find the best single
542+
model using the root mean squared errors of the category `category_name`.
543+
"""
544+
function find_best_single_model(rmse_var::RMSEVariable; category_name = "ANN")
545+
_unit_check(rmse_var)
546+
categ_names = category_names(rmse_var)
547+
ann_idx = categ_names |> (x -> findfirst(y -> (y == category_name), x))
548+
isnothing(ann_idx) &&
549+
error("The category $category_name does not exist in $categ_names")
550+
rmse_vec = rmse_var[:, ann_idx] |> copy
551+
# Replace all NaN with Inf so that we do not get NaN as a result
552+
# We do this instead of filtering because if we filter, then we need to keep track of
553+
# original indices
554+
replace!(rmse_vec, NaN => Inf)
555+
_, model_idx = findmin(rmse_vec)
556+
mdl_names = model_names(rmse_var)
557+
return rmse_var[model_idx, :], mdl_names[model_idx]
558+
end
559+
560+
"""
561+
find_worst_single_model(rmse_var::RMSEVariable; category_name = "ANN")
562+
563+
Return a tuple of the worst single model and the name of the model. Find the worst single
564+
model using the root mean squared errors of the category `category_name`.
565+
"""
566+
function find_worst_single_model(rmse_var::RMSEVariable; category_name = "ANN")
567+
_unit_check(rmse_var)
568+
categ_names = category_names(rmse_var)
569+
ann_idx = categ_names |> (x -> findfirst(y -> (y == category_name), x))
570+
isnothing(ann_idx) && error("Annual does not exist in $categ_names")
571+
rmse_vec = rmse_var[:, ann_idx] |> copy
572+
# Replace all NaN with Inf so that we do not get NaN as a result
573+
# We do this instead of filtering because if we filter, then we need to keep track of
574+
# original indices
575+
replace!(rmse_vec, NaN => -Inf)
576+
_, model_idx = findmax(rmse_vec)
577+
mdl_names = model_names(rmse_var)
578+
return rmse_var[model_idx, :], mdl_names[model_idx]
579+
end
580+
581+
"""
582+
median(rmse_var::RMSEVariable)
583+
584+
Find the median using the root mean squared errors across all categories.
585+
586+
Any `NaN` is ignored in computing the median.
587+
"""
588+
function median(rmse_var::RMSEVariable)
589+
_unit_check(rmse_var)
590+
# Drop dimension so that size is (n,) instead of (1,n) so that it is consistent with the
591+
# size of the arrays returned from find_worst_single_model and find_best_single_model
592+
return dropdims(nanmedian(rmse_var.RMSEs, dims = 1), dims = 1)
593+
end
594+
521595
end

test/test_Leaderboard.jl

+42
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,45 @@ end
263263
@test ClimaAnalysis.rmse_units(rmse_var)["hello1"] == "units1"
264264
@test ClimaAnalysis.rmse_units(rmse_var)["hello2"] == "units2"
265265
end
266+
267+
@testset "Finding best, worst, and median model" begin
268+
csv_file_path = joinpath(@__DIR__, "sample_data/test_csv.csv")
269+
rmse_var = ClimaAnalysis.read_rmses(csv_file_path, "ta")
270+
rmse_var[:, :] = [[1.0 2.0 3.0 4.0 5.0]; [6.0 7.0 8.0 9.0 10.0]]
271+
ClimaAnalysis.add_unit!(rmse_var, "ACCESS-CM2", "units")
272+
ClimaAnalysis.add_unit!(rmse_var, "ACCESS-ESM1-5", "units")
273+
val, model_name =
274+
ClimaAnalysis.find_best_single_model(rmse_var, category_name = "ANN")
275+
@test model_name == "ACCESS-CM2"
276+
@test val == [1.0, 2.0, 3.0, 4.0, 5.0]
277+
@test val |> size == (5,)
278+
279+
val, model_name =
280+
ClimaAnalysis.find_worst_single_model(rmse_var, category_name = "ANN")
281+
@test model_name == "ACCESS-ESM1-5"
282+
@test val == [6.0, 7.0, 8.0, 9.0, 10.0]
283+
@test val |> size == (5,)
284+
285+
val = ClimaAnalysis.median(rmse_var)
286+
@test val == [7.0, 9.0, 11.0, 13.0, 15.0] ./ 2.0
287+
@test val |> size == (5,)
288+
289+
# Test with NaN in RMSE array
290+
rmse_var = ClimaAnalysis.add_model(rmse_var, "for adding NaN")
291+
ClimaAnalysis.add_unit!(rmse_var, "for adding NaN", "units")
292+
val, model_name =
293+
ClimaAnalysis.find_best_single_model(rmse_var, category_name = "ANN")
294+
@test model_name == "ACCESS-CM2"
295+
@test val == [1.0, 2.0, 3.0, 4.0, 5.0]
296+
@test val |> size == (5,)
297+
298+
val, model_name =
299+
ClimaAnalysis.find_worst_single_model(rmse_var, category_name = "ANN")
300+
@test model_name == "ACCESS-ESM1-5"
301+
@test val == [6.0, 7.0, 8.0, 9.0, 10.0]
302+
@test val |> size == (5,)
303+
304+
val = ClimaAnalysis.median(rmse_var)
305+
@test val == [7.0, 9.0, 11.0, 13.0, 15.0] ./ 2.0
306+
@test val |> size == (5,)
307+
end

0 commit comments

Comments
 (0)