Skip to content

Commit 615ca82

Browse files
committed
Add Leaderboard.read_rmses
This commit adds the functionality of reading RMSEs values from a CSV file.
1 parent a7748da commit 615ca82

File tree

7 files changed

+154
-1
lines changed

7 files changed

+154
-1
lines changed

NEWS.md

+20
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,26 @@ rmse_var = ClimaAnalysis.RMSEVariable(
323323
A `RMSEVariable` can be inspected using `model_names`, `category_names`, and `rmse_units`
324324
which provide the model names, the category names, and the units respectively.
325325

326+
#### Reading RMSEs from CSV file
327+
328+
A CSV file containing model names in the first column and root mean squared errors in the
329+
subsequent columns with a header describing each category (i.e. seasons) can be read into
330+
a `RMSEVariable`. See the example below on how to use this functionality.
331+
332+
```julia
333+
rmse_var = ClimaAnalysis.read_rmses("./data/test_csv.csv", "ta")
334+
rmse_var = ClimaAnalysis.read_rmses(
335+
"./data/test_csv.csv",
336+
"ta",
337+
units = Dict("ACCESS-CM2" => "K", "ACCESS-ESM1-5" => "K"), # passing units as a dictionary
338+
)
339+
rmse_var = ClimaAnalysis.read_rmses(
340+
"./data/test_csv.csv",
341+
"ta",
342+
units = "K", # passing units as a string
343+
)
344+
```
345+
326346
## Bug fixes
327347

328348
- Increased the default value for `warp_string` to 72.

docs/src/api.md

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Leaderboard.RMSEVariable(short_name::String, model_names::Vector{String}, catego
8080
Leaderboard.model_names
8181
Leaderboard.category_names
8282
Leaderboard.rmse_units
83+
Leaderboard.read_rmses
8384
```
8485

8586
## Utilities

docs/src/data/test_csv.csv

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Model,DJF,MAM,JJA,SON,ANN
2+
ACCESS-CM2,11.941,10.178,13.279,10.443,8.710
3+
ACCESS-ESM1-5,15.752,12.477,15.955,12.972,NaN

docs/src/rmse_var.md

+30
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,34 @@ which provide the model names, the category names, and the units respectively.
6363
ClimaAnalysis.model_names(rmse_var)
6464
ClimaAnalysis.category_names(rmse_var)
6565
ClimaAnalysis.rmse_units(rmse_var)
66+
```
67+
68+
## Reading RMSEs from CSV file
69+
70+
Typically, the root mean squared errors (RMSEs) of different models across different
71+
categories are stored in a different file and need to be loaded in. `ClimaAnalysis` can load
72+
this information from a CSV file and store it in a `RMSEVariable`. The format of the CSV
73+
file should have a header consisting of the entry "model_name" (or any other text as it is
74+
ignored by the function) and rest of the entries should be the category names. Each row
75+
after the header should start with the model name and the root mean squared errors for each
76+
category for that model. The entries of the CSV file should be separated by commas.
77+
78+
See the example below using `read_rmses` where data is loaded from `test_csv.csv` and a
79+
short name of `ta` is provided. One can also pass in a dictionary mapping model names to
80+
units for `units` or a string if the units are the same for all the models.
81+
82+
```@example rmse_var
83+
rmse_var = ClimaAnalysis.read_rmses("./data/test_csv.csv", "ta")
84+
rmse_var = ClimaAnalysis.read_rmses(
85+
"./data/test_csv.csv",
86+
"ta",
87+
units = Dict("ACCESS-CM2" => "K", "ACCESS-ESM1-5" => "K"), # passing units as a dictionary
88+
)
89+
rmse_var = ClimaAnalysis.read_rmses(
90+
"./data/test_csv.csv",
91+
"ta",
92+
units = "K", # passing units as a string
93+
)
94+
95+
nothing # hide
6696
```

src/Leaderboard.jl

+69-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import NaNStatistics: nanmedian
66
export RMSEVariable,
77
model_names,
88
category_names,
9-
rmse_units
9+
rmse_units,
10+
read_rmses
1011

1112
"""
1213
Holding root mean squared errors over multiple categories and models for a single
@@ -236,4 +237,71 @@ Return all the unit of the models in `rmse_var`.
236237
"""
237238
rmse_units(rmse_var::RMSEVariable) = rmse_var.units
238239

240+
"""
241+
read_rmses(csv_file::String, short_name::String; units = nothing)
242+
243+
Read a CSV file and create a RMSEVariable with the `short_name` of the variable.
244+
245+
The format of the CSV file should have a header consisting of the entry "model_name" (or any
246+
other text as it is ignored by the function) and rest of the entries should be the category
247+
names. Each row after the header should start with the model name and the root mean squared
248+
errors for each category for that model. The entries of the CSV file should be separated by
249+
commas.
250+
251+
The parameter `units` can be a dictionary mapping model name to unit or a string. If `units`
252+
is a string, then units will be the same across all models. If units is `nothing`, then the
253+
unit is missing for each model which is denoted by an empty string.
254+
"""
255+
function read_rmses(csv_file::String, short_name::String; units = nothing)
256+
# Intialize variables we need to construct RMSEVariable
257+
model_names = Vector{String}()
258+
model_rmse_vec = []
259+
category_names = nothing
260+
open(csv_file, "r") do io
261+
header = readline(io)
262+
# Get categories (e.g. DJF, MAM, JJA, SON, ANN)
263+
category_names = String.(split(header, ','))
264+
265+
# get rid of the first column name which is the column named "model_name"
266+
category_names |> popfirst!
267+
268+
# Process each line
269+
for (line_num, line) in enumerate(eachline(io))
270+
# Split the line by comma
271+
fields = split(line, ',')
272+
273+
# Check if any entry is missing in the CSV file
274+
length(fields) != (length(category_names) + 1) &&
275+
error("Missing RMSEs for line $(line_num + 1) in CSV file")
276+
277+
# Grab model name
278+
model_name = fields[1]
279+
280+
# the rest of the row is the rmse for each category
281+
model_rmse = map(x -> parse(Float64, x), fields[2:end])
282+
283+
push!(model_names, model_name)
284+
push!(model_rmse_vec, model_rmse)
285+
end
286+
end
287+
model_rmses = stack(model_rmse_vec, dims = 1)
288+
isnothing(units) && (
289+
units = Dict{valtype(model_names), String}([
290+
(model_name, "") for model_name in model_names
291+
])
292+
)
293+
units isa String && (
294+
units = Dict{valtype(model_names), String}([
295+
model_name => units for model_name in model_names
296+
])
297+
)
298+
return RMSEVariable(
299+
short_name,
300+
model_names,
301+
category_names,
302+
model_rmses,
303+
units,
304+
)
305+
end
306+
239307
end

test/sample_data/test_csv.csv

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Model,DJF,MAM,JJA,SON,ANN
2+
ACCESS-CM2,11.941,10.178,13.279,10.443,8.710
3+
ACCESS-ESM1-5,15.752,12.477,15.955,12.972,NaN

test/test_Leaderboard.jl

+28
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,31 @@ import ClimaAnalysis
127127
Dict("model1" => ""),
128128
)
129129
end
130+
131+
@testset "Reading RMSEs from CSV file" begin
132+
# Testing constructor using CSV file
133+
csv_file_path = joinpath(@__DIR__, "sample_data/test_csv.csv")
134+
rmse_var = ClimaAnalysis.read_rmses(csv_file_path, "ta")
135+
@test ClimaAnalysis.model_names(rmse_var) == ["ACCESS-CM2", "ACCESS-ESM1-5"]
136+
@test ClimaAnalysis.category_names(rmse_var) ==
137+
["DJF", "MAM", "JJA", "SON", "ANN"]
138+
@test ClimaAnalysis.rmse_units(rmse_var) ==
139+
Dict("ACCESS-CM2" => "", "ACCESS-ESM1-5" => "")
140+
@test rmse_var.short_name == "ta"
141+
@test rmse_var.RMSEs[1, 1] == 11.941
142+
@test isnan(rmse_var.RMSEs[2, 5])
143+
144+
# Testing constructor using CSV file with units provided
145+
rmse_var = ClimaAnalysis.read_rmses(
146+
csv_file_path,
147+
"ta",
148+
units = Dict("ACCESS-ESM1-5" => "m", "wacky" => "weird"),
149+
)
150+
@test ClimaAnalysis.rmse_units(rmse_var) ==
151+
Dict("ACCESS-CM2" => "", "ACCESS-ESM1-5" => "m")
152+
153+
# Testing constructor using CSV file with units being a string
154+
rmse_var = ClimaAnalysis.read_rmses(csv_file_path, "ta", units = "m")
155+
@test ClimaAnalysis.rmse_units(rmse_var) ==
156+
Dict("ACCESS-CM2" => "m", "ACCESS-ESM1-5" => "m")
157+
end

0 commit comments

Comments
 (0)