Skip to content

Commit 75eb2bd

Browse files
sdmdata and models as namedtuple (#9)
* some small fixes * add sdmdata * use sdmdata in other functions * fix tests * allow rasters 0.11 * fix sdmdata when resampler is a TrainTestPairs * check_data for rasterstack * model_key not model_name * default single-threaded shapley * small data fixes * update docstrings * sdmdata fix * add mljmodels dependency to test * fix tests * model_keys not model_names (2) * prettier docstrings
1 parent a72b70e commit 75eb2bd

File tree

12 files changed

+348
-217
lines changed

12 files changed

+348
-217
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,17 @@ SpeciesDistributionModelsMakieExt = "Makie"
4040
[compat]
4141
CategoricalDistributions = "0.1.14"
4242
MLJGLMInterface = "0.3.7"
43-
Rasters = "0.10.1"
43+
Rasters = "0.10.1, 0.11"
4444
StatisticalMeasures = "0.1.5"
4545
StatsModels = "0.7.3"
4646
julia = "1.9"
4747

4848
[extras]
4949
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
50+
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
51+
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
5052
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
5153
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
52-
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
5354

5455
[targets]
55-
test = ["Distributions", "StableRNGs", "Test", "Makie"]
56+
test = ["Distributions", "Makie", "MLJModels", "StableRNGs", "Test"]

ext/SpeciesDistributionModelsMakieExt/SpeciesDistributionModelsMakieExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module SpeciesDistributionModelsMakieExt
22
using Makie, SpeciesDistributionModels
33
import SpeciesDistributionModels as SDM
4-
import SpeciesDistributionModels: model_names, machine_evaluations, sdm_machines, data,
4+
import SpeciesDistributionModels: model_keys, machine_evaluations, sdm_machines, data,
55
_conf_mats_from_thresholds
66
import SpeciesDistributionModels: interactive_evaluation
77
import Statistics, Loess

ext/SpeciesDistributionModelsMakieExt/plotrecipes.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
function _model_controls!(fig, ensemble)
22
toggles = [Makie.Toggle(fig, active = true) for i in 1:Base.length(ensemble)]
3-
labels = [Makie.Label(fig, String(key)) for key in SDM.model_names(ensemble)]
3+
labels = [Makie.Label(fig, String(key)) for key in SDM.model_keys(ensemble)]
44
g = Makie.grid!(hcat(toggles, labels))
55
return g, toggles
66
end
77

88

99
function Makie.boxplot(ev::SDMensembleEvaluation, measure::Symbol)
10-
modelnames = Base.string.(model_names(ev.ensemble))
10+
modelnames = collect(Base.string.(model_keys(ev.ensemble)))
1111
f = Makie.Figure()
1212

1313
for (i, t) in enumerate((:train, :test))
@@ -55,13 +55,13 @@ function SDM.interactive_evaluation(ensemble; thresholds = 0:0.01:1)
5555

5656
idx_by_model = map(enumerate(ensemble)) do (i, e)
5757
fill(i, Base.length(e))
58-
end |> NamedTuple{Tuple(model_names(ensemble))}
58+
end |> NamedTuple{Tuple(model_keys(ensemble))}
5959

6060
n_models = length(idx_by_model)
6161

6262
conf_mats = mapreduce(hcat, ensemble) do gr
6363
map(gr) do sdm_machine
64-
rows = sdm_machine.test_rows
64+
rows = SDM.test_rows(sdm_machine)
6565
y_hat = SDM.MLJBase.predict(sdm_machine.machine; rows)
6666
y = data(sdm_machine).response[rows]
6767
_conf_mats_from_thresholds(SDM.MLJBase.pdf.(y_hat, true), y, thresholds)

src/SpeciesDistributionModels.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Shapley, Loess
77
using MLJBase: pdf
88
using Rasters: Raster, RasterStack, Band
99
using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count
10-
using ComputationalResources: CPU1, CPUThreads, AbstractCPU
10+
using ComputationalResources: CPU1, CPUThreads, AbstractCPU, CPUProcesses
1111

1212
using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count
1313
using StatisticalMeasures: auc, kappa, sensitivity, selectivity, accuracy
14-
import MLJBase: StratifiedCV, CV, Holdout, ResamplingStrategy
14+
import MLJBase: StratifiedCV, CV, Holdout, ResamplingStrategy, Machine, Probabilistic
1515

16-
export SDMensemble, predict, sdm, select, machines, machine_keys,
16+
export SDMensemble, predict, sdm, sdmdata, select, machines, machine_keys,
1717
remove_collinear,
1818
explain, variable_importance, ShapleyValues,
1919
SDMmachineExplanation, SDMgroupExplanation, SDMensembleExplanation,
@@ -23,13 +23,13 @@ export SDMensemble, predict, sdm, select, machines, machine_keys,
2323
export auc, kappa, sensitivity, selectivity, accuracy,
2424
Continuous, OrderedFactor, Multiclass, Count,
2525
StratifiedCV, CV, Holdout, ResamplingStrategy
26-
26+
#include("learningnetwork.jl")
27+
include("models.jl")
2728
include("data_utils.jl")
2829
include("resample.jl")
2930
# export stubs for extensions
3031
export interactive_response_curves, interactive_evaluation
3132

32-
3333
include("collinearity.jl")
3434
include("models.jl")
3535
include("ensemble.jl")

src/data_utils.jl

Lines changed: 115 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,129 @@
1-
### Miscelanious utilities to deal with data issues such as names, missing values
2-
31
# Convert a BitArray to a CategoricalArray. Faster and type-stable version of `categorical`
2+
BooleanCategorical{N} = CategoricalArrays.CategoricalArray{Bool, N, UInt8} where N
43
function boolean_categorical(A::BitArray{N}) where N
5-
CategoricalArrays.CategoricalArray{Bool, N, UInt8}(A, levels=[false, true], ordered=false)
4+
BooleanCategorical{N}(A, levels=[false, true], ordered=false)
65
end
76
boolean_categorical(A::AbstractVector{Bool}) = boolean_categorical(BitArray(A))
87

9-
function _get_predictor_names(p, a)
10-
predictors = Base.intersect(Tables.schema(a).names, Tables.schema(p).names)
11-
predictors = filter!(!=(:geometry), predictors) # geometry is never a variable
12-
length(predictors) > 0 || error("Presence and absence data have no common variable names - can't fit the ensemble.")
13-
return predictors
8+
9+
struct SDMdata{K}
10+
predictor::NamedTuple
11+
response::CategoricalArrays.CategoricalArray
12+
geometry::Union{Nothing, Vector}
13+
traintestpairs::MLJBase.TrainTestPairs
14+
resampler::Union{Nothing, MLJBase.ResamplingStrategy}
15+
16+
function SDMdata(predictor::P, response, geometry, traintestpairs, resampler) where P<:NamedTuple{K} where K
17+
new{K}(predictor, response, geometry, traintestpairs, resampler)
18+
end
1419
end
1520

21+
function Base.show(io::IO, mime::MIME"text/plain", data::SDMdata{K}) where K
22+
y = response(data)
23+
print("SDMdata object with ")
24+
printstyled(sum(y), bold = true)
25+
print(" presence points and ")
26+
printstyled(length(y) - sum(y), bold = true)
27+
print(" absence points. \n \n")
1628

17-
function _predictor_response_from_presence_absence(presences, absences, predictors)
18-
p_columns = Tables.columns(presences)
19-
a_columns = Tables.columns(absences)
20-
n_presence = Tables.rowcount(p_columns)
21-
n_absence = Tables.rowcount(a_columns)
29+
printstyled("Resampling: \n", bold = true)
30+
println("Data is divided into $(nfolds(data)) folds using resampling strategy $(resampler(data)).")
2231

23-
# merge presence and absence data into one namedtuple of vectors
24-
predictor_values = NamedTuple{Tuple(predictors)}([[a_columns[var]; p_columns[var]] for var in predictors])
25-
response_values = boolean_categorical([falses(n_absence); trues(n_presence)])
26-
return predictor_values, response_values
32+
n_presences = length.(getindex.(traintestpairs(data), 1))
33+
n_absences = length.(getindex.(traintestpairs(data), 2))
34+
table_cols = hcat(1:nfolds(data), n_presences, n_absences)
35+
header = (["fold", "presences", "absences"])
36+
PrettyTables.pretty_table(io, table_cols; header = header)
37+
38+
printstyled("Predictor variables: \n", bold = true)
39+
Base.show(io, mime, MLJBase.schema(predictor(data)))
40+
41+
if isnothing(geometry(data))
42+
print("Does not contain geometry data")
43+
else
44+
print("Also contains geometry data")
45+
end
46+
end
47+
48+
49+
_gettrainrows(d::SDMdata, i) = d.traintestpairs[i][1]
50+
_gettestrows(d::SDMdata, i) = d.traintestpairs[i][2]
51+
predictor(d::SDMdata) = d.predictor
52+
predictorkeys(d::SDMdata{K}) where K = K
53+
response(d::SDMdata) = convert(AbstractArray{Bool}, d.response)
54+
geometry(d::SDMdata) = d.geometry
55+
traintestpairs(d::SDMdata) = d.traintestpairs
56+
resampler(d::SDMdata) = d.resampler
57+
nfolds(d::SDMdata) = length(d.traintestpairs)
58+
59+
function _sdmdata(presences, absences, resampler, ::Nothing)
60+
predictorkeys = Tuple(Base.intersect(Tables.schema(presences).names, Tables.schema(absences).names))
61+
length(predictorkeys) > 0 || error("Presence and absence data have no common variable names - can't fit the ensemble.")
62+
_sdmdata(presences, absences, resampler, predictorkeys)
2763
end
2864

65+
function _sdmdata(presences, absences, resampler, predictorkeys::NTuple{<:Any, <:Symbol})
66+
X, y = _predictor_response_from_presence_absence(presences, absences, predictorkeys)
67+
_sdmdata(X, y, resampler, predictorkeys)
68+
end
69+
70+
# in case input is a table
71+
function _sdmdata(X, response::BitVector, resampler, ::Nothing)
72+
columns = Tables.columntable(X)
73+
Tables.rowcount(columns) == length(response) || error("Number of rows in predictors and response do not match")
74+
predictorkeys = Tables.columnnames(columns)
75+
_sdmdata(columns, response, resampler, predictorkeys)
76+
end
77+
78+
_sdmdata(X::Tables.ColumnTable{K}, y::BitVector, resampler, predictorkeys::NTuple{<:Any, <:Symbol}) where K =
79+
if K == predictorkeys
80+
_sdmdata(X, boolean_categorical(y), resampler)
81+
else
82+
_sdmdata(X[predictorkeys], boolean_categorical(y), resampler)
83+
end
84+
85+
function _sdmdata(
86+
X::Tables.ColumnTable,
87+
y::BooleanCategorical,
88+
resampler::CV,
89+
)
90+
shuffled_resampler = CV(; nfolds = resampler.nfolds, rng = resampler.rng, shuffle = true)
91+
traintestpairs = MLJBase.train_test_pairs(shuffled_resampler, eachindex(y), X, y)
92+
_sdmdata(X, y, traintestpairs, shuffled_resampler)
93+
end
94+
function _sdmdata(
95+
X::Tables.ColumnTable,
96+
y::BooleanCategorical,
97+
resampler::MLJBase.ResamplingStrategy,
98+
)
99+
traintestpairs = MLJBase.train_test_pairs(resampler, eachindex(y), X, y)
100+
_sdmdata(X, y, traintestpairs, resampler)
101+
end
102+
103+
function _sdmdata(
104+
X::Tables.ColumnTable,
105+
y::BooleanCategorical,
106+
traintestpairs::MLJBase.TrainTestPairs,
107+
resampler = CustomRows()
108+
)
109+
geometries = :geometry keys(X) ? Tables.getcolumn(X, :geometry) : nothing
110+
X = Base.structdiff(X, NamedTuple{(:geometry,)})
111+
SDMdata(X, y, geometries, traintestpairs, resampler)
112+
end
29113

30114
cpu_backend(threaded) = threaded ? CPUThreads() : CPU1()
31115
_map(::CPU1) = Base.map
32-
_map(::CPUThreads) = ThreadsX.map
116+
_map(::CPUThreads) = ThreadsX.map
117+
118+
119+
function _predictor_response_from_presence_absence(presences, absences, predictorkeys::NTuple{<:Any, <:Symbol})
120+
p_columns = Tables.columns(presences)
121+
a_columns = Tables.columns(absences)
122+
n_presence = Tables.rowcount(p_columns)
123+
n_absence = Tables.rowcount(a_columns)
124+
125+
# merge presence and absence data into one namedtuple of vectors
126+
X = NamedTuple{predictorkeys}([[a_columns[var]; p_columns[var]] for var in predictorkeys])
127+
y = [falses(n_absence); trues(n_presence)]
128+
return (X, y)
129+
end

0 commit comments

Comments
 (0)