Skip to content

Commit

Permalink
change type structure (#5)
Browse files Browse the repository at this point in the history
* first draft for evaluate function

* set up machine and group types

* implement select and data for ensemble and groups

* implemented hierarchies in evaluate

* add some convenience functions

* add scientifictypesbase as dep

* predict for any Tables.jl compatible table

* add statisticalmeasures

* import pdf from MLJBase

* uncommit playground
  • Loading branch information
tiemvanderdeure authored Jan 4, 2024
1 parent 366f53a commit 0bc48db
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 119 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.0.0-DEV"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
Expand All @@ -16,8 +17,10 @@ MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Rasters = "a3a2b9e3-a471-40c9-b274-f788e487c689"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Expand Down
21 changes: 0 additions & 21 deletions playground.jl

This file was deleted.

9 changes: 7 additions & 2 deletions src/SpeciesDistributionModels.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
module SpeciesDistributionModels

import Tables, StatsBase, Statistics
import MLJBase, StatisticalMeasures, GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess
import MLJBase, StatisticalMeasures, StatisticalMeasuresBase, ScientificTypesBase, CategoricalArrays
import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess

using MLJBase: pdf

using Rasters: Raster, RasterStack

using Makie: Toggle, Label, scatter!, lines!, Axis, Figure, GridLayout, lift

import CategoricalArrays.CategoricalArray
using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count

export SDMensemble, predict, sdm, select, machines, machine_keys, shap,
interactive_evaluation, interactive_response_curves
Expand All @@ -16,6 +19,8 @@ include("models.jl")
include("ensemble.jl")
include("predict.jl")
include("explain.jl")
include("evaluate.jl")
include("plots.jl")


end
244 changes: 159 additions & 85 deletions src/ensemble.jl
Original file line number Diff line number Diff line change
@@ -1,78 +1,139 @@
struct SDMensemble <: AbstractVector{NamedTuple}
trained_models::Vector{<:NamedTuple} # Contains the trained models
predictors::Vector{Symbol} # Vector of names of predictor variables # Better tuple of vector???
n_presences::Int
n_absences::Int
models::NamedTuple
resamplers::NamedTuple
data
#folds::Vector{<:Vector}
# Machines have 1 machine, plus metadata
struct SDMmachine
machine
predictors::NTuple{<:Any, Symbol}
fold
train_rows
test_rows
end

# Groups have multiple machines with identical model and resampler
struct SDMgroup <: AbstractVector{SDMmachine}
sdm_machines::Vector{SDMmachine}
model
resampler
model_name
resampler_name
end

# Ensembles have multiple groups with potentially different models and different resamplers, but identical data
struct SDMensemble <: AbstractVector{SDMgroup}
groups::Vector{SDMgroup} # Contains the trained models
end

SDMgroupOrEnsemble = Union{SDMgroup, SDMensemble}

# Rather than storing it in ensemble, access it like this? It looks like args stores exactly the data as it is given to the machine
data(mach::SDMmachine) = (predictor = mach.machine.args[1].data, response = mach.machine.args[2].data)
data(group::SDMgroup) = data(group[1])
data(ensemble::SDMensemble) = data(ensemble[1])



#Base.getproperty(ensemble::SDMensemble, key::Symbol) = getproperty.(ensemble, key)
# getproperty directly into models???

Base.getindex(ensemble::SDMensemble, i) = ensemble.trained_models[i]
Base.size(ensemble::SDMensemble) = size(ensemble.trained_models)
Base.getindex(ensemble::SDMensemble, i) = ensemble.groups[i]
Base.getindex(group::SDMgroup, i) = group.sdm_machines[i]
Base.size(ensemble::SDMensemble) = size(ensemble.groups)
Base.size(group::SDMgroup) = size(group.sdm_machines)
n_machines(ensemble::SDMensemble) = mapreduce(group -> length(group.sdm_machines), +, ensemble)

#trained_models(ensemble) = getfield(ensemble, :trained_models)
machines(ensemble::SDMensemble) = getfield.(ensemble.trained_models, :machine)
machine_keys(ensemble::SDMensemble) = getfield.(ensemble.trained_models, :machine_key)
machines(group::SDMgroup) = map(m -> m.machine, group)

# Function to convienently select some models from the ensemble
function select(ensemble::SDMensemble, indices::Vector{Int})
trained_models = ensemble.trained_models[indices]
sdm_machines(group::SDMgroup) = group.sdm_machines

model_keys = unique(getfield.(trained_models, :model_key))
resampler_keys = unique(getfield.(trained_models, :resampler_key))
# machine_key generates a unique key for a machine
machine_keys(group::SDMgroup) = ["$(group.model_name)_$(group.resampler_name)_$(m.fold)" for m in group]

return SDMensemble(
trained_models,
ensemble.predictors,
ensemble.n_presences,
ensemble.n_absences,
ensemble.models[model_keys],
ensemble.resamplers[resampler_keys],
ensemble.data
)
# A bunch of functions are applied to an ensemble by applying to each group and reducing with vcat
for f in (:machines, :machine_keys, :sdm_machines)
@eval ($f)(ensemble::SDMensemble) = mapreduce(group -> ($f)(group), vcat, ensemble)
end

function select(
ensemble::SDMensemble;
models = keys(ensemble.models),
resamplers = keys(ensemble.resamplers),
folds = 1:1000)

select(
ensemble,
findall(
m -> m.model_key in models && m.resampler_key in resamplers && m.fold in folds,
ensemble.trained_models
## Select methods

# Function to convienently select some models from groups or ensembles
function select(group::SDMgroup, machine_indices::Vector{<:Int})
if length(machine_indices) == 0
return nothing
else
return SDMgroup(
group.sdm_machines[machine_indices],
group.model,
group.resampler,
group.model_name,
group.resampler_name
)
)
end
end

function Base.show(io::IO, mime::MIME"text/plain", ensemble::SDMensemble)
println(io, "SDMensemble with $(Base.length(ensemble)) models")
function select(ensemble::SDMensemble, group_indices::Vector{<:Int})
if length(group_indices) == 0
return nothing
else
return SDMensemble(
ensemble.groups[group_indices],
)
end
end

println(io, "Model performance:")
function select(ensemble::SDMensemble; machines::Vector{<:Int})
if length(machines) == 0
return nothing
else
Statistics.maximum(machines) <= n_machines(ensemble) || throw(BoundsError(ensemble, machines))

j = 0
groups = SDMgroup[]
groups = map(ensemble) do group
i = j
j = i + length(group)
select(group, filter(x -> x > i && x <= j, machines) .- i)
end

return SDMensemble(
filter(!Base.isnothing, groups),
)
end
end

aucs = auc_by_model(ensemble)
data = hcat(collect(keys(ensemble.models)), aucs)
header = (["model_key", "auc"])
PrettyTables.pretty_table(io, data; header = header)
## Show methods
function Base.show(io::IO, mime::MIME"text/plain", ensemble::SDMensemble)
n_presence = sum(data(ensemble).response .== true)
n_absence = sum(data(ensemble).response .== false)

sc = MLJBase.schema(data(ensemble).predictor)
sci = sc.scitypes
nam = sc.names

println(io, "SDMensemble with $(n_machines(ensemble)) machines across $(Base.length(ensemble)) groups")
println(io, "Occurence data: Presence-Absence with $n_presence presences and $n_absence absences")
println(io, "Predictors: $(join(["$key ($scitype)" for (key, scitype) in zip(nam, sci)], ", "))")

model_names = getfield.(ensemble.groups, :model_name)
resampler_names = getfield.(ensemble.groups, :resampler_name)
n_models = Base.length.(ensemble.groups)
table_cols = hcat(model_names, resampler_names, n_models)
header = (["model", "resampler", "machines"])
PrettyTables.pretty_table(io, table_cols; header = header)

end

#println(io, "model keys: ", keys(ensemble.models))
#show(io, mime, Tables.rows(ensemble))
function Base.show(io::IO, mime::MIME"text/plain", group::SDMgroup)
println(io, "SDMgroup with $(Base.length(group)) machines")
println(io, "Model $(group.model_name) and resampler $(group.resampler_name)")
end


## Table interface !! Is this still valid?
Tables.istable(::Type{SDMensemble}) = true
Tables.schema(ensemble::SDMensemble) = Tables.schema(ensemble.trained_models[1])
Tables.rows(ensemble::SDMensemble) = Tables.rows(ensemble.trained_models)
Tables.columns(ensemble::SDMensemble) = Tables.columns(ensemble.trained_models)
Tables.schema(ensemble::SDMensemble) = Tables.schema(ensemble.groups[1])
Tables.rows(ensemble::SDMensemble) = Tables.rows(ensemble.groups)
Tables.columns(ensemble::SDMensemble) = Tables.columns(ensemble.groups)

# Turns models into a NamedTuple with unique keys
function givenames(models::Vector)
function _givenames(models::Vector)
names = map(models) do model
replace(MLJBase.name(model), r"Classifier$"=>"")
end
Expand All @@ -84,22 +145,40 @@ function givenames(models::Vector)
return NamedTuple{Tuple(Symbol.(names))}(models)
end

function auc_by_model(ensemble)
mapreduce(vcat, keys(ensemble.models)) do key
Statistics.mean([model.auc for model in ensemble.trained_models if model.model_key == key])
function _fit_sdm_model(predictor_values::NamedTuple, response_values, model, fold, train, test, verbosity)
mach = MLJBase.machine(model, predictor_values, response_values)
MLJBase.fit!(mach; rows = train, verbosity = verbosity)
return SDMmachine(mach, keys(predictor_values), fold, train, test)
end

function _fit_sdm_group(
predictor_values::NamedTuple,
response_values,
model,
resampler,
folds,
model_name,
resampler_name,
verbosity
)

machines = map(enumerate(folds)) do (f, (train, test))
_fit_sdm_model(predictor_values, response_values, model, f, train, test, verbosity)
end

return SDMgroup(machines, model, resampler, model_name, resampler_name)

end

function sdm(
presences,
absence,
models,
resamplers;
var_keys::Vector{Symbol} = [key for key in Tables.schema(absence).names if in(key, Tables.schema(presences).names)],
var_keys::Vector{Symbol} = intersect(Tables.schema(absence).names, Tables.schema(presences).names),
scitypes::Vector{DataType} = [MLJBase.scitype(Tables.schema(presences).types) for key in var_keys],
verbosity::Int = 0
)

)
@assert Tables.istable(presences) && Tables.istable(absence)

n_presence = Base.length(Tables.rows(presences)) ##
Expand All @@ -108,37 +187,32 @@ function sdm(

# merge presence and absence data into one namedtuple of vectors
predictor_values = NamedTuple{Tuple(var_keys)}([[Tables.columns(absence)[var]; Tables.columns(presences)[var]] for var in var_keys])
response_values = CategoricalArray(
response_values = CategoricalArrays.categorical(
[falses(n_absence); trues(n_presence)];
levels = [false, true], ordered = true)

models_ = givenames(models)
resamplers_ = givenames(resamplers)
levels = [false, true], ordered = false
)

trained_models = mapreduce(vcat, keys(resamplers_)) do resampler_key
models_ = _givenames(models)
resamplers_ = _givenames(resamplers)

sdm_groups = mapreduce(vcat, collect(keys(resamplers_))) do resampler_key
resampler = resamplers_[resampler_key]
folds = MLJBase.train_test_pairs(resampler, 1:n_total, response_values) ## get indices
mapreduce(vcat, keys(models_)) do model_key
map(collect(keys(models_))) do model_key
model = models_[model_key]
map(enumerate(folds)) do (f, (train, test))
mach = MLJBase.machine(model, predictor_values, response_values)
MLJBase.fit!(mach; rows = train, verbosity = verbosity)
y_hat = MLJBase.predict(mach, rows = test)
auc = StatisticalMeasures.auc(y_hat, response_values[test])
machine_key = Symbol(String(model_key) * "_" * String(resampler_key) * "_" * string(f))
return (; machine = mach, auc = auc, model_key, resampler_key, fold = f, machine_key, train, test)
# Probably make a Type for this
end
_fit_sdm_group(
predictor_values,
response_values,
model,
resampler,
folds,
model_key,
resampler_key,
verbosity
)
end
end

return SDMensemble(
trained_models,
var_keys,
n_presence,
n_absence,
models_,
resamplers_,
(predictor = predictor_values, response = response_values)
)
end
return SDMensemble(sdm_groups)

end
Loading

0 comments on commit 0bc48db

Please sign in to comment.