Skip to content

Commit

Permalink
add collinearity methods (#6)
Browse files Browse the repository at this point in the history
* add methods for collinearity detection using vif or cor

* add statsmodels requirement

* get rid of max size argument

* fix tests

* add documentation for remove_collinear

* better handing of categorical variables
  • Loading branch information
tiemvanderdeure authored Jan 18, 2024
1 parent 0bc48db commit c47d772
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 5 deletions.
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Lasso = "b4fcebef-c861-5a0f-a7e2-ba9dc32b180a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Loess = "4345ca2d-374a-55d4-8d30-97f9976e7612"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
Expand All @@ -22,11 +24,14 @@ Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
julia = "1.6"
StatsModels = "0.7.3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
7 changes: 4 additions & 3 deletions src/SpeciesDistributionModels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SpeciesDistributionModels

import Tables, StatsBase, Statistics
import Tables, StatsBase, Statistics, StatsAPI, StatsModels, LinearAlgebra
import MLJBase, StatisticalMeasures, StatisticalMeasuresBase, ScientificTypesBase, CategoricalArrays
import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess

Expand All @@ -13,14 +13,15 @@ using Makie: Toggle, Label, scatter!, lines!, Axis, Figure, GridLayout, lift
using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count

export SDMensemble, predict, sdm, select, machines, machine_keys, shap,
interactive_evaluation, interactive_response_curves
interactive_evaluation, interactive_response_curves,
remove_collinear

include("collinearity.jl")
include("models.jl")
include("ensemble.jl")
include("predict.jl")
include("explain.jl")
include("evaluate.jl")
include("plots.jl")


end
107 changes: 107 additions & 0 deletions src/collinearity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
abstract type AbstractCollinearityMethod end

struct Gvif <: AbstractCollinearityMethod
threshold
end
Gvif(; threshold) = Gvif(threshold)

struct Vif <: AbstractCollinearityMethod
threshold
end
Vif(; threshold) = Vif(threshold)

struct Pearson <: AbstractCollinearityMethod
threshold
end
Pearson(; threshold) = Pearson(threshold)

# Need to add a method for RasterStack, unless it will be Tables.jl compatible
"""
remove_collinear(data; method, silent = false)
Removes strongly correlated variables in `data`, until correlation is below a threshold specified in `method`.
`method` can currently be either `Gvif`, `Vif` or `Pearson`, which use GVIF, VIF, or Pearson's r, respectively.
GVIF and VIF are similar method, but GVIF includes categorical variables whereas VIF ignores them.
To run without showing information about collinearity scores, set `silent = true`.
## Example
```julia
julia> import SpeciesDistributionModels as SDM
julia> mydata = (a = 1:100, b = sqrt.(1:100), c = rand(100))
julia> SDM.remove_collinear(mydata; method = SDM.Vif(10))
[ Info: a has highest GVIF of 28.367942095054225
[ Info: Removing a, 2 variables remaining
[ Info: b has highest GVIF of 1.0077618445543057
[ Info: All variables are below threshold, returning remaining variables
(:b, :c)
```
"""
function remove_collinear(data; method, silent::Bool = false)
schema = Tables.schema(data)
datakeys = schema.names
iscategorical = collect(schema.types .<: CategoricalArrays.CategoricalValue)
_remove_collinear(data, datakeys, method, ~silent, iscategorical)
end

_remove_collinear(data, datakeys, v::Vif, verbose, iscategorical) = (_vifstep(data, datakeys[.~iscategorical], v.threshold, verbose, StatsAPI.vif)..., datakeys[iscategorical]...)
_remove_collinear(data, datakeys, v::Gvif, verbose, iscategorical) = _vifstep(data, datakeys, v.threshold, verbose, StatsAPI.gvif)
_remove_collinear(data, datakeys, p::Pearson, verbose, iscategorical) = (_pearsonstep(data, datakeys[.~iscategorical], p.threshold, verbose)..., datakeys[iscategorical]...)

function _vifstep(data, datakeys, threshold, verbose, vifmethod)
highest_vif = threshold + 1.
while highest_vif > threshold
# make a custom implementation of gvif that works without the useless model
m = GLM.lm(StatsModels.FormulaTerm(StatsModels.term(1), StatsModels.term.(datakeys)), data)
vifresult = vifmethod(m)
maxvif = Base.findmax(vifresult)
highest_vif = maxvif[1]
if verbose
@info "$(datakeys[maxvif[2]]) has highest VIF score: $(maxvif[1])"
end
if isnan(highest_vif)
error("Cannot compute VIF. Possible some variables have perfect collinearity")
end

if highest_vif > threshold
if verbose
@info "Removing $(datakeys[maxvif[2]]), $(length(datakeys)-1) variables remaining"
end
datakeys = datakeys[filter(x -> x != maxvif[2], 1:length(datakeys))] # not very elegant!
end
end

if verbose
@info "All variables are below threshold, returning remaining variables"
end

return datakeys
end

# to break ties it j
function _pearsonstep(data, datakeys, threshold, verbose)
data = Tables.columntable(data)[datakeys]
datamatrix = reduce(hcat, data)
c = abs.(Statistics.cor(datamatrix) - LinearAlgebra.I)
correlated_vars_idx = findall(LinearAlgebra.LowerTriangular(c) .> threshold)
if verbose
@info "Found $(length(correlated_vars_idx)) correlated variable pairs"
for idx in correlated_vars_idx
println("$(keys(data)[idx.I[1]]) ~ $(keys(data)[idx.I[2]]): $(c[idx])")
end
end

correlated_vars = Tuple.(correlated_vars_idx)
vars_to_remove = Int[]
while correlated_vars != []
cm = mapreduce(x -> collect(x), vcat, correlated_vars) |> StatsBase.countmap # count how often each occurs
to_remove = findmax(cm)[2]
correlated_vars = [c for c in correlated_vars if ~in(to_remove, c)]
append!(vars_to_remove, to_remove)
end
vars_to_remove
vars_remaining = keys(data)[setdiff(1:length(keys(data)), vars_to_remove)]
return vars_remaining
end
38 changes: 36 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,40 @@
using SpeciesDistributionModels
using SpeciesDistributionModels, CategoricalArrays
import SpeciesDistributionModels as SDM

import GLM: Distributions

# some mock data
n = 500
backgrounddata = (a = rand(n), b = rand(n), c = categorical(rand(0:3, n)))
presencedata = (a = rand(n), b = rand(n).^2, c = categorical(rand(Distributions.Binomial(3, 0.5), n)))

using Test

@testset "SpeciesDistributionModels.jl" begin
# Write your tests here.
models = [SDM.random_forest(), SDM.random_forest(; max_depth = 3), SDM.linear_model(), SDM.boosted_regression_tree()]
resamplers = [SDM.MLJBase.CV(; shuffle = true, nfolds = 5)]

ensemble = sdm(
presencedata, backgrounddata,
models,
resamplers
)

evaluation = SDM.evaluate(ensemble)
end

@testset "collinearity" begin
# mock data with a collinearity problem
data_with_collinearity = merge(backgrounddata, (; d = backgrounddata.a .+ rand(n), e = backgrounddata.a .+ rand(n), f = f = categorical(rand(Distributions.Binomial(3, 0.5), 500)) ))

rm_col_gvif = remove_collinear(data_with_collinearity; method = SDM.Gvif(; threshold = 2.), silent = true)
rm_col_vif = remove_collinear(data_with_collinearity; method = SDM.Vif(; threshold = 2.), silent = true)
rm_col_pearson = remove_collinear(data_with_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true)
@test rm_col_gvif == (:b, :c, :d, :e, :f)
@test rm_col_vif == (:b, :d, :e, :c, :f)
@test rm_col_pearson == (:b, :d, :e, :c, :f)

data_with_perfect_collinearity = (a = [1,2,3], b = [1,2,3])
Test.@test_throws Exception remove_collinear(data_with_perfect_collinearity; method = SDM.Gvif(; threshold = 2.), verbose = true)
@test remove_collinear(data_with_perfect_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true) == (:a, )
end

0 comments on commit c47d772

Please sign in to comment.