From e6bbbffad88b62b0feba7735c10e76720fcc967a Mon Sep 17 00:00:00 2001 From: tiemvanderdeure Date: Fri, 5 Jul 2024 17:00:24 +0200 Subject: [PATCH] add datatype to OneHotEncoder and ContinuousEncoder --- src/builtins/Transformers.jl | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/builtins/Transformers.jl b/src/builtins/Transformers.jl index fd00d43..c1a6261 100644 --- a/src/builtins/Transformers.jl +++ b/src/builtins/Transformers.jl @@ -732,6 +732,7 @@ end drop_last::Bool = false ordered_factor::Bool = true ignore::Bool = false + datatype::DataType = Float64 end # we store the categorical refs for each feature to be encoded and the @@ -874,7 +875,7 @@ function MMI.transform(transformer::OneHotEncoder, fitresult, X) names = last.(pairs) cols_to_add = map(refs) do ref if ismissing(ref) missing - else float.(hot(col, ref)) + else transformer.datatype.(hot(col, ref)) end end append!(new_cols, cols_to_add) @@ -893,6 +894,7 @@ end @with_kw_noshow mutable struct ContinuousEncoder <: Unsupervised drop_last::Bool = false one_hot_ordered_factors::Bool = false + datatype::DataType = Float64 end function MMI.fit(transformer::ContinuousEncoder, verbosity::Int, X) @@ -918,7 +920,8 @@ function MMI.fit(transformer::ContinuousEncoder, verbosity::Int, X) # fit the one-hot encoder: hot_encoder = OneHotEncoder(ordered_factor=transformer.one_hot_ordered_factors, - drop_last=transformer.drop_last) + drop_last=transformer.drop_last, + datatype=transformer.datatype) hot_fitresult, _, hot_report = MMI.fit(hot_encoder, verbosity - 1, X) new_features = setdiff(hot_report.new_features, features_to_be_dropped) @@ -957,8 +960,15 @@ function MMI.transform(transformer::ContinuousEncoder, fitresult, X) X1 = transform(hot_encoder, hot_fitresult, X0) # convert remaining to continuous: - return coerce(X1, Count=>Continuous, OrderedFactor=>Continuous) - + X2 = coerce(X1, Count=>Continuous, OrderedFactor=>Continuous) + if transformer.datatype == Float64 + return X2 + else + X3 = map(Tables.columntable(X2)) do c + transformer.datatype.(c) + end + return MMI.table(X3, prototype=X) + end end