From 9ea8c553103efd3206dc1765c698aad6318cca8e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 22 Dec 2024 19:49:53 -0500 Subject: [PATCH] feat: introduce new trait for special class column --- src/InterfaceDynamicExpressions.jl | 3 +++ src/MLJInterface.jl | 20 +++++++++++++------- src/ParametricExpression.jl | 9 ++++++--- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index b30957064..8b79230aa 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -359,4 +359,7 @@ function DE.EvaluationHelpersModule._grad_evaluator( ) end +# Allows special handling of class columns in MLJInterface.jl +handles_class_column(::Type{<:AbstractExpression}) = false + end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 8d2a0045e..d29fc278b 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -25,6 +25,7 @@ using DynamicQuantities: dimension using LossFunctions: SupervisedLoss using ..InterfaceDynamicQuantitiesModule: get_dimensions_type +using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..CoreModule: Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE, ComplexityMapping using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS @@ -212,7 +213,10 @@ function _update( options, class, ) - if isnothing(class) && MMI.istable(X) && :class in MMI.schema(X).names + if IDE.handles_class_column(m.expression_type) && + isnothing(class) && + MMI.istable(X) && + :class in MMI.schema(X).names names_without_class = filter(!=(:class), MMI.schema(X).names) new_X = MMI.selectcols(X, collect(names_without_class)) new_class = MMI.selectcols(X, :class) @@ -486,12 +490,14 @@ function _predict(m::M, fitresult, Xnew, idx, class) where {M<:AbstractSRRegress ) return _predict(m, fitresult, Xnew.data, Xnew.idx, class) end - if isnothing(class) && MMI.istable(Xnew) && haskey(Xnew, :class) - if !(Xnew isa NamedTuple) - error("Classes can only be specified with named tuples.") - end - Xnew2 = Base.structdiff(Xnew, (; Xnew.class)) - return _predict(m, fitresult, Xnew2, idx, Xnew.class) + if IDE.handles_class_column(m.expression_type) && + isnothing(class) && + MMI.istable(Xnew) && + :class in MMI.schema(Xnew).names + names_without_class = filter(!=(:class), MMI.schema(Xnew).names) + Xnew2 = MMI.selectcols(Xnew, collect(names_without_class)) + class = MMI.selectcols(Xnew, :class) + return _predict(m, fitresult, Xnew2, idx, class) end if fitresult.has_class diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index a5664fd45..5b92db8c3 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -17,7 +17,7 @@ using Random: default_rng, AbstractRNG using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, AbstractMutationWeights using ..PopMemberModule: PopMember -using ..InterfaceDynamicExpressionsModule: expected_array_type +using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..LossFunctionsModule: LossFunctionsModule as LF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB using ..MutateModule: MutateModule as MM @@ -65,7 +65,7 @@ function DE.eval_tree_array( options::AbstractOptions; kws..., ) - A = expected_array_type(X, typeof(tree)) + A = IDE.expected_array_type(X, typeof(tree)) out, complete = DE.eval_tree_array( tree, X, @@ -80,7 +80,7 @@ end function LF.eval_tree_dispatch( tree::ParametricExpression, dataset::Dataset, options::AbstractOptions, idx ) - A = expected_array_type(dataset.X, typeof(tree)) + A = IDE.expected_array_type(dataset.X, typeof(tree)) out, complete = DE.eval_tree_array( tree, LF.maybe_getindex(dataset.X, :, idx), @@ -181,4 +181,7 @@ function MF.mutate_constant( end end +# ParametricExpression handles class columns +IDE.handles_class_column(::Type{<:ParametricExpression}) = true + end