Skip to content

Commit

Permalink
Refactor expression handling (#292)
Browse files Browse the repository at this point in the history
* refactor

* better error

* add warning

* remove redundant stuff

* eachrow_replace rewrite

* new tests for eachrow

* argument order membernames

* another Base.Generator

* add test set for get column expression

* add warning back

* fix test

* fixes

* use depwarn

* Apply suggestions from code review

Co-authored-by: Milan Bouchet-Valat <[email protected]>

Co-authored-by: Milan Bouchet-Valat <[email protected]>
  • Loading branch information
pdeffebach and nalimilan committed Sep 11, 2021
1 parent 9b696ed commit 6ba85a7
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 161 deletions.
46 changes: 26 additions & 20 deletions src/eachrow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,41 @@
##
##############################################################################


# Recursive function that traverses the syntax tree of e, replaces instances of
# ":(:(x))" with ":x[row]".
eachrow_replace(x) = x
eachrow_replace(e::QuoteNode) = Expr(:ref, e, :row)

function eachrow_replace(e::Expr)
# Traverse the syntax tree of e
if onearg(e, :cols)
@warn "cols(x) is deprecated, use \$x instead"
# cols(:x) becomes cols(:x)[row]
return Expr(:ref, Expr(:call, :cols, e.args[2]), :row)
elseif is_column_expr(e)
return Expr(:ref, Expr(:$, e.args[1]), :row)
if onearg(e, :^)
return e.args[2]
end

if e.head == :.
if e.args[1] isa QuoteNode
e.args[1] = Expr(:ref, e.args[1], :row)
return e
else
return e
end
# Traverse the syntax tree of e
col = get_column_expr(e)
if col !== nothing
return :($e[row])
# equivalent to protect_replace_syms
elseif e.head == :.
x_new = eachrow_replace(e.args[1])
y = e.args[2]
y_new = y isa Expr ? eachrow_replace(y) : y

return Expr(:., x_new, y_new)
else
mapexpr(eachrow_replace, e)
end

Expr(e.head, (isempty(e.args) ? e.args : map(eachrow_replace, e.args))...)
end

eachrow_replace(e::QuoteNode) = Expr(:ref, e, :row)
protect_eachrow_replace(e) = e
protect_eachrow_replace(e::Expr) = eachrow_replace(e)

# Set the base case for helper, i.e. for when expand hits an object of type
# other than Expr (generally a Symbol or a literal).
eachrow_replace(x) = x
function eachrow_replace_dotted(e, membernames)
x_new = eachrow_repalce(e.args[1])
y_new = protect_eachrow_repalce(e.args[2])
Expr(:., x_new, y_new)
end

function eachrow_find_newcols(e::Expr, newcol_decl)
if e.head == :macrocall && e.args[1] == Symbol("@newcol")
Expand Down
4 changes: 2 additions & 2 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ function combine_helper(x, args...; deprecation_warning = false)

fe = first(exprs)
if length(exprs) == 1 &&
!(fe isa QuoteNode || onearg(fe, :cols) || is_column_expr(fe)) &&
get_column_expr(fe) === nothing &&
!(fe.head == :(=) || fe.head == :kw)

@warn "Returning a Table object from @by and @combine now requires `$(DOLLAR)AsTable` on the LHS."
Expand Down Expand Up @@ -1668,7 +1668,7 @@ function by_helper(x, what, args...)
exprs, outer_flags = create_args_vector(args...)
fe = first(exprs)
if length(exprs) == 1 &&
!(fe isa QuoteNode || onearg(fe, :cols) || is_column_expr(fe)) &&
get_column_expr(fe) === nothing &&
!(fe.head == :(=) || fe.head == :kw)

@warn "Returning a Table object from @by and @combine now requires `\$AsTable` on the LHS."
Expand Down
220 changes: 81 additions & 139 deletions src/parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,62 @@ end
onearg(e::Expr, f) = e.head == :call && length(e.args) == 2 && e.args[1] == f
onearg(e, f) = false

is_column_expr(x) = false
is_column_expr(e::Expr) = e.head == :$
"""
get_column_expr(x)
If the input is a valid column identifier, i.e.
a `QuoteNode` or an expression beginning with
`$DOLLAR`, returns the underlying identifier.
If input is not a valid column identifier,
returns `nothing`.
"""
get_column_expr(x) = nothing
function get_column_expr(e::Expr)
e.head == :$ && return e.args[1]
if onearg(e, :cols)
Base.depwarn("cols is deprecated use $DOLLAR to escape column names instead", :cols)
return e.args[2]
end
return nothing
end
get_column_expr(x::QuoteNode) = x

mapexpr(f, e) = Expr(e.head, map(f, e.args)...)
mapexpr(f, e) = Expr(e.head, Base.Generator(f, e.args)...)

replace_syms!(x, membernames) = x
replace_syms!(q::QuoteNode, membernames) =
replace_syms!(Meta.quot(q.value), membernames)
function replace_syms!(e::Expr, membernames)
replace_syms!(membernames, x) = x
replace_syms!(membernames, q::QuoteNode) = addkey!(membernames, q)

function replace_syms!(membernames, e::Expr)
if onearg(e, :^)
e.args[2]
elseif onearg(e, :_I_)
@warn "_I_() for escaping variables is deprecated, use cols() instead"
addkey!(membernames, :($(e.args[2])))
elseif onearg(e, :cols)
@warn "cols(x) for escaping variables is deprecated, use \$x instead"
addkey!(membernames, :($(e.args[2])))
elseif is_column_expr(e)
addkey!(membernames, :($(e.args[1])))
elseif e.head == :quote
addkey!(membernames, Meta.quot(e.args[1]) )
return e.args[2]
end

col = get_column_expr(e)
if col !== nothing
return addkey!(membernames, col)
elseif e.head == :.
replace_dotted!(e, membernames)
return replace_dotted!(membernames, e)
else
mapexpr(x -> replace_syms!(x, membernames), e)
return mapexpr(x -> replace_syms!(membernames, x), e)
end
end

protect_replace_syms!(membernames, e) = e
protect_replace_syms!(membernames, e::Expr) = replace_syms!(membernames, e)

function replace_dotted!(membernames, e)
x_new = replace_syms!(membernames, e.args[1])
y_new = protect_replace_syms!(membernames, e.args[2])
Expr(:., x_new, y_new)
end

is_simple_non_broadcast_call(x) = false
function is_simple_non_broadcast_call(expr::Expr)
expr.head == :call &&
length(expr.args) >= 2 &&
expr.args[1] isa Symbol &&
all(x -> x isa QuoteNode || onearg(x, :cols) || is_column_expr(x), expr.args[2:end])
all(a -> get_column_expr(a) !== nothing, expr.args[2:end])
end

is_simple_broadcast_call(x) = false
Expand All @@ -51,21 +73,14 @@ function is_simple_broadcast_call(expr::Expr)
expr.args[1] isa Symbol &&
expr.args[2] isa Expr &&
expr.args[2].head == :tuple &&
all(x -> x isa QuoteNode || onearg(x, :cols) || is_column_expr(x), expr.args[2].args)
all(a -> get_column_expr(a) !== nothing, expr.args[2].args)
end

function args_to_selectors(v)
t = map(v) do arg
if arg isa QuoteNode
arg
elseif onearg(arg, :cols)
@warn "cols(x) is deprecated, use \$x instead"
arg.args[2]
elseif is_column_expr(arg)
arg.args[1]
else
throw(ArgumentError("This path should not be reached, arg: $(arg)"))
end
t = Base.Generator(v) do arg
col = get_column_expr(arg)
col === nothing && throw(ArgumentError("This path should not be reached, arg: $(arg)"))
col
end

:(DataFramesMeta.make_source_concrete($(Expr(:vect, t...))))
Expand Down Expand Up @@ -190,11 +205,9 @@ function get_source_fun(function_expr; exprflags = deepcopy(DEFAULT_FLAGS))
else
membernames = Dict{Any, Symbol}()

body = replace_syms!(function_expr, membernames)

body = replace_syms!(membernames, function_expr)
source = :(DataFramesMeta.make_source_concrete($(Expr(:vect, keys(membernames)...))))
inputargs = Expr(:tuple, values(membernames)...)

fun = quote
$inputargs -> begin
$body
Expand Down Expand Up @@ -245,121 +258,65 @@ function fun_to_vec(ex::Expr;
check_macro_flags_consistency(final_flags)

if gensym_names
ex = Expr(:kw, gensym(), ex)
ex = Expr(:kw, QuoteNode(gensym()), ex)
end

# :x
# handled below via dispatch on ::QuoteNode

# Fix any references to `cols` and replace them
# with $
if onearg(ex, :cols)
ex = Expr(:$, ex.args[2])
end

# $:x
if is_column_expr(ex)
return ex.args[1]
ex_col = get_column_expr(ex)
if ex_col !== nothing
return ex_col
end

if no_dest
source, fun = get_source_fun(ex, exprflags = final_flags)
src, fun = get_source_fun(ex, exprflags = final_flags)
return quote
$source => $fun
$src => $fun
end
end

@assert ex.head == :kw || ex.head == :(=)
lhs = ex.args[1]

# fix cols
if onearg(lhs, :cols)
lhs = Expr(:$, lhs.args[2])
end

rhs = MacroTools.unblock(ex.args[2])

if onearg(rhs, :cols)
@warn "cols(x) is deprecated, use \$x instead"
rhs = Expr(:$, rhs.args[2])
end

if is_macro_head(rhs, "@byrow")
s = "In keyword argument inputs, `@byrow` must be on the left hand side. " *
"Did you write `y = @byrow f(:x)` instead of `@byrow y = f(:x)`?"
throw(ArgumentError(s))
end

# y = ...
if lhs isa Symbol
msg = "Using an un-quoted Symbol on the LHS is deprecated. " *
"Write $(QuoteNode(lhs)) = ... instead."
Base.depwarn(msg, "")
lhs = QuoteNode(lhs)
if !(ex.head == :kw || ex.head == :(=))
throw(ArgumentError("Malformed expression in DataFramesMeta.jl macro"))
end

# :y = :x
if lhs isa QuoteNode && rhs isa QuoteNode
source = rhs
dest = lhs
lhs = let t = ex.args[1]
if t isa Symbol
t = QuoteNode(t)
msg = "Using an un-quoted Symbol on the LHS is deprecated. " *
"Write $t = ... instead."

return quote
$source => $dest
@warn msg
end
end

# :y = $:x
if lhs isa QuoteNode && is_column_expr(rhs)
source = rhs.args[1]
dest = lhs

return quote
$source => $dest
s = get_column_expr(t)
if s === nothing
throw(ArgumentError("Malformed expression oh LHS in DataFramesMeta.jl macro"))
end
end

# $:y = :x
if is_column_expr(lhs) && rhs isa QuoteNode
source = rhs
dest = lhs.args[1]

return quote
$source => $dest
end
s
end

# $:y = $:x
if is_column_expr(lhs) && is_column_expr(rhs)
source = rhs.args[1]
dest = lhs.args[1]
return quote
$source => $dest
end
end

# :y = f(:x)
# :y = f($:x)
# :y = :x + 1
# :y = $:x + 1
source, fun = get_source_fun(rhs; exprflags = final_flags)
if lhs isa QuoteNode
rhs = MacroTools.unblock(ex.args[2])
rhs_col = get_column_expr(rhs)
if rhs_col !== nothing
src = rhs_col
dest = lhs
return quote
$source => $fun => $dest
end
return :($src => $dest)
end

# $:y = f(:x)
if is_column_expr(lhs)
dest = lhs.args[1]

return quote
$source => $fun => $dest
end
if is_macro_head(rhs, "@byrow") || is_macro_head(rhs, "@passmissing")
s = "In keyword argument inputs, `@byrow` and `@passmissing`" *
"must be on the left hand side. " *
"Did you write `y = @byrow f(:x)` instead of `@byrow y = f(:x)`?"
throw(ArgumentError(s))
end

throw(ArgumentError("This path should not be reached"))
dest = lhs
src, fun = get_source_fun(rhs; exprflags = final_flags)
return :($src => $fun => $dest)
end

fun_to_vec(ex::QuoteNode;
no_dest::Bool=false,
gensym_names::Bool=false,
Expand All @@ -376,21 +333,6 @@ function make_source_concrete(x::AbstractVector)
end
end

protect_replace_syms!(e, membernames) = e
function protect_replace_syms!(e::Expr, membernames)
if e.head == :quote
e
else
replace_syms!(e, membernames)
end
end

function replace_dotted!(e, membernames)
x_new = replace_syms!(e.args[1], membernames)
y_new = protect_replace_syms!(e.args[2], membernames)
Expr(:., x_new, y_new)
end

function create_args_vector(args...; wrap_byrow::Bool=false)
create_args_vector(Expr(:block, args...); wrap_byrow = wrap_byrow)
end
Expand Down
Loading

0 comments on commit 6ba85a7

Please sign in to comment.