From 50811e3b4ba40404f0ed35bdbb4afea4e47482d7 Mon Sep 17 00:00:00 2001 From: Tiem van der Deure Date: Thu, 2 Jan 2025 15:27:45 +0100 Subject: [PATCH] fix predict for Rasters 0.13 (#20) * use maplayers * add tests for predicting to rasters * require rasters 0.13 --- Project.toml | 2 +- src/predict.jl | 2 +- test/runtests.jl | 38 ++++++++++++++++++++++++++++---------- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index d1151e0..f0b082e 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ Loess = "0.6" MLJBase = "1.7.0" Makie = "0.20, 0.21" PrettyTables = "2" -Rasters = "0.12, 0.13" +Rasters = "0.13" ScientificTypesBase = "3" Shapley = "0.1" StatisticalMeasures = "0.1.5" diff --git a/src/predict.jl b/src/predict.jl index a81a333..9754835 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -99,7 +99,7 @@ function _reformat_and_predict_raster(s::Union{<:SDMensemble, SDMgroup, SDMmachi rs_preds = rs[predictorkeys(data(s))] missing_mask = Rasters.boolmask(rs_preds) d = rs_preds[missing_mask] - if any(map(x -> Missing <: eltype(x), rs_preds)) + if any(Rasters.maplayers(x -> Missing <: eltype(x), rs_preds)) # to get rid of Union{Missing, Float64} etc. layertypes = map(x -> Base.nonmissingtype(eltype(x)), (Rasters.layers(rs_preds))) nttype = NamedTuple{keys(layertypes), Tuple{values(layertypes)...}} diff --git a/test/runtests.jl b/test/runtests.jl index 0806cf9..e7cfaa1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using SpeciesDistributionModels, MLJBase, MLJModels, Tables import SpeciesDistributionModels as SDM using StableRNGs, Distributions, Test using Makie +using Rasters using MLJGLMInterface: LinearBinaryClassifier using EvoTrees: EvoTreeClassifier @@ -9,6 +10,7 @@ using MLJDecisionTreeInterface: RandomForestClassifier rng = StableRNG(0) #using Random; rng = Random.GLOBAL_RNG + # some mock data n = 100 backgrounddata = (a = rand(rng, n), b = rand(rng, n), c = rand(rng, n)) @@ -48,6 +50,22 @@ presencedata = (a = rand(rng, n), b = rand(rng, n).^2, c = sqrt.(rand(rng, n))) machine_aucs = SDM.machine_evaluations(evaluation).test.auc + # explain + expl = explain(ensemble; method = ShapleyValues(10; rng)) + varimp = variable_importance(expl) + @test varimp.b > varimp.a + @test varimp.c > varimp.a + + # plots + interactive_evaluation(ensemble, thresholds = 0:0.001:1) + interactive_response_curves(expl) + boxplot(evaluation, :auc) +end + +data = sdmdata(presencedata, backgrounddata; resampler = CV(nfolds = 5, shuffle = true)) +ensemble = sdm(data, (; lm = LinearBinaryClassifier())) + +@testset "predict" begin pr1 = SDM.predict(ensemble, backgrounddata) pr2 = SDM.predict(ensemble, backgrounddata; reducer = maximum) pr3 = SDM.predict(ensemble, backgrounddata; reducer = x -> sum(x .> 0.5), by_group = true) @@ -61,16 +79,16 @@ presencedata = (a = rand(rng, n), b = rand(rng, n).^2, c = sqrt.(rand(rng, n))) @test_throws ArgumentError SDM.predict(ensemble, backgrounddata[(:a,)]) @test_throws Exception SDM.predict(ensemble, backgrounddata; by_group = true) - # explain - expl = explain(ensemble; method = ShapleyValues(10; rng)) - varimp = variable_importance(expl) - @test varimp.b > varimp.a - @test varimp.c > varimp.a - - # plots - interactive_evaluation(ensemble, thresholds = 0:0.001:1) - interactive_response_curves(expl) - boxplot(evaluation, :auc) + ## to a Raster + ds = (X(1:100), Y(1:100)) + rs = RasterStack((a = rand(ds), b = rand(ds), c = rand(ds)), missingval = 0.0) + # make the first value missing + rs[1] = (a = 0, b = 0, c = 0) + raspr = SDM.predict(ensemble, rs) + @test all(ismissing, raspr[X=1, Y=1]) + # all values should be between 0 and 1 + extr = extrema(skipmissing(raspr)) + @test extr[1] > 0 && extr[2] < 1 end @testset "collinearity" begin