From 0bc48db945046b4099ba7db5724423f895106e94 Mon Sep 17 00:00:00 2001 From: Tiem van der Deure Date: Thu, 4 Jan 2024 10:23:32 +0100 Subject: [PATCH] change type structure (#5) * first draft for evaluate function * set up machine and group types * implement select and data for ensemble and groups * implemented hierarchies in evaluate * add some convenience functions * add scientifictypesbase as dep * predict for any Tables.jl compatible table * add statisticalmeasures * import pdf from MLJBase * uncommit playground --- Project.toml | 3 + playground.jl | 21 --- src/SpeciesDistributionModels.jl | 9 +- src/ensemble.jl | 244 ++++++++++++++++++++----------- src/evaluate.jl | 186 +++++++++++++++++++++++ src/predict.jl | 33 +++-- 6 files changed, 377 insertions(+), 119 deletions(-) delete mode 100644 playground.jl create mode 100644 src/evaluate.jl diff --git a/Project.toml b/Project.toml index 237d608..effac33 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.0.0-DEV" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" @@ -16,8 +17,10 @@ MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Rasters = "a3a2b9e3-a471-40c9-b274-f788e487c689" +ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065" StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" +StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" diff --git a/playground.jl b/playground.jl deleted file mode 100644 index 82d8be9..0000000 --- a/playground.jl +++ /dev/null @@ -1,21 +0,0 @@ -using Revise - -using SpeciesDistributionModels, GLMakie -import SpeciesDistributionModels as SDM - -presences = (a = rand(200), b = rand(200), c = rand(200)) -background = (a = rand(200), b = sqrt.(rand(200)), c = rand(200).^2) - -models = [ - SDM.linear_model(), - SDM.boosted_regression_tree(), - SDM.random_forest(), - SDM.random_forest(; n_trees = 10, max_depth = 3)] - -ensemble = sdm(presences, background, models, [SDM.MLJBase.CV(; shuffle = true)]) - -interactive_evaluation(ensemble) - -shapley = shap(ensemble; n_samples = 5) - -interactive_response_curves(shapley) diff --git a/src/SpeciesDistributionModels.jl b/src/SpeciesDistributionModels.jl index 638da25..1adda08 100644 --- a/src/SpeciesDistributionModels.jl +++ b/src/SpeciesDistributionModels.jl @@ -1,13 +1,16 @@ module SpeciesDistributionModels import Tables, StatsBase, Statistics -import MLJBase, StatisticalMeasures, GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess +import MLJBase, StatisticalMeasures, StatisticalMeasuresBase, ScientificTypesBase, CategoricalArrays +import GLM, PrettyTables, Rasters, EvoTrees, DecisionTree, Makie, Shapley, Loess + +using MLJBase: pdf using Rasters: Raster, RasterStack using Makie: Toggle, Label, scatter!, lines!, Axis, Figure, GridLayout, lift -import CategoricalArrays.CategoricalArray +using ScientificTypesBase: Continuous, OrderedFactor, Multiclass, Count export SDMensemble, predict, sdm, select, machines, machine_keys, shap, interactive_evaluation, interactive_response_curves @@ -16,6 +19,8 @@ include("models.jl") include("ensemble.jl") include("predict.jl") include("explain.jl") +include("evaluate.jl") include("plots.jl") + end diff --git a/src/ensemble.jl b/src/ensemble.jl index 3b9b23f..b179d77 100644 --- a/src/ensemble.jl +++ b/src/ensemble.jl @@ -1,78 +1,139 @@ -struct SDMensemble <: AbstractVector{NamedTuple} - trained_models::Vector{<:NamedTuple} # Contains the trained models - predictors::Vector{Symbol} # Vector of names of predictor variables # Better tuple of vector??? - n_presences::Int - n_absences::Int - models::NamedTuple - resamplers::NamedTuple - data - #folds::Vector{<:Vector} +# Machines have 1 machine, plus metadata +struct SDMmachine + machine + predictors::NTuple{<:Any, Symbol} + fold + train_rows + test_rows end +# Groups have multiple machines with identical model and resampler +struct SDMgroup <: AbstractVector{SDMmachine} + sdm_machines::Vector{SDMmachine} + model + resampler + model_name + resampler_name +end + +# Ensembles have multiple groups with potentially different models and different resamplers, but identical data +struct SDMensemble <: AbstractVector{SDMgroup} + groups::Vector{SDMgroup} # Contains the trained models +end + +SDMgroupOrEnsemble = Union{SDMgroup, SDMensemble} + +# Rather than storing it in ensemble, access it like this? It looks like args stores exactly the data as it is given to the machine +data(mach::SDMmachine) = (predictor = mach.machine.args[1].data, response = mach.machine.args[2].data) +data(group::SDMgroup) = data(group[1]) +data(ensemble::SDMensemble) = data(ensemble[1]) + + + #Base.getproperty(ensemble::SDMensemble, key::Symbol) = getproperty.(ensemble, key) # getproperty directly into models??? -Base.getindex(ensemble::SDMensemble, i) = ensemble.trained_models[i] -Base.size(ensemble::SDMensemble) = size(ensemble.trained_models) +Base.getindex(ensemble::SDMensemble, i) = ensemble.groups[i] +Base.getindex(group::SDMgroup, i) = group.sdm_machines[i] +Base.size(ensemble::SDMensemble) = size(ensemble.groups) +Base.size(group::SDMgroup) = size(group.sdm_machines) +n_machines(ensemble::SDMensemble) = mapreduce(group -> length(group.sdm_machines), +, ensemble) -#trained_models(ensemble) = getfield(ensemble, :trained_models) -machines(ensemble::SDMensemble) = getfield.(ensemble.trained_models, :machine) -machine_keys(ensemble::SDMensemble) = getfield.(ensemble.trained_models, :machine_key) +machines(group::SDMgroup) = map(m -> m.machine, group) -# Function to convienently select some models from the ensemble -function select(ensemble::SDMensemble, indices::Vector{Int}) - trained_models = ensemble.trained_models[indices] +sdm_machines(group::SDMgroup) = group.sdm_machines - model_keys = unique(getfield.(trained_models, :model_key)) - resampler_keys = unique(getfield.(trained_models, :resampler_key)) +# machine_key generates a unique key for a machine +machine_keys(group::SDMgroup) = ["$(group.model_name)_$(group.resampler_name)_$(m.fold)" for m in group] - return SDMensemble( - trained_models, - ensemble.predictors, - ensemble.n_presences, - ensemble.n_absences, - ensemble.models[model_keys], - ensemble.resamplers[resampler_keys], - ensemble.data - ) +# A bunch of functions are applied to an ensemble by applying to each group and reducing with vcat +for f in (:machines, :machine_keys, :sdm_machines) + @eval ($f)(ensemble::SDMensemble) = mapreduce(group -> ($f)(group), vcat, ensemble) end -function select( - ensemble::SDMensemble; - models = keys(ensemble.models), - resamplers = keys(ensemble.resamplers), - folds = 1:1000) - - select( - ensemble, - findall( - m -> m.model_key in models && m.resampler_key in resamplers && m.fold in folds, - ensemble.trained_models +## Select methods + +# Function to convienently select some models from groups or ensembles +function select(group::SDMgroup, machine_indices::Vector{<:Int}) + if length(machine_indices) == 0 + return nothing + else + return SDMgroup( + group.sdm_machines[machine_indices], + group.model, + group.resampler, + group.model_name, + group.resampler_name ) - ) + end end -function Base.show(io::IO, mime::MIME"text/plain", ensemble::SDMensemble) - println(io, "SDMensemble with $(Base.length(ensemble)) models") +function select(ensemble::SDMensemble, group_indices::Vector{<:Int}) + if length(group_indices) == 0 + return nothing + else + return SDMensemble( + ensemble.groups[group_indices], + ) + end +end - println(io, "Model performance:") +function select(ensemble::SDMensemble; machines::Vector{<:Int}) + if length(machines) == 0 + return nothing + else + Statistics.maximum(machines) <= n_machines(ensemble) || throw(BoundsError(ensemble, machines)) + + j = 0 + groups = SDMgroup[] + groups = map(ensemble) do group + i = j + j = i + length(group) + select(group, filter(x -> x > i && x <= j, machines) .- i) + end + + return SDMensemble( + filter(!Base.isnothing, groups), + ) + end +end - aucs = auc_by_model(ensemble) - data = hcat(collect(keys(ensemble.models)), aucs) - header = (["model_key", "auc"]) - PrettyTables.pretty_table(io, data; header = header) +## Show methods +function Base.show(io::IO, mime::MIME"text/plain", ensemble::SDMensemble) + n_presence = sum(data(ensemble).response .== true) + n_absence = sum(data(ensemble).response .== false) + + sc = MLJBase.schema(data(ensemble).predictor) + sci = sc.scitypes + nam = sc.names + + println(io, "SDMensemble with $(n_machines(ensemble)) machines across $(Base.length(ensemble)) groups") + println(io, "Occurence data: Presence-Absence with $n_presence presences and $n_absence absences") + println(io, "Predictors: $(join(["$key ($scitype)" for (key, scitype) in zip(nam, sci)], ", "))") + + model_names = getfield.(ensemble.groups, :model_name) + resampler_names = getfield.(ensemble.groups, :resampler_name) + n_models = Base.length.(ensemble.groups) + table_cols = hcat(model_names, resampler_names, n_models) + header = (["model", "resampler", "machines"]) + PrettyTables.pretty_table(io, table_cols; header = header) + +end - #println(io, "model keys: ", keys(ensemble.models)) - #show(io, mime, Tables.rows(ensemble)) +function Base.show(io::IO, mime::MIME"text/plain", group::SDMgroup) + println(io, "SDMgroup with $(Base.length(group)) machines") + println(io, "Model $(group.model_name) and resampler $(group.resampler_name)") end + +## Table interface !! Is this still valid? Tables.istable(::Type{SDMensemble}) = true -Tables.schema(ensemble::SDMensemble) = Tables.schema(ensemble.trained_models[1]) -Tables.rows(ensemble::SDMensemble) = Tables.rows(ensemble.trained_models) -Tables.columns(ensemble::SDMensemble) = Tables.columns(ensemble.trained_models) +Tables.schema(ensemble::SDMensemble) = Tables.schema(ensemble.groups[1]) +Tables.rows(ensemble::SDMensemble) = Tables.rows(ensemble.groups) +Tables.columns(ensemble::SDMensemble) = Tables.columns(ensemble.groups) # Turns models into a NamedTuple with unique keys -function givenames(models::Vector) +function _givenames(models::Vector) names = map(models) do model replace(MLJBase.name(model), r"Classifier$"=>"") end @@ -84,10 +145,29 @@ function givenames(models::Vector) return NamedTuple{Tuple(Symbol.(names))}(models) end -function auc_by_model(ensemble) - mapreduce(vcat, keys(ensemble.models)) do key - Statistics.mean([model.auc for model in ensemble.trained_models if model.model_key == key]) +function _fit_sdm_model(predictor_values::NamedTuple, response_values, model, fold, train, test, verbosity) + mach = MLJBase.machine(model, predictor_values, response_values) + MLJBase.fit!(mach; rows = train, verbosity = verbosity) + return SDMmachine(mach, keys(predictor_values), fold, train, test) +end + +function _fit_sdm_group( + predictor_values::NamedTuple, + response_values, + model, + resampler, + folds, + model_name, + resampler_name, + verbosity + ) + + machines = map(enumerate(folds)) do (f, (train, test)) + _fit_sdm_model(predictor_values, response_values, model, f, train, test, verbosity) end + + return SDMgroup(machines, model, resampler, model_name, resampler_name) + end function sdm( @@ -95,11 +175,10 @@ function sdm( absence, models, resamplers; - var_keys::Vector{Symbol} = [key for key in Tables.schema(absence).names if in(key, Tables.schema(presences).names)], + var_keys::Vector{Symbol} = intersect(Tables.schema(absence).names, Tables.schema(presences).names), scitypes::Vector{DataType} = [MLJBase.scitype(Tables.schema(presences).types) for key in var_keys], verbosity::Int = 0 - ) - +) @assert Tables.istable(presences) && Tables.istable(absence) n_presence = Base.length(Tables.rows(presences)) ## @@ -108,37 +187,32 @@ function sdm( # merge presence and absence data into one namedtuple of vectors predictor_values = NamedTuple{Tuple(var_keys)}([[Tables.columns(absence)[var]; Tables.columns(presences)[var]] for var in var_keys]) - response_values = CategoricalArray( + response_values = CategoricalArrays.categorical( [falses(n_absence); trues(n_presence)]; - levels = [false, true], ordered = true) - - models_ = givenames(models) - resamplers_ = givenames(resamplers) + levels = [false, true], ordered = false + ) - trained_models = mapreduce(vcat, keys(resamplers_)) do resampler_key + models_ = _givenames(models) + resamplers_ = _givenames(resamplers) + + sdm_groups = mapreduce(vcat, collect(keys(resamplers_))) do resampler_key resampler = resamplers_[resampler_key] folds = MLJBase.train_test_pairs(resampler, 1:n_total, response_values) ## get indices - mapreduce(vcat, keys(models_)) do model_key + map(collect(keys(models_))) do model_key model = models_[model_key] - map(enumerate(folds)) do (f, (train, test)) - mach = MLJBase.machine(model, predictor_values, response_values) - MLJBase.fit!(mach; rows = train, verbosity = verbosity) - y_hat = MLJBase.predict(mach, rows = test) - auc = StatisticalMeasures.auc(y_hat, response_values[test]) - machine_key = Symbol(String(model_key) * "_" * String(resampler_key) * "_" * string(f)) - return (; machine = mach, auc = auc, model_key, resampler_key, fold = f, machine_key, train, test) - # Probably make a Type for this - end + _fit_sdm_group( + predictor_values, + response_values, + model, + resampler, + folds, + model_key, + resampler_key, + verbosity + ) end end - return SDMensemble( - trained_models, - var_keys, - n_presence, - n_absence, - models_, - resamplers_, - (predictor = predictor_values, response = response_values) - ) -end + return SDMensemble(sdm_groups) + +end \ No newline at end of file diff --git a/src/evaluate.jl b/src/evaluate.jl new file mode 100644 index 0000000..067b079 --- /dev/null +++ b/src/evaluate.jl @@ -0,0 +1,186 @@ +struct SDMmachineEvaluation + sdm_machine::SDMmachine + measures + results +end + +struct SDMgroupEvaluation <: AbstractVector{SDMmachineEvaluation} + machine_evaluations::Vector{SDMmachineEvaluation} + group::SDMgroup + measures + results +end + +struct SDMensembleEvaluation <: AbstractVector{SDMgroupEvaluation} + group_evaluations::Vector{SDMgroupEvaluation} + ensemble::SDMensemble + measures + results +end + +SDMgroupOrEnsembleEvaluation = Union{SDMgroupEvaluation, SDMensembleEvaluation} + +ScoreType = NamedTuple{(:score, :threshold), Tuple{Float64, Union{Missing, Float64}}} + +Base.getindex(ensemble::SDMensembleEvaluation, i) = ensemble.group_evaluations[i] +Base.getindex(group::SDMgroupEvaluation, i) = group.machine_evaluations[i] + +Base.size(ensemble::SDMensembleEvaluation) = Base.size(ensemble.group_evaluations) +Base.size(group::SDMgroupEvaluation) = Base.size(group.machine_evaluations) + +function machine_evaluations(groupeval::SDMgroupEvaluation; mean = false) + map((:train, :test)) do set + map(keys(groupeval.measures)) do key + r = map(groupeval) do e + e.results[set][key].score + end + + if mean + Statistics.mean(r) + else + r + end + + end |> NamedTuple{keys(groupeval.measures)} + end |> NamedTuple{(:train, :test)} +end + +function machine_evaluations(ensembleeval::SDMensembleEvaluation; mean = false) + map((:train, :test)) do set + map(keys(ensembleeval.measures)) do key + r = mapreduce(vcat, ensembleeval) do groupeval + map(groupeval) do e + e.results[set][key].score + end + end + + if mean + Statistics.mean(r) + else + r + end + + end |> NamedTuple{keys(ensembleeval.measures)} + end |> NamedTuple{(:train, :test)} +end + +## Show methods +function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMmachineEvaluation) + println(io, "SDMmachineEvaluation") + + measures = collect(keys(evaluation.measures)) + train_scores = round.(getfield.(collect(evaluation.results.train), :score); digits = 2) + test_scores = round.(getfield.(collect(evaluation.results.test), :score); digits = 2) + + table_cols = hcat(measures, train_scores, test_scores) + header = (["measure", "train_score", "test_score"]) + PrettyTables.pretty_table(io, table_cols; header = header) +end + +function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMgroupOrEnsembleEvaluation) + measures = collect(keys(evaluation.measures)) + train_scores, test_scores = machine_evaluations(evaluation, mean = true) + + group_scores = map(measures) do key + evaluation.results[key].score + end + + println(io, "$(typeof(evaluation)) with $(length(measures)) performance measures") + + table_cols = hcat(measures, collect(group_scores), collect(train_scores), collect(test_scores)) + header = (["measure", "performance of avg", "avg. train performance", "avg. test performance"]) + PrettyTables.pretty_table(io, table_cols; header = header) + +end + +## Core evuator +# internal method to get a vector of scores from y_hats, ys, and a namedtuple of measures +function _evaluate(y_hat, y, measures) + map(measures) do measure + # If the measures is threshold independent + if StatisticalMeasuresBase.kind_of_proxy(measure) == StatisticalMeasures.LearnAPI.Distribution() + return ScoreType((score = measure(y_hat, y), threshold = missing)) + else # else the measure uses thresholds + # first get all possible thresholded values + scores = pdf.(y_hat, true) + thresholds = unique(scores) + thresholded_scores = map(t -> CategoricalArrays.categorical(scores .>= t, levels = [false, true]), thresholds) + + # find the max value and corresponding threshold for measure + all_scores = measure.(thresholded_scores, Ref(y)) + max_score = findmax(all_scores) + return ScoreType((score = max_score[1], threshold = thresholds[max_score[2]])) + end + end +end + +# Evaluate a single SDMmachine +function evaluate( + sdm_machine::SDMmachine; + measures = (; + auc = StatisticalMeasures.auc, + log_loss = StatisticalMeasures.log_loss, + kappa = StatisticalMeasures.kappa + ) +) + results = map((train = sdm_machine.train_rows, test = sdm_machine.test_rows)) do rows + y_hat = MLJBase.predict(sdm_machine.machine, rows = rows) + y = data(sdm_machine).response[rows] + _evaluate(y_hat, y, measures) + end + + return SDMmachineEvaluation(sdm_machine, measures, results) +end + +# Evaluate a group +function evaluate( + group::SDMgroup; + measures = (; + auc = StatisticalMeasures.auc, + log_loss = StatisticalMeasures.log_loss, + kappa = StatisticalMeasures.kappa + ) +) + machine_evaluations = map(m -> (evaluate(m; measures = measures)), group) + + # average group prediction + y_hat = mapreduce(+, machines(group)) do mach + MLJBase.predict(mach) # MLJBase.predict because StatisticalMeasures expect UniverateFiniteArrays. + end / length(group) + + y = data(group).response + group_evaluation = _evaluate(y_hat, y, measures) + + return SDMgroupEvaluation( + machine_evaluations, + group, + measures, + group_evaluation + ) +end + +function evaluate( + ensemble::SDMensemble, + measures = (; + auc = StatisticalMeasures.auc, + log_loss = StatisticalMeasures.log_loss, + kappa = StatisticalMeasures.kappa) + ) + + group_evaluations = map(m -> (evaluate(m; measures = measures)), ensemble) + + # average ensemble prediction + y_hat = mapreduce(+, machines(ensemble)) do mach + MLJBase.predict(mach) + end / n_machines(ensemble) + + y = data(ensemble).response + ensemble_evaluation = _evaluate(y_hat, y, measures) + + return SDMensembleEvaluation( + group_evaluations, + ensemble, + measures, + ensemble_evaluation + ) +end \ No newline at end of file diff --git a/src/predict.jl b/src/predict.jl index 88b6b78..5763ba9 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -1,13 +1,22 @@ -function predict(ensemble::SDMensemble, data::NamedTuple) - # take data based on variable names - data_ = data[ensemble.predictors] +function _unsafe_predict(mach::SDMmachine, data) + CategoricalDistributions.pdf.(MLJBase.predict(mach.machine, data), true) +end - map(machines(ensemble)) do mach - # predict each machine and get the probability of true - Float64.(MLJBase.predict(mach, data_).prob_given_ref[2]) + +function predict(mach::SDMmachine, d) + data_ = Tables.columntable(d)[keys(data(mach).predictor)] + _unsafe_predict(mach, data_) +end + +function predict(s::SDMgroupOrEnsemble, d) + data_ = Tables.columntable(d)[keys(data(s).predictor)] + + mapreduce(hcat, sdm_machines(s)) do mach + _unsafe_predict(mach, data_) end end +#= on pause until RasterStacks are compatible with Tables.jl function predict(ensemble::SDMensemble, data::Rasters.RasterStack) preds = Tuple(ensemble.predictors) @@ -38,13 +47,15 @@ function predict(ensemble::SDMensemble, data::Rasters.RasterStack) return outraster end +=# -function predict(ensemble::SDMensemble, rows::Symbol) - y_hat_y = map(ensemble.trained_models) do model - y_hat = MLJBase.predict(model.machine, rows = model[rows]) - y = ensemble.data.response[model[rows]] +# inernal convenience function to predict just train or test rows for each machine +function _predict(s::SDMgroupOrEnsemble, rows::Symbol) + y_hat_y = map(sdm_machines(s)) do sdm_mach + y_hat = MLJBase.predict(sdm_mach.machine, rows = sdm_mach[rows]) + y = data(s).response[sdm_mach[rows]] return (;y_hat, y) end - return NamedTuple{Tuple(machine_keys(ensemble))}(y_hat_y) + return (y_hat_y) end \ No newline at end of file