diff --git a/src/builtins/Transformers.jl b/src/builtins/Transformers.jl index 4f4b6a9a..cfeefcd1 100644 --- a/src/builtins/Transformers.jl +++ b/src/builtins/Transformers.jl @@ -860,7 +860,7 @@ function MMI.fit(transformer::OneHotEncoder, verbosity::Int, X) if T <: allowed_scitypes && ftr in specified_features ref_name_pairs_given_feature[ftr] = Pair{<:Unsigned,Symbol}[] shift = transformer.drop_last ? 1 : 0 - levels = classes(first(col)) + levels = classes(col) fitted_levels_given_feature[ftr] = levels if verbosity > 0 @info "Spawning $(length(levels)-shift) sub-features "* diff --git a/test/builtins/Transformers.jl b/test/builtins/Transformers.jl index ee8e403c..f4d19933 100644 --- a/test/builtins/Transformers.jl +++ b/test/builtins/Transformers.jl @@ -522,6 +522,16 @@ end @test_throws Exception Xt.favourite_number__10 @test_throws Exception Xt.name__Mary @test report.new_features == collect(MLJBase.schema(Xt).names) + + # Test when the first value is missing + X = (name=categorical([missing, "John", "Mary", "John"]),) + t = OneHotEncoder() + f, _, _ = MLJBase.fit(t, 0, X) + Xt = MLJBase.transform(t, f, X) + @test Xt.name__John[1] === Xt.name__Mary[1] === missing + @test Xt.name__John[2:end] == Union{Missing, Float64}[1.0, 0.0, 1.0] + @test Xt.name__Mary[2:end] == Union{Missing, Float64}[0.0, 1.0, 0.0] + end