Skip to content

Commit

Permalink
Create MLJ compliant doc strings (#138)
Browse files Browse the repository at this point in the history
* make doc-strings mlj-compliant
  • Loading branch information
ablaom authored Feb 1, 2023
1 parent bd9d915 commit 9d98288
Show file tree
Hide file tree
Showing 6 changed files with 472 additions and 150 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
DocStringExtensions = "0.8, 0.9"
IterativeSolvers = "0.8, 0.9"
LinearMaps = "2.6, 3.2"
MLJModelInterface = "0.3, 0.4, 1.0"
MLJModelInterface = "1.4"
Optim = "0.20, 0.21, 1"
Parameters = "0.12"
julia = "1.6, 1"
127 changes: 96 additions & 31 deletions src/mlj/classifiers.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,58 @@
#= ===================
LOGISTIC CLASSIFIER
=================== =#
=================== =#

"""
Logistic Classifier (typically called "Logistic Regression"). This model is
a standard classifier for both binary and multiclass classification.
In the binary case it corresponds to the LogisticLoss, in the multiclass to the
Multinomial (softmax) loss. An elastic net penalty can be applied with
overall objective function
$(doc_header(LogisticClassifier))
``L(y, Xθ) + n⋅λ|θ|₂²/2 + n⋅γ|θ|₁``
This model is more commonly known as "logistic regression". It is a standard classifier
for both binary and multiclass classification. The objective function applies either a
logistic loss (binary target) or multinomial (softmax) loss, and has a mixed L1/L2
penalty:
where ``L`` is either the logistic or multinomial loss and ``λ`` and ``γ`` indicate
the strength of the L2 (resp. L1) regularisation components and
``n`` is the number of samples `size(X, 1)`.
With `scale_penalty_with_samples = false` the objective function is
``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁``
``L(y, Xθ) + n⋅λ|θ|₂²/2 + n⋅γ|θ|₁``.
## Parameters
Here ``L`` is either `MLJLinearModels.LogisticLoss` or `MLJLinearModels.MultiClassLoss`,
``λ`` and ``γ`` indicate
the strength of the L2 (resp. L1) regularization components and
``n`` is the number of training observations.
With `scale_penalty_with_samples = false` the objective function is instead
``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁``.
# Training data
In MLJ or MLJBase, bind an instance `model` to data with
mach = machine(model, X, y)
where:
- `X` is any table of input features (eg, a `DataFrame`) whose columns
have `Continuous` scitype; check column scitypes with `schema(X)`
- `y` is the target, which can be any `AbstractVector` whose element
scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
with `scitype(y)`
Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters
$TYPEDFIELDS
$(example_docstring("LogisticClassifier", nclasses = 2))
See also [`MultinomialClassifier`](@ref).
"""
@with_kw_noshow mutable struct LogisticClassifier <: MMI.Probabilistic
"strength of the regulariser if `penalty` is `:l2` or `:l1` and strength of the L2
regulariser if `penalty` is `:en`."
"strength of the regularizer if `penalty` is `:l2` or `:l1` and strength of the L2
regularizer if `penalty` is `:en`."
lambda::Real = eps()
"strength of the L1 regulariser if `penalty` is `:en`."
"strength of the L1 regularizer if `penalty` is `:en`."
gamma::Real = 0.0
"the penalty to use, either `:l2`, `:l1`, `:en` (elastic net) or `:none`."
penalty::SymStr = :l2
Expand All @@ -37,7 +62,18 @@ $(example_docstring("LogisticClassifier", nclasses = 2))
penalize_intercept::Bool = false
"whether to scale the penalty with the number of samples."
scale_penalty_with_samples::Bool = true
"type of solver to use, default if `nothing`."
"""some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`, `Newton`,
`NewtonCG`, `ProxGrad`; but subject to the following restrictions:
- If `gamma > 0` (L1 norm penalized) then only `ProxGrad` is allowed.
- Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.
If `solver = nothing` (default) then `ProxGrad(accel=true)` (FISTA) is used,
unless `gamma = 0`, in which case `LBFGS()` is used.
Solver aliases: `FISTA(; kwargs...) = ProxGrad(accel=true, kwargs...)`,
`ISTA(; kwargs...) = ProxGrad(accel=false, kwargs...)`"""
solver::Option{Solver} = nothing
end

Expand All @@ -50,27 +86,49 @@ glr(m::LogisticClassifier, nclasses::Integer) =
scale_penalty_with_samples=m.scale_penalty_with_samples,
nclasses=nclasses)

descr(::Type{LogisticClassifier}) = "Classifier corresponding to the loss function ``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is the logistic loss."

#= ======================
MULTINOMIAL CLASSIFIER
====================== =#

"""
See `LogisticClassifier`, it's the same except that multiple classes are assumed
by default. The other parameters are the same.
$(doc_header(MultinomialClassifier))
This model coincides with [`LogisticClassifier`](@ref), except certain optimizations
possible in the special binary case will not be applied. Its hyperparameters are
identical.
# Training data
In MLJ or MLJBase, bind an instance `model` to data with
mach = machine(model, X, y)
where:
- `X` is any table of input features (eg, a `DataFrame`) whose columns
have `Continuous` scitype; check column scitypes with `schema(X)`
- `y` is the target, which can be any `AbstractVector` whose element
scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
with `scitype(y)`
Train the machine using `fit!(mach, rows=...)`.
## Parameters
# Hyperparameters
$TYPEDFIELDS
$(example_docstring("LogisticClassifier", nclasses = 3))
$(example_docstring("MultinomialClassifier", nclasses = 3))
See also [`LogisticClassifier`](@ref).
"""
@with_kw_noshow mutable struct MultinomialClassifier <: MMI.Probabilistic
"strength of the regulariser if `penalty` is `:l2` or `:l1`.
Strength of the L2 regulariser if `penalty` is `:en`."
"strength of the regularizer if `penalty` is `:l2` or `:l1`.
Strength of the L2 regularizer if `penalty` is `:en`."
lambda::Real = eps()
"strength of the L1 regulariser if `penalty` is `:en`."
"strength of the L1 regularizer if `penalty` is `:en`."
gamma::Real = 0.0
"the penalty to use, either `:l2`, `:l1`, `:en` (elastic net) or `:none`."
penalty::SymStr = :l2
Expand All @@ -80,7 +138,18 @@ $(example_docstring("LogisticClassifier", nclasses = 3))
penalize_intercept::Bool = false
"whether to scale the penalty with the number of samples."
scale_penalty_with_samples::Bool = true
"type of solver to use, default if `nothing`."
"""some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`,
`NewtonCG`, `ProxGrad`; but subject to the following restrictions:
- If `gamma > 0` (L1 norm penalized) then `ProxGrad` is disallowed.
- Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.
If `solver = nothing` (default) then `ProxGrad(accel=true)` (FISTA) is used,
unless `gamma = 0`, in which case `LBFGS()` is used.
Solver aliases: `FISTA(; kwargs...) = ProxGrad(accel=true, kwargs...)`,
`ISTA(; kwargs...) = ProxGrad(accel=false, kwargs...)`"""
solver::Option{Solver} = nothing
end

Expand All @@ -91,7 +160,3 @@ glr(m::MultinomialClassifier, nclasses::Integer) =
penalize_intercept=m.penalize_intercept,
scale_penalty_with_samples=m.scale_penalty_with_samples,
nclasses=nclasses)

descr(::Type{MultinomialClassifier}) =
"Classifier corresponding to the loss function " *
"``L(y, Xθ) + λ|θ|₂²/2 + γ|θ|₁`` where `L` is the multinomial loss."
40 changes: 40 additions & 0 deletions src/mlj/doc_tools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
const DOC_SOLVERS =
"Different solver options exist, as indicated under "*
"\"Hyperparameters\" below. "

function example_docstring(m; nclasses = nothing)
"""
## Example
using MLJ
X, y = $(nclasses == nothing ? "make_regression()" : "make_blobs(centers = $nclasses)")
mach = fit!(machine($m(), X, y))
predict(mach, X)
fitted_params(mach)
"""
end

function doc_header(ModelType)
name = MLJModelInterface.name(ModelType)
human_name = MLJModelInterface.human_name(ModelType)

"""
$name
A model type for constructing a $human_name, based on
[MLJLinearModels.jl](https://github.com/alan-turing-institute/MLJLinearModels.jl), and
implementing the MLJ model interface.
From MLJ, the type can be imported using
$name = @load $name pkg=MLJLinearModels
Do `model = $name()` to construct an instance with default
hyper-parameters.
"""
end

const DOC_PROXGRAD = "Aliases `ISTA` and `FSTA` correspond to "*
"`ProxGrad` with the option `acceleration=false` or `true` respectively. "
54 changes: 20 additions & 34 deletions src/mlj/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,7 @@ export LinearRegressor, RidgeRegressor, LassoRegressor, ElasticNetRegressor,

const SymStr = Union{Symbol,String}

function example_docstring(m; nclasses = nothing)
"""
## Example
using MLJ
X, y = $(nclasses == nothing ? "make_regression()" : "make_blobs(centers = $nclasses)")
mach = fit!(machine($m(), X, y))
predict(mach, X)
fitted_params(mach)
"""
end

include("doc_tools.jl")
include("regressors.jl")
include("classifiers.jl")

Expand Down Expand Up @@ -121,8 +109,6 @@ function MMI.fitted_params(m::Union{CLF_MODELS...}, (θ, features, classes, c))
return _fitted_params(θ, features, nothing)
end

@static VERSION < v"1.1" && (eachrow(A::AbstractVecOrMat) = (view(A, i, :) for i in axes(A, 1)))

coef_vec(W::AbstractMatrix, features) =
[feature => coef for (feature, coef) in zip(features, eachrow(W))]
coef_vec::AbstractVector, features) =
Expand All @@ -135,30 +121,30 @@ coef_vec(θ::AbstractVector, ::Nothing) = θ
======================= =#

MMI.metadata_pkg.(ALL_MODELS,
name="MLJLinearModels",
uuid="6ee0df7b-362f-4a72-a706-9e79364fb692",
url="https://github.com/alan-turing-institute/MLJLinearModels.jl",
julia=true,
license="MIT",
package_name="MLJLinearModels",
package_uuid="6ee0df7b-362f-4a72-a706-9e79364fb692",
package_url="https://github.com/alan-turing-institute/MLJLinearModels.jl",
is_pure_julia=true,
package_license="MIT",
is_wrapper=false)

descr_(M) = descr(M) *
"\n→ based on [MLJLinearModels](https://github.com/alan-turing-institute/MLJLinearModels.jl)" *
"\n→ do `@load $(MMI.name(M)) pkg=\"MLJLinearModels\" to use the model.`" *
"\n→ do `?$(MMI.name(M))` for documentation."
lp_(M) = "MLJLinearModels.$(MMI.name(M))"

for M in REG_MODELS
MMI.metadata_model(M,
input=MMI.Table(MMI.Continuous),
target=AbstractVector{MMI.Continuous},
weights=false,
descr=descr_(M), path=lp_(M))
MMI.metadata_model(
M,
input_scitype=MMI.Table(MMI.Continuous),
target_scitype=AbstractVector{MMI.Continuous},
load_path=lp_(M),
)
end
for M in CLF_MODELS
MMI.metadata_model(M,
input=MMI.Table(MMI.Continuous),
target=AbstractVector{<:MMI.Finite},
weights=false,
descr=descr_(M), path=lp_(M))
MMI.metadata_model(
M,
input_scitype=MMI.Table(MMI.Continuous),
target_scitype=AbstractVector{<:MMI.Finite},
load_path=lp_(M),
)
end

MMI.human_name(::Type{<:LADRegressor}) = "least absolute deviation regressor"
Loading

0 comments on commit 9d98288

Please sign in to comment.