Skip to content

Commit

Permalink
address selectcols issue #991
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jan 5, 2025
1 parent 82b4e7f commit 6b8b899
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
16 changes: 13 additions & 3 deletions src/interface/data_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer})
return Tables.getcolumn(cols, c)
end

function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Colon, AbstractArray})
function MMI.selectcols(::FI, ::Val{:table}, X, c)
if isdataframe(X)
return X[!, c]
else
Expand All @@ -115,18 +115,28 @@ end
# utils for `select`*

# project named tuple onto a tuple with only specified `labels` or indices:
function project(t::NamedTuple, labels::AbstractArray{Symbol})
function project(t::NamedTuple, labels::Union{AbstractArray{Symbol},NTuple{<:Any,Symbol}})
return NamedTuple{tuple(labels...)}(t)
end

project(t::NamedTuple, label::Colon) = t
project(t::NamedTuple, label::Symbol) = project(t, [label,])
project(t::NamedTuple, i::Integer) = project(t, [i,])

function project(t::NamedTuple, indices::AbstractArray{<:Integer})
function project(
t::NamedTuple,
indices::AbstractArray{<:Integer},
)
return NamedTuple{tuple(keys(t)[indices]...)}(tuple([t[i] for i in indices]...))
end

function project(
t::NamedTuple,
indices::NTuple{<:Any,<:Integer},
)
return NamedTuple{tuple(keys(t)[[indices...]]...)}(tuple([t[i] for i in indices]...))
end

# utils for selectrows
typename(X) = split(string(supertype(typeof(X))), '.')[end]
isdataframe(X) = typename(X) == "AbstractDataFrame"
Expand Down
4 changes: 3 additions & 1 deletion test/interface/data_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ end
s = schema(tt)
@test nrows(tt) == N

@test selectcols(tt, 4:6) ==
@test selectcols(tt, 4:6) == selectcols(tt, (4, 5, 6)) ==
selectcols(tt, (:x4, :x5, :z)) ==
selectcols(tt, [:x4, :x5, :z]) ==
selectcols(TypedTables.Table(x4=tt.x4, x5=tt.x5, z=tt.z), :)
@test selectcols(tt, [:x1, :z]) ==
selectcols(TypedTables.Table(x1=tt.x1, z=tt.z), :)
Expand Down

0 comments on commit 6b8b899

Please sign in to comment.