-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add methods for collinearity detection using vif or cor * add statsmodels requirement * get rid of max size argument * fix tests * add documentation for remove_collinear * better handing of categorical variables
- Loading branch information
1 parent
0bc48db
commit c47d772
Showing
4 changed files
with
152 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
abstract type AbstractCollinearityMethod end | ||
|
||
struct Gvif <: AbstractCollinearityMethod | ||
threshold | ||
end | ||
Gvif(; threshold) = Gvif(threshold) | ||
|
||
struct Vif <: AbstractCollinearityMethod | ||
threshold | ||
end | ||
Vif(; threshold) = Vif(threshold) | ||
|
||
struct Pearson <: AbstractCollinearityMethod | ||
threshold | ||
end | ||
Pearson(; threshold) = Pearson(threshold) | ||
|
||
# Need to add a method for RasterStack, unless it will be Tables.jl compatible | ||
""" | ||
remove_collinear(data; method, silent = false) | ||
Removes strongly correlated variables in `data`, until correlation is below a threshold specified in `method`. | ||
`method` can currently be either `Gvif`, `Vif` or `Pearson`, which use GVIF, VIF, or Pearson's r, respectively. | ||
GVIF and VIF are similar method, but GVIF includes categorical variables whereas VIF ignores them. | ||
To run without showing information about collinearity scores, set `silent = true`. | ||
## Example | ||
```julia | ||
julia> import SpeciesDistributionModels as SDM | ||
julia> mydata = (a = 1:100, b = sqrt.(1:100), c = rand(100)) | ||
julia> SDM.remove_collinear(mydata; method = SDM.Vif(10)) | ||
[ Info: a has highest GVIF of 28.367942095054225 | ||
[ Info: Removing a, 2 variables remaining | ||
[ Info: b has highest GVIF of 1.0077618445543057 | ||
[ Info: All variables are below threshold, returning remaining variables | ||
(:b, :c) | ||
``` | ||
""" | ||
function remove_collinear(data; method, silent::Bool = false) | ||
schema = Tables.schema(data) | ||
datakeys = schema.names | ||
iscategorical = collect(schema.types .<: CategoricalArrays.CategoricalValue) | ||
_remove_collinear(data, datakeys, method, ~silent, iscategorical) | ||
end | ||
|
||
_remove_collinear(data, datakeys, v::Vif, verbose, iscategorical) = (_vifstep(data, datakeys[.~iscategorical], v.threshold, verbose, StatsAPI.vif)..., datakeys[iscategorical]...) | ||
_remove_collinear(data, datakeys, v::Gvif, verbose, iscategorical) = _vifstep(data, datakeys, v.threshold, verbose, StatsAPI.gvif) | ||
_remove_collinear(data, datakeys, p::Pearson, verbose, iscategorical) = (_pearsonstep(data, datakeys[.~iscategorical], p.threshold, verbose)..., datakeys[iscategorical]...) | ||
|
||
function _vifstep(data, datakeys, threshold, verbose, vifmethod) | ||
highest_vif = threshold + 1. | ||
while highest_vif > threshold | ||
# make a custom implementation of gvif that works without the useless model | ||
m = GLM.lm(StatsModels.FormulaTerm(StatsModels.term(1), StatsModels.term.(datakeys)), data) | ||
vifresult = vifmethod(m) | ||
maxvif = Base.findmax(vifresult) | ||
highest_vif = maxvif[1] | ||
if verbose | ||
@info "$(datakeys[maxvif[2]]) has highest VIF score: $(maxvif[1])" | ||
end | ||
if isnan(highest_vif) | ||
error("Cannot compute VIF. Possible some variables have perfect collinearity") | ||
end | ||
|
||
if highest_vif > threshold | ||
if verbose | ||
@info "Removing $(datakeys[maxvif[2]]), $(length(datakeys)-1) variables remaining" | ||
end | ||
datakeys = datakeys[filter(x -> x != maxvif[2], 1:length(datakeys))] # not very elegant! | ||
end | ||
end | ||
|
||
if verbose | ||
@info "All variables are below threshold, returning remaining variables" | ||
end | ||
|
||
return datakeys | ||
end | ||
|
||
# to break ties it j | ||
function _pearsonstep(data, datakeys, threshold, verbose) | ||
data = Tables.columntable(data)[datakeys] | ||
datamatrix = reduce(hcat, data) | ||
c = abs.(Statistics.cor(datamatrix) - LinearAlgebra.I) | ||
correlated_vars_idx = findall(LinearAlgebra.LowerTriangular(c) .> threshold) | ||
if verbose | ||
@info "Found $(length(correlated_vars_idx)) correlated variable pairs" | ||
for idx in correlated_vars_idx | ||
println("$(keys(data)[idx.I[1]]) ~ $(keys(data)[idx.I[2]]): $(c[idx])") | ||
end | ||
end | ||
|
||
correlated_vars = Tuple.(correlated_vars_idx) | ||
vars_to_remove = Int[] | ||
while correlated_vars != [] | ||
cm = mapreduce(x -> collect(x), vcat, correlated_vars) |> StatsBase.countmap # count how often each occurs | ||
to_remove = findmax(cm)[2] | ||
correlated_vars = [c for c in correlated_vars if ~in(to_remove, c)] | ||
append!(vars_to_remove, to_remove) | ||
end | ||
vars_to_remove | ||
vars_remaining = keys(data)[setdiff(1:length(keys(data)), vars_to_remove)] | ||
return vars_remaining | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,40 @@ | ||
using SpeciesDistributionModels | ||
using SpeciesDistributionModels, CategoricalArrays | ||
import SpeciesDistributionModels as SDM | ||
|
||
import GLM: Distributions | ||
|
||
# some mock data | ||
n = 500 | ||
backgrounddata = (a = rand(n), b = rand(n), c = categorical(rand(0:3, n))) | ||
presencedata = (a = rand(n), b = rand(n).^2, c = categorical(rand(Distributions.Binomial(3, 0.5), n))) | ||
|
||
using Test | ||
|
||
@testset "SpeciesDistributionModels.jl" begin | ||
# Write your tests here. | ||
models = [SDM.random_forest(), SDM.random_forest(; max_depth = 3), SDM.linear_model(), SDM.boosted_regression_tree()] | ||
resamplers = [SDM.MLJBase.CV(; shuffle = true, nfolds = 5)] | ||
|
||
ensemble = sdm( | ||
presencedata, backgrounddata, | ||
models, | ||
resamplers | ||
) | ||
|
||
evaluation = SDM.evaluate(ensemble) | ||
end | ||
|
||
@testset "collinearity" begin | ||
# mock data with a collinearity problem | ||
data_with_collinearity = merge(backgrounddata, (; d = backgrounddata.a .+ rand(n), e = backgrounddata.a .+ rand(n), f = f = categorical(rand(Distributions.Binomial(3, 0.5), 500)) )) | ||
|
||
rm_col_gvif = remove_collinear(data_with_collinearity; method = SDM.Gvif(; threshold = 2.), silent = true) | ||
rm_col_vif = remove_collinear(data_with_collinearity; method = SDM.Vif(; threshold = 2.), silent = true) | ||
rm_col_pearson = remove_collinear(data_with_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true) | ||
@test rm_col_gvif == (:b, :c, :d, :e, :f) | ||
@test rm_col_vif == (:b, :d, :e, :c, :f) | ||
@test rm_col_pearson == (:b, :d, :e, :c, :f) | ||
|
||
data_with_perfect_collinearity = (a = [1,2,3], b = [1,2,3]) | ||
Test.@test_throws Exception remove_collinear(data_with_perfect_collinearity; method = SDM.Gvif(; threshold = 2.), verbose = true) | ||
@test remove_collinear(data_with_perfect_collinearity; method = SDM.Pearson(; threshold = 0.65), silent = true) == (:a, ) | ||
end |