diff --git a/Project.toml b/Project.toml index 727b7685..4f1bc639 100644 --- a/Project.toml +++ b/Project.toml @@ -49,6 +49,7 @@ Tables = "0.2, 1.0" julia = "1.6" [extras] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -59,4 +60,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" [targets] -test = ["DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"] +test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"] diff --git a/src/interface/data_utils.jl b/src/interface/data_utils.jl index 715410e8..0362d250 100644 --- a/src/interface/data_utils.jl +++ b/src/interface/data_utils.jl @@ -97,14 +97,18 @@ function MMI.selectrows(::FI, ::Val{:table}, X, r) end function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer}) - cols = Tables.columntable(X) # named tuple of vectors - return cols[c] + cols = Tables.columns(X) + return Tables.getcolumn(cols, c) end -function MMI.selectcols(::FI, ::Val{:table}, X, c::AbstractArray) - cols = Tables.columntable(X) # named tuple of vectors - newcols = project(cols, c) - return Tables.materializer(X)(newcols) +function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Colon, AbstractArray}) + if isdataframe(X) + return X[!, c] + else + cols = Tables.columntable(X) # named tuple of vectors + newcols = project(cols, c) + return Tables.materializer(X)(newcols) + end end # ------------------------------- @@ -124,7 +128,7 @@ function project(t::NamedTuple, indices::AbstractArray{<:Integer}) end # utils for selectrows -typename(X) = split(string(supertype(typeof(X)).name), '.')[end] +typename(X) = split(string(supertype(typeof(X))), '.')[end] isdataframe(X) = typename(X) == "AbstractDataFrame" # ---------------------------------------------------------------- diff --git a/test/interface/data_utils.jl b/test/interface/data_utils.jl index d75cfb47..84d35a5e 100644 --- a/test/interface/data_utils.jl +++ b/test/interface/data_utils.jl @@ -1,3 +1,5 @@ +import DataFrames + rng = StableRNGs.StableRNG(123) @testset "categorical" begin @@ -23,7 +25,7 @@ end b = categorical(["a", "b", "c"]) c = categorical(["a", "b", "c"]; ordered=true) X = (x1=x, x2=z, x3=b, x4=c) - @test MLJModelInterface.scitype(x) == ST.scitype(x) + @test MLJModelInterface.scitype(x) == ST.scitype(x) @test MLJModelInterface.scitype(y) == ST.scitype(y) @test MLJModelInterface.scitype(z) == ST.scitype(z) @test MLJModelInterface.scitype(a) == ST.scitype(a) @@ -39,7 +41,7 @@ end b = categorical(["a", "b", "c"]) c = categorical(["a", "b", "c"]; ordered=true) X = (x1=x, x2=z, x3=b, x4=c) - @test_throws ArgumentError MLJModelInterface.schema(x) + @test_throws ArgumentError MLJModelInterface.schema(x) @test MLJModelInterface.schema(X) == ST.schema(X) end @@ -197,4 +199,13 @@ end @test selectcols(tt, :w) == v end +# https://github.com/JuliaAI/MLJBase.jl/issues/784 +@testset "typename and dataframes" begin + df = DataFrames.DataFrame(x=[1,2,3], y=[2,3,4], z=[4,5,6]) + @test MLJBase.typename(df) == "AbstractDataFrame" + @test MLJBase.isdataframe(df) + @test selectrows(df, 2:3) == df[2:3, :] + @test selectcols(df, [:x, :z]) == df[!, [:x, :z]] +end + true