From c47d77258668b6359b5dc526ac50bcbee8c2e234 Mon Sep 17 00:00:00 2001 From: Tiem van der Deure Date: Thu, 18 Jan 2024 14:54:48 +0100 Subject: [PATCH] add collinearity methods (#6) * add methods for collinearity detection using vif or cor * add statsmodels requirement * get rid of max size argument * fix tests * add documentation for remove_collinear * better handing of categorical variables --- Project.toml | 5 ++ src/SpeciesDistributionModels.jl | 7 +- src/collinearity.jl | 107 +++++++++++++++++++++++++++++++ test/runtests.jl | 38 ++++++++++- 4 files changed, 152 insertions(+), 5 deletions(-) create mode 100644 src/collinearity.jl diff --git a/Project.toml b/Project.toml index effac33..dc489b5 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,8 @@ DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +Lasso = "b4fcebef-c861-5a0f-a7e2-ba9dc32b180a" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Loess = "4345ca2d-374a-55d4-8d30-97f9976e7612" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" @@ -22,11 +24,14 @@ Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065" StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] julia = "1.6" +StatsModels = "0.7.3" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/SpeciesDistributionModels.jl b/src/SpeciesDistributionModels.jl index 1adda08..c04d28d 100644 --- a/src/SpeciesDistributionModels.jl +++ b/src/SpeciesDistributionModels.jl @@ -1,6 +1,6 @@ module SpeciesDistributionModels -import Tables, StatsBase, Statistics +import Tables, StatsBase, Statistics, StatsAPI, StatsModels, LinearAlgebra import MLJBase, StatisticalMeasures, StatisticalMeasuresBase, ScientificTypesBase, CategoricalArrays import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess @@ -13,8 +13,10 @@ using Makie: Toggle, Label, scatter!, lines!, Axis, Figure, GridLayout, lift using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count export SDMensemble, predict, sdm, select, machines, machine_keys, shap, - interactive_evaluation, interactive_response_curves + interactive_evaluation, interactive_response_curves, + remove_collinear +include("collinearity.jl") include("models.jl") include("ensemble.jl") include("predict.jl") @@ -22,5 +24,4 @@ include("explain.jl") include("evaluate.jl") include("plots.jl") - end diff --git a/src/collinearity.jl b/src/collinearity.jl new file mode 100644 index 0000000..02cc10a --- /dev/null +++ b/src/collinearity.jl @@ -0,0 +1,107 @@ +abstract type AbstractCollinearityMethod end + +struct Gvif <: AbstractCollinearityMethod + threshold +end +Gvif(; threshold) = Gvif(threshold) + +struct Vif <: AbstractCollinearityMethod + threshold +end +Vif(; threshold) = Vif(threshold) + +struct Pearson <: AbstractCollinearityMethod + threshold +end +Pearson(; threshold) = Pearson(threshold) + +# Need to add a method for RasterStack, unless it will be Tables.jl compatible +""" + remove_collinear(data; method, silent = false) + +Removes strongly correlated variables in `data`, until correlation is below a threshold specified in `method`. + +`method` can currently be either `Gvif`, `Vif` or `Pearson`, which use GVIF, VIF, or Pearson's r, respectively. +GVIF and VIF are similar method, but GVIF includes categorical variables whereas VIF ignores them. + +To run without showing information about collinearity scores, set `silent = true`. + +## Example +```julia +julia> import SpeciesDistributionModels as SDM +julia> mydata = (a = 1:100, b = sqrt.(1:100), c = rand(100)) +julia> SDM.remove_collinear(mydata; method = SDM.Vif(10)) +[ Info: a has highest GVIF of 28.367942095054225 +[ Info: Removing a, 2 variables remaining +[ Info: b has highest GVIF of 1.0077618445543057 +[ Info: All variables are below threshold, returning remaining variables +(:b, :c) +``` + +""" +function remove_collinear(data; method, silent::Bool = false) + schema = Tables.schema(data) + datakeys = schema.names + iscategorical = collect(schema.types .<: CategoricalArrays.CategoricalValue) + _remove_collinear(data, datakeys, method, ~silent, iscategorical) +end + +_remove_collinear(data, datakeys, v::Vif, verbose, iscategorical) = (_vifstep(data, datakeys[.~iscategorical], v.threshold, verbose, StatsAPI.vif)..., datakeys[iscategorical]...) +_remove_collinear(data, datakeys, v::Gvif, verbose, iscategorical) = _vifstep(data, datakeys, v.threshold, verbose, StatsAPI.gvif) +_remove_collinear(data, datakeys, p::Pearson, verbose, iscategorical) = (_pearsonstep(data, datakeys[.~iscategorical], p.threshold, verbose)..., datakeys[iscategorical]...) + +function _vifstep(data, datakeys, threshold, verbose, vifmethod) + highest_vif = threshold + 1. + while highest_vif > threshold + # make a custom implementation of gvif that works without the useless model + m = GLM.lm(StatsModels.FormulaTerm(StatsModels.term(1), StatsModels.term.(datakeys)), data) + vifresult = vifmethod(m) + maxvif = Base.findmax(vifresult) + highest_vif = maxvif[1] + if verbose + @info "$(datakeys[maxvif[2]]) has highest VIF score: $(maxvif[1])" + end + if isnan(highest_vif) + error("Cannot compute VIF. Possible some variables have perfect collinearity") + end + + if highest_vif > threshold + if verbose + @info "Removing $(datakeys[maxvif[2]]), $(length(datakeys)-1) variables remaining" + end + datakeys = datakeys[filter(x -> x != maxvif[2], 1:length(datakeys))] # not very elegant! + end + end + + if verbose + @info "All variables are below threshold, returning remaining variables" + end + + return datakeys +end + +# to break ties it j +function _pearsonstep(data, datakeys, threshold, verbose) + data = Tables.columntable(data)[datakeys] + datamatrix = reduce(hcat, data) + c = abs.(Statistics.cor(datamatrix) - LinearAlgebra.I) + correlated_vars_idx = findall(LinearAlgebra.LowerTriangular(c) .> threshold) + if verbose + @info "Found $(length(correlated_vars_idx)) correlated variable pairs" + for idx in correlated_vars_idx + println("$(keys(data)[idx.I[1]]) ~ $(keys(data)[idx.I[2]]): $(c[idx])") + end + end + + correlated_vars = Tuple.(correlated_vars_idx) + vars_to_remove = Int[] + while correlated_vars != [] + cm = mapreduce(x -> collect(x), vcat, correlated_vars) |> StatsBase.countmap # count how often each occurs + to_remove = findmax(cm)[2] + correlated_vars = [c for c in correlated_vars if ~in(to_remove, c)] + append!(vars_to_remove, to_remove) + end + vars_to_remove + vars_remaining = keys(data)[setdiff(1:length(keys(data)), vars_to_remove)] + return vars_remaining +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 3345581..4968c25 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,40 @@ -using SpeciesDistributionModels +using SpeciesDistributionModels, CategoricalArrays +import SpeciesDistributionModels as SDM + +import GLM: Distributions + +# some mock data +n = 500 +backgrounddata = (a = rand(n), b = rand(n), c = categorical(rand(0:3, n))) +presencedata = (a = rand(n), b = rand(n).^2, c = categorical(rand(Distributions.Binomial(3, 0.5), n))) + using Test @testset "SpeciesDistributionModels.jl" begin - # Write your tests here. + models = [SDM.random_forest(), SDM.random_forest(; max_depth = 3), SDM.linear_model(), SDM.boosted_regression_tree()] + resamplers = [SDM.MLJBase.CV(; shuffle = true, nfolds = 5)] + + ensemble = sdm( + presencedata, backgrounddata, + models, + resamplers + ) + + evaluation = SDM.evaluate(ensemble) +end + +@testset "collinearity" begin + # mock data with a collinearity problem + data_with_collinearity = merge(backgrounddata, (; d = backgrounddata.a .+ rand(n), e = backgrounddata.a .+ rand(n), f = f = categorical(rand(Distributions.Binomial(3, 0.5), 500)) )) + + rm_col_gvif = remove_collinear(data_with_collinearity; method = SDM.Gvif(; threshold = 2.), silent = true) + rm_col_vif = remove_collinear(data_with_collinearity; method = SDM.Vif(; threshold = 2.), silent = true) + rm_col_pearson = remove_collinear(data_with_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true) + @test rm_col_gvif == (:b, :c, :d, :e, :f) + @test rm_col_vif == (:b, :d, :e, :c, :f) + @test rm_col_pearson == (:b, :d, :e, :c, :f) + + data_with_perfect_collinearity = (a = [1,2,3], b = [1,2,3]) + Test.@test_throws Exception remove_collinear(data_with_perfect_collinearity; method = SDM.Gvif(; threshold = 2.), verbose = true) + @test remove_collinear(data_with_perfect_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true) == (:a, ) end