Skip to content

Commit

Permalink
Merge pull request JuliaTrustworthyAI#108 from john-waczak/main
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt authored Nov 13, 2023
2 parents 444dc20 + 33cae6e commit c4ef89a
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 41 deletions.
4 changes: 4 additions & 0 deletions src/conformal_models/conformal_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ const tested_atomic_models = Dict(
:lasso => :(@load LassoRegressor pkg = MLJLinearModels),
:evo_tree => :(@load EvoTreeRegressor pkg = EvoTrees),
:nearest_neighbor => :(@load KNNRegressor pkg = NearestNeighborModels),
:decision_tree_regressor => :(@load DecisionTreeRegressor pkg = DecisionTree),
:random_forest_regressor => :(@load RandomForestRegressor pkg = DecisionTree),
# :light_gbm => :(@load LGBMRegressor pkg = LightGBM),
# :neural_network => :(@load NeuralNetworkRegressor pkg = MLJFlux),
# :symbolic_regression => (@load SRRegressor pkg = SymbolicRegression),
Expand All @@ -119,6 +121,8 @@ const tested_atomic_models = Dict(
:logistic => :(@load LogisticClassifier pkg = MLJLinearModels),
:evo_tree => :(@load EvoTreeClassifier pkg = EvoTrees),
:nearest_neighbor => :(@load KNNClassifier pkg = NearestNeighborModels),
:decision_tree_classifier => :(@load DecisionTreeClassifier pkg = DecisionTree),
:random_forest_classifier => :(@load RandomForestClassifier pkg = DecisionTree),
# :light_gbm => :(@load LGBMClassifier pkg = LightGBM),
# :neural_network => :(@load NeuralNetworkClassifier pkg = MLJFlux),
),
Expand Down
29 changes: 18 additions & 11 deletions src/conformal_models/inductive_classification.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
score(conf_model::ConformalProbabilisticSet, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Generic score method for the [`ConformalProbabilisticSet`](@ref). It computes nonconformity scores using the heuristic function `h` and the softmax probabilities of the true class. Method is dispatched for different Conformal Probabilistic Sets and atomic models.
"""
function score(
Expand All @@ -21,10 +21,9 @@ function split_data(conf_model::ConformalProbabilisticSet, X, y)
train, calibration = partition(eachindex(y), conf_model.train_ratio)
Xtrain = selectrows(X, train)
ytrain = y[train]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)
Xcal = selectrows(X, calibration)
ycal = y[calibration]
Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal)

return Xtrain, ytrain, Xcal, ycal
end

Expand All @@ -50,7 +49,7 @@ end
"""
score(conf_model::SimpleInductiveClassifier, ::Type{<:Supervised}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Score method for the [`SimpleInductiveClassifier`](@ref) dispatched for any `<:Supervised` model.
Score method for the [`SimpleInductiveClassifier`](@ref) dispatched for any `<:Supervised` model.
"""
function score(
conf_model::SimpleInductiveClassifier,
Expand All @@ -59,7 +58,9 @@ function score(
X,
y::Union{Nothing,AbstractArray}=nothing,
)
= reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, X))
= reformat_mlj_prediction(
MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, X)...)
)
L =.decoder.classes
probas = pdf(p̂, L)
scores = @.(conf_model.heuristic(y, probas))
Expand Down Expand Up @@ -87,8 +88,10 @@ function MMI.fit(conf_model::SimpleInductiveClassifier, verbosity, X, y)
# Data Splitting:
Xtrain, ytrain, Xcal, ycal = split_data(conf_model, X, y)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)
# Training:
fitresult, cache, report = MMI.fit(
conf_model.model, verbosity, MMI.reformat(conf_model.model, Xtrain, ytrain)...
)

# Nonconformity Scores:
cal_scores, scores = score(conf_model, fitresult, Xcal, ycal)
Expand Down Expand Up @@ -161,8 +164,10 @@ function MMI.fit(conf_model::AdaptiveInductiveClassifier, verbosity, X, y)
# Data Splitting:
Xtrain, ytrain, Xcal, ycal = split_data(conf_model, X, y)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)
# Training:
fitresult, cache, report = MMI.fit(
conf_model.model, verbosity, MMI.reformat(conf_model.model, Xtrain, ytrain)...
)

# Nonconformity Scores:
cal_scores, scores = score(conf_model, fitresult, Xcal, ycal)
Expand All @@ -174,7 +179,7 @@ end
"""
score(conf_model::AdaptiveInductiveClassifier, ::Type{<:Supervised}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Score method for the [`AdaptiveInductiveClassifier`](@ref) dispatched for any `<:Supervised` model.
Score method for the [`AdaptiveInductiveClassifier`](@ref) dispatched for any `<:Supervised` model.
"""
function score(
conf_model::AdaptiveInductiveClassifier,
Expand All @@ -183,7 +188,9 @@ function score(
X,
y::Union{Nothing,AbstractArray}=nothing,
)
= reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, X))
= reformat_mlj_prediction(
MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, X)...)
)
L =.decoder.classes
probas = pdf(p̂, L) # compute probabilities for all classes
scores = map(Base.Iterators.product(eachrow(probas), L)) do Z
Expand Down
12 changes: 7 additions & 5 deletions src/conformal_models/inductive_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ function MMI.fit(conf_model::SimpleInductiveRegressor, verbosity, X, y)
train, calibration = partition(eachindex(y), conf_model.train_ratio)
Xtrain = selectrows(X, train)
ytrain = y[train]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)
Xcal = selectrows(X, calibration)
ycal = y[calibration]
Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)
# Training:
fitresult, cache, report = MMI.fit(
conf_model.model, verbosity, MMI.reformat(conf_model.model, Xtrain, ytrain)...
)

# Nonconformity Scores:
= reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, Xcal))
= reformat_mlj_prediction(
MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xcal)...)
)
conf_model.scores = @.(conf_model.heuristic(ycal, ŷ))

return (fitresult, cache, report)
Expand Down
15 changes: 10 additions & 5 deletions src/conformal_models/transductive_classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@ For the [`NaiveClassifier`](@ref) nonconformity scores are computed in-sample as
S_i^{\text{IS}} = s(X_i, Y_i) = h(\hat\mu(X_i), Y_i), \ i \in \mathcal{D}_{\text{calibration}}
``
A typical choice for the heuristic function is ``h(\hat\mu(X_i), Y_i)=1-\hat\mu(X_i)_{Y_i}`` where ``\hat\mu(X_i)_{Y_i}`` denotes the softmax output of the true class and ``\hat\mu`` denotes the model fitted on training data ``\mathcal{D}_{\text{train}}``.
A typical choice for the heuristic function is ``h(\hat\mu(X_i), Y_i)=1-\hat\mu(X_i)_{Y_i}`` where ``\hat\mu(X_i)_{Y_i}`` denotes the softmax output of the true class and ``\hat\mu`` denotes the model fitted on training data ``\mathcal{D}_{\text{train}}``.
"""
function MMI.fit(conf_model::NaiveClassifier, verbosity, X, y)

# Setup:
Xtrain = selectrows(X, :)
ytrain = y[:]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)
# Training:
fitresult, cache, report = MMI.fit(
conf_model.model, verbosity, MMI.reformat(conf_model.model, Xtrain, ytrain)...
)

# Nonconformity Scores:
=
pdf.(
reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, Xtrain)),
reformat_mlj_prediction(
MMI.predict(
conf_model.model, fitresult, MMI.reformat(conf_model.model, Xtrain)...
),
),
ytrain,
)
conf_model.scores = @.(conf_model.heuristic(y, ŷ))
Expand Down
28 changes: 16 additions & 12 deletions src/conformal_models/transductive_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ function MMI.fit(conf_model::NaiveRegressor, verbosity, X, y)
# Setup:
Xtrain = selectrows(X, :)
ytrain = y[:]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)
fitresult, cache, report = MMI.fit(
conf_model.model, verbosity, MMI.reformat(conf_model.model, Xtrain, ytrain)...
)

# Nonconformity Scores:
= reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, Xtrain))
= reformat_mlj_prediction(
MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xtrain)...)
)
conf_model.scores = @.(conf_model.heuristic(ytrain, ŷ))

return (fitresult, cache, report)
Expand Down Expand Up @@ -101,10 +104,11 @@ function MMI.fit(conf_model::JackknifeRegressor, verbosity, X, y)
# Setup:
Xtrain = selectrows(X, :)
ytrain = y[:]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)
# Training:
fitresult, cache, report = MMI.fit(
conf_model.model, verbosity, MMI.reformat(conf_model.model, Xtrain, ytrain)...
)

# Nonconformity Scores:
T = size(y, 1)
Expand Down Expand Up @@ -177,7 +181,7 @@ where ``\hat\mu_{-i}(X_i)`` denotes the leave-one-out prediction for ``X_i``. In
"""
function MMI.fit(conf_model::JackknifePlusRegressor, verbosity, X, y)

# Training:
# Training:
fitresult, cache, report = ([], [], [])

# Nonconformity Scores:
Expand Down Expand Up @@ -266,7 +270,7 @@ where ``\hat\mu_{-i}(X_i)`` denotes the leave-one-out prediction for ``X_i``. In
"""
function MMI.fit(conf_model::JackknifeMinMaxRegressor, verbosity, X, y)

# Pre-allocate:
# Pre-allocate:
fitresult, cache, report = ([], [], [])

# Training and Nonconformity Scores:
Expand Down Expand Up @@ -370,7 +374,7 @@ function MMI.fit(conf_model::CVPlusRegressor, verbosity, X, y)
Dict(:fitresult => μ̂ₖ, :test => test, :cache => cache, :report => report)
end

# Pre-allocate:
# Pre-allocate:
fitresult, cache, report = ([], [], [])

# Nonconformity Scores:
Expand Down Expand Up @@ -405,10 +409,10 @@ end
For the [`CVPlusRegressor`](@ref) prediction intervals are computed in much same way as for the [`JackknifePlusRegressor`](@ref). Specifically, we have,
``
\hat{C}_{n,\alpha}(X_{n+1}) = \left[ \hat{q}_{n, \alpha}^{-} \{\hat\mu_{-\mathcal{D}_{k(i)}}(X_{n+1}) - S_i^{\text{CV}} \}, \hat{q}_{n, \alpha}^{+} \{\hat\mu_{-\mathcal{D}_{k(i)}}(X_{n+1}) + S_i^{\text{CV}}\} \right] , \ i \in \mathcal{D}_{\text{train}}
\hat{C}_{n,\alpha}(X_{n+1}) = \left[ \hat{q}_{n, \alpha}^{-} \{\hat\mu_{-\mathcal{D}_{k(i)}}(X_{n+1}) - S_i^{\text{CV}} \}, \hat{q}_{n, \alpha}^{+} \{\hat\mu_{-\mathcal{D}_{k(i)}}(X_{n+1}) + S_i^{\text{CV}}\} \right] , \ i \in \mathcal{D}_{\text{train}}
``
where ``\hat\mu_{-\mathcal{D}_{k(i)}}`` denotes the model fitted on training data with fold ``\mathcal{D}_{k(i)}`` that contains the ``i`` th point removed.
where ``\hat\mu_{-\mathcal{D}_{k(i)}}`` denotes the model fitted on training data with fold ``\mathcal{D}_{k(i)}`` that contains the ``i`` th point removed.
The [`JackknifePlusRegressor`](@ref) is a special case of the [`CVPlusRegressor`](@ref) for which ``K=n``.
"""
Expand Down Expand Up @@ -475,7 +479,7 @@ function MMI.fit(conf_model::CVMinMaxRegressor, verbosity, X, y)
Dict(:fitresult => μ̂ₖ, :test => test, :cache => cache, :report => report)
end

# Pre-allocate:
# Pre-allocate:
fitresult, cache, report = ([], [], [])

# Nonconformity Scores:
Expand Down
22 changes: 14 additions & 8 deletions test/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.9.0"
julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "35130c42d0ed70ece3ae50bdfeedecf590b3fb1d"
project_hash = "395fbffc89c30a0e7ec8b0c754d4d91fbeabc585"

[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
Expand Down Expand Up @@ -279,7 +279,7 @@ version = "0.1.25"
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.2+0"
version = "1.0.5+0"

[[deps.CompositionsBase]]
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
Expand Down Expand Up @@ -355,9 +355,9 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.DecisionTree]]
deps = ["AbstractTrees", "DelimitedFiles", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics"]
git-tree-sha1 = "c6475a3ccad06cb1c2ebc0740c1bb4fe5a0731b7"
git-tree-sha1 = "526ca14aaaf2d5a0e242f3a8a7966eb9065d7d78"
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
version = "0.12.3"
version = "0.12.4"

[[deps.DefineSingletons]]
git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c"
Expand Down Expand Up @@ -441,6 +441,12 @@ git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6"
uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
version = "0.3.0"

[[deps.EpollShim_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643"
uuid = "2702e6a9-849d-5ed8-8c21-79e8b8f9ee43"
version = "0.0.20230411+0"

[[deps.EvoTrees]]
deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "1b63fdc0acad47c3203398171c138835c1c40d69"
Expand Down Expand Up @@ -1238,7 +1244,7 @@ version = "0.42.2+0"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.9.0"
version = "1.9.2"

[[deps.PlotThemes]]
deps = ["PlotUtils", "Statistics"]
Expand Down Expand Up @@ -1684,7 +1690,7 @@ uuid = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d"
version = "0.2.0"

[[deps.Wayland_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"]
deps = ["Artifacts", "EpollShim_jll", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"]
git-tree-sha1 = "ed8d92d9774b077c53e1da50fd81a36af3744c1c"
uuid = "a2964d1f-97da-50d4-b82a-358c7fce9d89"
version = "1.21.0+0"
Expand Down Expand Up @@ -1893,7 +1899,7 @@ version = "0.15.1+0"
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.7.0+0"
version = "5.8.0+0"

[[deps.libfdk_aac_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
CompatHelperLocal = "5224ae11-6099-4aaa-941d-3aab004bd678"
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
LightGBM = "7acf609c-83a4-11e9-1ffb-b912bcd3b04a"
Expand Down

0 comments on commit c4ef89a

Please sign in to comment.