diff --git a/Project.toml b/Project.toml index 218074b..2e38c8e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJIteration" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" authors = ["Anthony D. Blaom "] -version = "0.4.4" +version = "0.4.5" [deps] IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" diff --git a/src/traits.jl b/src/traits.jl index 289425c..28db5c6 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -12,21 +12,22 @@ MLJBase.package_url(::Type{<:EitherIteratedModel}) = MLJBase.package_license(::Type{<:EitherIteratedModel}) = "MIT" # inherited traits: -for T in [:DeterministicIteratedModel, :ProbabilisticIteratedModel] - for trait in [:supports_weights, - :supports_class_weights, - :is_pure_julia, - :input_scitype, - :output_scitype, - :target_scitype] +for trait in [:supports_weights, + :supports_class_weights, + :is_pure_julia, + :input_scitype, + :output_scitype, + :target_scitype] + quote + # needed because traits are not always deducable from + # the type (eg, `target_scitype` and `Pipeline` models): + MLJBase.$trait(imodel::EitherIteratedModel) = $trait(imodel.model) + end |> eval + for T in [:DeterministicIteratedModel, :ProbabilisticIteratedModel] quote # try to get trait at level of types ("failure" here just # means falling back to `Unknown`): MLJBase.$trait(::Type{<:$T{M}}) where M = MLJBase.$trait(M) - - # needed because traits are not always deducable from - # the type (eg, `target_scitype` and `Pipeline` models): - MLJBase.$trait(imodel::EitherIteratedModel) = $trait(imodel.model) end |> eval end end