Skip to content

Commit

Permalink
fix predict for Rasters 0.13 (#20)
Browse files Browse the repository at this point in the history
* use maplayers

* add tests for predicting to rasters

* require rasters 0.13
  • Loading branch information
tiemvanderdeure authored Jan 2, 2025
1 parent 3e6c0ad commit 50811e3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Loess = "0.6"
MLJBase = "1.7.0"
Makie = "0.20, 0.21"
PrettyTables = "2"
Rasters = "0.12, 0.13"
Rasters = "0.13"
ScientificTypesBase = "3"
Shapley = "0.1"
StatisticalMeasures = "0.1.5"
Expand Down
2 changes: 1 addition & 1 deletion src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ function _reformat_and_predict_raster(s::Union{<:SDMensemble, SDMgroup, SDMmachi
rs_preds = rs[predictorkeys(data(s))]
missing_mask = Rasters.boolmask(rs_preds)
d = rs_preds[missing_mask]
if any(map(x -> Missing <: eltype(x), rs_preds))
if any(Rasters.maplayers(x -> Missing <: eltype(x), rs_preds))
# to get rid of Union{Missing, Float64} etc.
layertypes = map(x -> Base.nonmissingtype(eltype(x)), (Rasters.layers(rs_preds)))
nttype = NamedTuple{keys(layertypes), Tuple{values(layertypes)...}}
Expand Down
38 changes: 28 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ using SpeciesDistributionModels, MLJBase, MLJModels, Tables
import SpeciesDistributionModels as SDM
using StableRNGs, Distributions, Test
using Makie
using Rasters

using MLJGLMInterface: LinearBinaryClassifier
using EvoTrees: EvoTreeClassifier
using MLJDecisionTreeInterface: RandomForestClassifier

rng = StableRNG(0)
#using Random; rng = Random.GLOBAL_RNG

# some mock data
n = 100
backgrounddata = (a = rand(rng, n), b = rand(rng, n), c = rand(rng, n))
Expand Down Expand Up @@ -48,6 +50,22 @@ presencedata = (a = rand(rng, n), b = rand(rng, n).^2, c = sqrt.(rand(rng, n)))

machine_aucs = SDM.machine_evaluations(evaluation).test.auc

# explain
expl = explain(ensemble; method = ShapleyValues(10; rng))
varimp = variable_importance(expl)
@test varimp.b > varimp.a
@test varimp.c > varimp.a

# plots
interactive_evaluation(ensemble, thresholds = 0:0.001:1)
interactive_response_curves(expl)
boxplot(evaluation, :auc)
end

data = sdmdata(presencedata, backgrounddata; resampler = CV(nfolds = 5, shuffle = true))
ensemble = sdm(data, (; lm = LinearBinaryClassifier()))

@testset "predict" begin
pr1 = SDM.predict(ensemble, backgrounddata)
pr2 = SDM.predict(ensemble, backgrounddata; reducer = maximum)
pr3 = SDM.predict(ensemble, backgrounddata; reducer = x -> sum(x .> 0.5), by_group = true)
Expand All @@ -61,16 +79,16 @@ presencedata = (a = rand(rng, n), b = rand(rng, n).^2, c = sqrt.(rand(rng, n)))
@test_throws ArgumentError SDM.predict(ensemble, backgrounddata[(:a,)])
@test_throws Exception SDM.predict(ensemble, backgrounddata; by_group = true)

# explain
expl = explain(ensemble; method = ShapleyValues(10; rng))
varimp = variable_importance(expl)
@test varimp.b > varimp.a
@test varimp.c > varimp.a

# plots
interactive_evaluation(ensemble, thresholds = 0:0.001:1)
interactive_response_curves(expl)
boxplot(evaluation, :auc)
## to a Raster
ds = (X(1:100), Y(1:100))
rs = RasterStack((a = rand(ds), b = rand(ds), c = rand(ds)), missingval = 0.0)
# make the first value missing
rs[1] = (a = 0, b = 0, c = 0)
raspr = SDM.predict(ensemble, rs)
@test all(ismissing, raspr[X=1, Y=1])
# all values should be between 0 and 1
extr = extrema(skipmissing(raspr))
@test extr[1] > 0 && extr[2] < 1
end

@testset "collinearity" begin
Expand Down

0 comments on commit 50811e3

Please sign in to comment.