Skip to content

Commit

Permalink
feat: introduce new trait for special class column
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 23, 2024
1 parent 1bf13ef commit 9ea8c55
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,7 @@ function DE.EvaluationHelpersModule._grad_evaluator(
)
end

# Allows special handling of class columns in MLJInterface.jl
handles_class_column(::Type{<:AbstractExpression}) = false

end
20 changes: 13 additions & 7 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using DynamicQuantities:
dimension
using LossFunctions: SupervisedLoss
using ..InterfaceDynamicQuantitiesModule: get_dimensions_type
using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE
using ..CoreModule:
Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE, ComplexityMapping
using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
Expand Down Expand Up @@ -212,7 +213,10 @@ function _update(
options,
class,
)
if isnothing(class) && MMI.istable(X) && :class in MMI.schema(X).names
if IDE.handles_class_column(m.expression_type) &&
isnothing(class) &&
MMI.istable(X) &&
:class in MMI.schema(X).names
names_without_class = filter(!=(:class), MMI.schema(X).names)
new_X = MMI.selectcols(X, collect(names_without_class))
new_class = MMI.selectcols(X, :class)
Expand Down Expand Up @@ -486,12 +490,14 @@ function _predict(m::M, fitresult, Xnew, idx, class) where {M<:AbstractSRRegress
)
return _predict(m, fitresult, Xnew.data, Xnew.idx, class)
end
if isnothing(class) && MMI.istable(Xnew) && haskey(Xnew, :class)
if !(Xnew isa NamedTuple)
error("Classes can only be specified with named tuples.")
end
Xnew2 = Base.structdiff(Xnew, (; Xnew.class))
return _predict(m, fitresult, Xnew2, idx, Xnew.class)
if IDE.handles_class_column(m.expression_type) &&
isnothing(class) &&
MMI.istable(Xnew) &&
:class in MMI.schema(Xnew).names
names_without_class = filter(!=(:class), MMI.schema(Xnew).names)
Xnew2 = MMI.selectcols(Xnew, collect(names_without_class))
class = MMI.selectcols(Xnew, :class)
return _predict(m, fitresult, Xnew2, idx, class)
end

if fitresult.has_class
Expand Down
9 changes: 6 additions & 3 deletions src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using Random: default_rng, AbstractRNG

using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, AbstractMutationWeights
using ..PopMemberModule: PopMember
using ..InterfaceDynamicExpressionsModule: expected_array_type
using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE
using ..LossFunctionsModule: LossFunctionsModule as LF
using ..ExpressionBuilderModule: ExpressionBuilderModule as EB
using ..MutateModule: MutateModule as MM
Expand Down Expand Up @@ -65,7 +65,7 @@ function DE.eval_tree_array(
options::AbstractOptions;
kws...,
)
A = expected_array_type(X, typeof(tree))
A = IDE.expected_array_type(X, typeof(tree))
out, complete = DE.eval_tree_array(
tree,
X,
Expand All @@ -80,7 +80,7 @@ end
function LF.eval_tree_dispatch(
tree::ParametricExpression, dataset::Dataset, options::AbstractOptions, idx
)
A = expected_array_type(dataset.X, typeof(tree))
A = IDE.expected_array_type(dataset.X, typeof(tree))
out, complete = DE.eval_tree_array(
tree,
LF.maybe_getindex(dataset.X, :, idx),
Expand Down Expand Up @@ -181,4 +181,7 @@ function MF.mutate_constant(
end
end

# ParametricExpression handles class columns
IDE.handles_class_column(::Type{<:ParametricExpression}) = true

end

0 comments on commit 9ea8c55

Please sign in to comment.