diff --git a/src/builtins/Constant.jl b/src/builtins/Constant.jl index 83a39b4d..65d0d12b 100644 --- a/src/builtins/Constant.jl +++ b/src/builtins/Constant.jl @@ -11,17 +11,17 @@ struct ConstantRegressor{D<:Distributions.Sampleable} <: Probabilistic distribution_type::Type{D} end -ConstantRegressor(; distribution_type=Distributions.Normal) = +ConstantRegressor(; distribution_type = Distributions.Normal) = ConstantRegressor(distribution_type) -function MLJBase.fit(::ConstantRegressor{D}, verbosity::Int, X, y) where D +function MLJBase.fit(::ConstantRegressor{D}, verbosity::Int, X, y) where {D} fitresult = Distributions.fit(D, y) - cache = nothing - report = NamedTuple() + cache = nothing + report = NamedTuple() return fitresult, cache, report end -MLJBase.fitted_params(::ConstantRegressor, fitresult) = (target_distribution=fitresult,) +MLJBase.fitted_params(::ConstantRegressor, fitresult) = (target_distribution = fitresult,) MLJBase.predict(::ConstantRegressor, fitresult, Xnew) = fill(fitresult, nrows(Xnew)) @@ -33,12 +33,13 @@ struct DeterministicConstantRegressor <: Deterministic end function MLJBase.fit(::DeterministicConstantRegressor, verbosity::Int, X, y) fitresult = mean(y) - cache = nothing - report = NamedTuple() + cache = nothing + report = NamedTuple() return fitresult, cache, report end -MLJBase.predict(::DeterministicConstantRegressor, fitresult, Xnew) = fill(fitresult, nrows(Xnew)) +MLJBase.predict(::DeterministicConstantRegressor, fitresult, Xnew) = + fill(fitresult, nrows(Xnew)) ## ## THE CONSTANT CLASSIFIER @@ -56,14 +57,14 @@ obtain the training target mode instead. struct ConstantClassifier <: Probabilistic end # here `args` is `y` or `y, w`: -function MLJBase.fit(::ConstantClassifier, verbosity::Int, X, y, w=nothing) +function MLJBase.fit(::ConstantClassifier, verbosity::Int, X, y, w = nothing) fitresult = Distributions.fit(MLJBase.UnivariateFinite, y, w) - cache = nothing - report = NamedTuple + cache = nothing + report = NamedTuple return fitresult, cache, report end -MLJBase.fitted_params(::ConstantClassifier, fitresult) = (target_distribution=fitresult,) +MLJBase.fitted_params(::ConstantClassifier, fitresult) = (target_distribution = fitresult,) MLJBase.predict(::ConstantClassifier, fitresult, Xnew) = fill(fitresult, nrows(Xnew)) @@ -76,51 +77,65 @@ struct DeterministicConstantClassifier <: Deterministic end function MLJBase.fit(::DeterministicConstantClassifier, verbosity::Int, X, y) # dump missing target values and make into a regular array: fitresult = mode(skipmissing(y) |> collect) # a CategoricalValue or CategoricalString - cache = nothing - report = NamedTuple() + cache = nothing + report = NamedTuple() return fitresult, cache, report end -MLJBase.predict(::DeterministicConstantClassifier, fitresult, Xnew) = fill(fitresult, nrows(Xnew)) +MLJBase.predict(::DeterministicConstantClassifier, fitresult, Xnew) = + fill(fitresult, nrows(Xnew)) ## ## METADATA ## metadata_pkg.( - (ConstantRegressor, ConstantClassifier, - DeterministicConstantRegressor, DeterministicConstantClassifier), - name = "MLJModels", - uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7", - url = "https://github.com/alan-turing-institute/MLJModels.jl", - julia = true, - license = "MIT", - is_wrapper = false) - -metadata_model(ConstantRegressor, - input = Table(Scientific), - target = AbstractVector{Continuous}, + ( + ConstantRegressor, + ConstantClassifier, + DeterministicConstantRegressor, + DeterministicConstantClassifier, + ), + name = "MLJModels", + uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7", + url = "https://github.com/alan-turing-institute/MLJModels.jl", + julia = true, + license = "MIT", + is_wrapper = false, +) + +metadata_model( + ConstantRegressor, + input = Table(Scientific), + target = AbstractVector{Continuous}, weights = false, - descr = "Constant regressor (Probabilistic).", - path = "MLJModels.ConstantRegressor") - -metadata_model(DeterministicConstantRegressor, - input = Table(Scientific), - target = AbstractVector{Continuous}, + descr = "Constant regressor (Probabilistic).", + path = "MLJModels.ConstantRegressor", +) + +metadata_model( + DeterministicConstantRegressor, + input = Table(Scientific), + target = AbstractVector{Continuous}, weights = false, - descr = "Constant regressor (Deterministic).", - path = "MLJModels.DeterministicConstantRegressor") - -metadata_model(ConstantClassifier, - input = Table(Scientific), - target = AbstractVector{<:Finite}, + descr = "Constant regressor (Deterministic).", + path = "MLJModels.DeterministicConstantRegressor", +) + +metadata_model( + ConstantClassifier, + input = Table(Scientific), + target = AbstractVector{<:Finite}, weights = true, - descr = "Constant classifier (Probabilistic).", - path = "MLJModels.ConstantClassifier") - -metadata_model(DeterministicConstantClassifier, - input = Table(Scientific), - target = AbstractVector{<:Finite}, + descr = "Constant classifier (Probabilistic).", + path = "MLJModels.ConstantClassifier", +) + +metadata_model( + DeterministicConstantClassifier, + input = Table(Scientific), + target = AbstractVector{<:Finite}, weights = false, - descr = "Constant classifier (Deterministic).", - path = "MLJModels.DeterministicConstantClassifier") + descr = "Constant classifier (Deterministic).", + path = "MLJModels.DeterministicConstantClassifier", +)