Skip to content

Commit

Permalink
add datatype to OneHotEncoder and ContinuousEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Jul 5, 2024
1 parent 62bd007 commit e6bbbff
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/builtins/Transformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ end
drop_last::Bool = false
ordered_factor::Bool = true
ignore::Bool = false
datatype::DataType = Float64
end

# we store the categorical refs for each feature to be encoded and the
Expand Down Expand Up @@ -874,7 +875,7 @@ function MMI.transform(transformer::OneHotEncoder, fitresult, X)
names = last.(pairs)
cols_to_add = map(refs) do ref
if ismissing(ref) missing
else float.(hot(col, ref))
else transformer.datatype.(hot(col, ref))
end
end
append!(new_cols, cols_to_add)
Expand All @@ -893,6 +894,7 @@ end
@with_kw_noshow mutable struct ContinuousEncoder <: Unsupervised
drop_last::Bool = false
one_hot_ordered_factors::Bool = false
datatype::DataType = Float64
end

function MMI.fit(transformer::ContinuousEncoder, verbosity::Int, X)
Expand All @@ -918,7 +920,8 @@ function MMI.fit(transformer::ContinuousEncoder, verbosity::Int, X)
# fit the one-hot encoder:
hot_encoder =
OneHotEncoder(ordered_factor=transformer.one_hot_ordered_factors,
drop_last=transformer.drop_last)
drop_last=transformer.drop_last,
datatype=transformer.datatype)
hot_fitresult, _, hot_report = MMI.fit(hot_encoder, verbosity - 1, X)

new_features = setdiff(hot_report.new_features, features_to_be_dropped)
Expand Down Expand Up @@ -957,8 +960,15 @@ function MMI.transform(transformer::ContinuousEncoder, fitresult, X)
X1 = transform(hot_encoder, hot_fitresult, X0)

# convert remaining to continuous:
return coerce(X1, Count=>Continuous, OrderedFactor=>Continuous)

X2 = coerce(X1, Count=>Continuous, OrderedFactor=>Continuous)
if transformer.datatype == Float64
return X2
else
X3 = map(Tables.columntable(X2)) do c
transformer.datatype.(c)
end
return MMI.table(X3, prototype=X)
end
end


Expand Down

0 comments on commit e6bbbff

Please sign in to comment.