Skip to content


add collinearity methods (#6)
Browse files Browse the repository at this point in the history
* add methods for collinearity detection using vif or cor

* add statsmodels requirement

* get rid of max size argument

* fix tests

* add documentation for remove_collinear

* better handing of categorical variables
  • Loading branch information
tiemvanderdeure authored Jan 18, 2024
1 parent 0bc48db commit c47d772
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 5 deletions.
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Lasso = "b4fcebef-c861-5a0f-a7e2-ba9dc32b180a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Loess = "4345ca2d-374a-55d4-8d30-97f9976e7612"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
Expand All @@ -22,11 +24,14 @@ Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

julia = "1.6"
StatsModels = "0.7.3"

Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
7 changes: 4 additions & 3 deletions src/SpeciesDistributionModels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SpeciesDistributionModels

import Tables, StatsBase, Statistics
import Tables, StatsBase, Statistics, StatsAPI, StatsModels, LinearAlgebra
import MLJBase, StatisticalMeasures, StatisticalMeasuresBase, ScientificTypesBase, CategoricalArrays
import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess

Expand All @@ -13,14 +13,15 @@ using Makie: Toggle, Label, scatter!, lines!, Axis, Figure, GridLayout, lift
using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count

export SDMensemble, predict, sdm, select, machines, machine_keys, shap,
interactive_evaluation, interactive_response_curves
interactive_evaluation, interactive_response_curves,


107 changes: 107 additions & 0 deletions src/collinearity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
abstract type AbstractCollinearityMethod end

struct Gvif <: AbstractCollinearityMethod
Gvif(; threshold) = Gvif(threshold)

struct Vif <: AbstractCollinearityMethod
Vif(; threshold) = Vif(threshold)

struct Pearson <: AbstractCollinearityMethod
Pearson(; threshold) = Pearson(threshold)

# Need to add a method for RasterStack, unless it will be Tables.jl compatible
remove_collinear(data; method, silent = false)
Removes strongly correlated variables in `data`, until correlation is below a threshold specified in `method`.
`method` can currently be either `Gvif`, `Vif` or `Pearson`, which use GVIF, VIF, or Pearson's r, respectively.
GVIF and VIF are similar method, but GVIF includes categorical variables whereas VIF ignores them.
To run without showing information about collinearity scores, set `silent = true`.
## Example
julia> import SpeciesDistributionModels as SDM
julia> mydata = (a = 1:100, b = sqrt.(1:100), c = rand(100))
julia> SDM.remove_collinear(mydata; method = SDM.Vif(10))
[ Info: a has highest GVIF of 28.367942095054225
[ Info: Removing a, 2 variables remaining
[ Info: b has highest GVIF of 1.0077618445543057
[ Info: All variables are below threshold, returning remaining variables
(:b, :c)
function remove_collinear(data; method, silent::Bool = false)
schema = Tables.schema(data)
datakeys = schema.names
iscategorical = collect(schema.types .<: CategoricalArrays.CategoricalValue)
_remove_collinear(data, datakeys, method, ~silent, iscategorical)

_remove_collinear(data, datakeys, v::Vif, verbose, iscategorical) = (_vifstep(data, datakeys[.~iscategorical], v.threshold, verbose, StatsAPI.vif)..., datakeys[iscategorical]...)
_remove_collinear(data, datakeys, v::Gvif, verbose, iscategorical) = _vifstep(data, datakeys, v.threshold, verbose, StatsAPI.gvif)
_remove_collinear(data, datakeys, p::Pearson, verbose, iscategorical) = (_pearsonstep(data, datakeys[.~iscategorical], p.threshold, verbose)..., datakeys[iscategorical]...)

function _vifstep(data, datakeys, threshold, verbose, vifmethod)
highest_vif = threshold + 1.
while highest_vif > threshold
# make a custom implementation of gvif that works without the useless model
m = GLM.lm(StatsModels.FormulaTerm(StatsModels.term(1), StatsModels.term.(datakeys)), data)
vifresult = vifmethod(m)
maxvif = Base.findmax(vifresult)
highest_vif = maxvif[1]
if verbose
@info "$(datakeys[maxvif[2]]) has highest VIF score: $(maxvif[1])"
if isnan(highest_vif)
error("Cannot compute VIF. Possible some variables have perfect collinearity")

if highest_vif > threshold
if verbose
@info "Removing $(datakeys[maxvif[2]]), $(length(datakeys)-1) variables remaining"
datakeys = datakeys[filter(x -> x != maxvif[2], 1:length(datakeys))] # not very elegant!

if verbose
@info "All variables are below threshold, returning remaining variables"

return datakeys

# to break ties it j
function _pearsonstep(data, datakeys, threshold, verbose)
data = Tables.columntable(data)[datakeys]
datamatrix = reduce(hcat, data)
c = abs.(Statistics.cor(datamatrix) - LinearAlgebra.I)
correlated_vars_idx = findall(LinearAlgebra.LowerTriangular(c) .> threshold)
if verbose
@info "Found $(length(correlated_vars_idx)) correlated variable pairs"
for idx in correlated_vars_idx
println("$(keys(data)[idx.I[1]]) ~ $(keys(data)[idx.I[2]]): $(c[idx])")

correlated_vars = Tuple.(correlated_vars_idx)
vars_to_remove = Int[]
while correlated_vars != []
cm = mapreduce(x -> collect(x), vcat, correlated_vars) |> StatsBase.countmap # count how often each occurs
to_remove = findmax(cm)[2]
correlated_vars = [c for c in correlated_vars if ~in(to_remove, c)]
append!(vars_to_remove, to_remove)
vars_remaining = keys(data)[setdiff(1:length(keys(data)), vars_to_remove)]
return vars_remaining
38 changes: 36 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,40 @@
using SpeciesDistributionModels
using SpeciesDistributionModels, CategoricalArrays
import SpeciesDistributionModels as SDM

import GLM: Distributions

# some mock data
n = 500
backgrounddata = (a = rand(n), b = rand(n), c = categorical(rand(0:3, n)))
presencedata = (a = rand(n), b = rand(n).^2, c = categorical(rand(Distributions.Binomial(3, 0.5), n)))

using Test

@testset "SpeciesDistributionModels.jl" begin
# Write your tests here.
models = [SDM.random_forest(), SDM.random_forest(; max_depth = 3), SDM.linear_model(), SDM.boosted_regression_tree()]
resamplers = [SDM.MLJBase.CV(; shuffle = true, nfolds = 5)]

ensemble = sdm(
presencedata, backgrounddata,

evaluation = SDM.evaluate(ensemble)

@testset "collinearity" begin
# mock data with a collinearity problem
data_with_collinearity = merge(backgrounddata, (; d = backgrounddata.a .+ rand(n), e = backgrounddata.a .+ rand(n), f = f = categorical(rand(Distributions.Binomial(3, 0.5), 500)) ))

rm_col_gvif = remove_collinear(data_with_collinearity; method = SDM.Gvif(; threshold = 2.), silent = true)
rm_col_vif = remove_collinear(data_with_collinearity; method = SDM.Vif(; threshold = 2.), silent = true)
rm_col_pearson = remove_collinear(data_with_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true)
@test rm_col_gvif == (:b, :c, :d, :e, :f)
@test rm_col_vif == (:b, :d, :e, :c, :f)
@test rm_col_pearson == (:b, :d, :e, :c, :f)

data_with_perfect_collinearity = (a = [1,2,3], b = [1,2,3])
Test.@test_throws Exception remove_collinear(data_with_perfect_collinearity; method = SDM.Gvif(; threshold = 2.), verbose = true)
@test remove_collinear(data_with_perfect_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true) == (:a, )

0 comments on commit c47d772

Please sign in to comment.