From cec88fd5ae7f80617ce4c0a1bbd415c0c30639dc Mon Sep 17 00:00:00 2001 From: Thibaut Lienart Date: Fri, 10 Jun 2022 10:42:26 +0200 Subject: [PATCH 1/5] closes #784 by fixing typename + tests --- Project.toml | 3 ++- src/interface/data_utils.jl | 20 ++++++++++++++------ test/interface/data_utils.jl | 15 +++++++++++++-- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 727b7685..4f1bc639 100644 --- a/Project.toml +++ b/Project.toml @@ -49,6 +49,7 @@ Tables = "0.2, 1.0" julia = "1.6" [extras] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -59,4 +60,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" [targets] -test = ["DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"] +test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"] diff --git a/src/interface/data_utils.jl b/src/interface/data_utils.jl index 715410e8..d1dbb791 100644 --- a/src/interface/data_utils.jl +++ b/src/interface/data_utils.jl @@ -97,14 +97,22 @@ function MMI.selectrows(::FI, ::Val{:table}, X, r) end function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer}) - cols = Tables.columntable(X) # named tuple of vectors - return cols[c] + if !isdataframe(X) + cols = Tables.columntable(X) # named tuple of vectors + return cols[c] + else + return X[!, c] + end end function MMI.selectcols(::FI, ::Val{:table}, X, c::AbstractArray) - cols = Tables.columntable(X) # named tuple of vectors - newcols = project(cols, c) - return Tables.materializer(X)(newcols) + if !isdataframe(X) + cols = Tables.columntable(X) # named tuple of vectors + newcols = project(cols, c) + return Tables.materializer(X)(newcols) + else + return X[!, c] + end end # ------------------------------- @@ -124,7 +132,7 @@ function project(t::NamedTuple, indices::AbstractArray{<:Integer}) end # utils for selectrows -typename(X) = split(string(supertype(typeof(X)).name), '.')[end] +typename(X) = split(string(supertype(typeof(X))), '.')[end] isdataframe(X) = typename(X) == "AbstractDataFrame" # ---------------------------------------------------------------- diff --git a/test/interface/data_utils.jl b/test/interface/data_utils.jl index d75cfb47..84d35a5e 100644 --- a/test/interface/data_utils.jl +++ b/test/interface/data_utils.jl @@ -1,3 +1,5 @@ +import DataFrames + rng = StableRNGs.StableRNG(123) @testset "categorical" begin @@ -23,7 +25,7 @@ end b = categorical(["a", "b", "c"]) c = categorical(["a", "b", "c"]; ordered=true) X = (x1=x, x2=z, x3=b, x4=c) - @test MLJModelInterface.scitype(x) == ST.scitype(x) + @test MLJModelInterface.scitype(x) == ST.scitype(x) @test MLJModelInterface.scitype(y) == ST.scitype(y) @test MLJModelInterface.scitype(z) == ST.scitype(z) @test MLJModelInterface.scitype(a) == ST.scitype(a) @@ -39,7 +41,7 @@ end b = categorical(["a", "b", "c"]) c = categorical(["a", "b", "c"]; ordered=true) X = (x1=x, x2=z, x3=b, x4=c) - @test_throws ArgumentError MLJModelInterface.schema(x) + @test_throws ArgumentError MLJModelInterface.schema(x) @test MLJModelInterface.schema(X) == ST.schema(X) end @@ -197,4 +199,13 @@ end @test selectcols(tt, :w) == v end +# https://github.com/JuliaAI/MLJBase.jl/issues/784 +@testset "typename and dataframes" begin + df = DataFrames.DataFrame(x=[1,2,3], y=[2,3,4], z=[4,5,6]) + @test MLJBase.typename(df) == "AbstractDataFrame" + @test MLJBase.isdataframe(df) + @test selectrows(df, 2:3) == df[2:3, :] + @test selectcols(df, [:x, :z]) == df[!, [:x, :z]] +end + true From 12d2ae3e06fadcdd43f98db1aa5f5e89ba84812e Mon Sep 17 00:00:00 2001 From: Thibaut Lienart Date: Fri, 10 Jun 2022 14:56:52 +0200 Subject: [PATCH 2/5] Update src/interface/data_utils.jl Co-authored-by: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> --- src/interface/data_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/data_utils.jl b/src/interface/data_utils.jl index d1dbb791..3517731b 100644 --- a/src/interface/data_utils.jl +++ b/src/interface/data_utils.jl @@ -98,7 +98,7 @@ end function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer}) if !isdataframe(X) - cols = Tables.columntable(X) # named tuple of vectors + cols = Tables.columns(X) return cols[c] else return X[!, c] From abce696198ca639b3761e0bdb7ff9b987619c9e9 Mon Sep 17 00:00:00 2001 From: Thibaut Lienart Date: Fri, 10 Jun 2022 14:57:11 +0200 Subject: [PATCH 3/5] Update src/interface/data_utils.jl Co-authored-by: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> --- src/interface/data_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/data_utils.jl b/src/interface/data_utils.jl index 3517731b..89e01da6 100644 --- a/src/interface/data_utils.jl +++ b/src/interface/data_utils.jl @@ -99,7 +99,7 @@ end function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer}) if !isdataframe(X) cols = Tables.columns(X) - return cols[c] + return Tables.getcolumn(cols, c) else return X[!, c] end From 7f0f2fddaee29516cf8875cdedf1aaec4e7f6010 Mon Sep 17 00:00:00 2001 From: Thibaut Lienart Date: Fri, 10 Jun 2022 14:57:18 +0200 Subject: [PATCH 4/5] Update src/interface/data_utils.jl Co-authored-by: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> --- src/interface/data_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/data_utils.jl b/src/interface/data_utils.jl index 89e01da6..1e7aaa0c 100644 --- a/src/interface/data_utils.jl +++ b/src/interface/data_utils.jl @@ -105,7 +105,7 @@ function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer}) end end -function MMI.selectcols(::FI, ::Val{:table}, X, c::AbstractArray) +function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Colon, AbstractArray}) if !isdataframe(X) cols = Tables.columntable(X) # named tuple of vectors newcols = project(cols, c) From ca69c2f5a2a6e7e4aee4fcba0192722adab9feb2 Mon Sep 17 00:00:00 2001 From: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> Date: Fri, 10 Jun 2022 15:11:26 +0100 Subject: [PATCH 5/5] cleanup code --- src/interface/data_utils.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/interface/data_utils.jl b/src/interface/data_utils.jl index 1e7aaa0c..0362d250 100644 --- a/src/interface/data_utils.jl +++ b/src/interface/data_utils.jl @@ -97,21 +97,17 @@ function MMI.selectrows(::FI, ::Val{:table}, X, r) end function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer}) - if !isdataframe(X) - cols = Tables.columns(X) - return Tables.getcolumn(cols, c) - else - return X[!, c] - end + cols = Tables.columns(X) + return Tables.getcolumn(cols, c) end function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Colon, AbstractArray}) - if !isdataframe(X) + if isdataframe(X) + return X[!, c] + else cols = Tables.columntable(X) # named tuple of vectors newcols = project(cols, c) return Tables.materializer(X)(newcols) - else - return X[!, c] end end