Skip to content

Commit

Permalink
Merge pull request #5 from Keno/kf/size
Browse files Browse the repository at this point in the history
Allow unrolling on size(x, dim)
  • Loading branch information
cstjean authored Feb 25, 2019
2 parents d5bde5c + 44a538a commit 37df04f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/Unrolled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ function unrolled_filter end
for sequence-like types, obviously. """
function type_length end

"""
`type_size(::Type, dim)` returns the size of an array in the specified dimension.
"""
function type_size end


include("range.jl")

const expansion_funs = Dict{Function, Function}()
Expand Down Expand Up @@ -52,6 +58,7 @@ end
type_length(tup::Type{T}) where {T<:Tuple} = length(tup.parameters)
# Default fall-back
type_length(typ::Type) = length(typ)
type_size(typ::Type, i) = size(typ, i)

""" `function_argument_name(arg_expr)`
Expand Down Expand Up @@ -80,6 +87,13 @@ macro unroll(fundef)
"Can only unroll a loop over one of the function's arguments")
return Expr(:($), Expr(:call, :($Unrolled.type_length), seq_var))
end
function seq_type_size(seq_var, dim)
@assert(seq_var in all_args,
"Can only unroll a loop over one of the function's arguments")
@assert(dim isa Integer,
"Dimension argument must be an integer")
return Expr(:($), Expr(:call, :($Unrolled.type_size), seq_var, dim))
end
process(x) = x
function process(expr::Expr)
if expr.args[1]==Symbol("@unroll")
Expand All @@ -88,6 +102,9 @@ macro unroll(fundef)
for var_ in 1:length(seq_) loopbody__ end =>
:($Unrolled.@unroll_loop(for $var in 1:$(seq_type_length(seq));
$(loopbody...) end))
for var_ in 1:size(seq_, dim_) loopbody__ end =>
:($Unrolled.@unroll_loop(for $var in 1:$(seq_type_size(seq, dim));
$(loopbody...) end))
for var_ in seq_ loopbody__ end =>
:($Unrolled.@unroll_loop($(seq_type(seq)),
for $var in $seq; $(loopbody...) end))
Expand Down
15 changes: 15 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,18 @@ end
end
return log_scale ? exp.(state) : state
end

# Unrolling with a sized argument
struct CartesianIndexSpace{dims} <: AbstractArray{Int, 2}; end
Base.size(cis::Type{CartesianIndexSpace{dims}}) where {dims} = dims
Base.size(cis::Type{<:CartesianIndexSpace}, i::Integer) = size(cis)[i]
Base.size(cis::CartesianIndexSpace) = size(typeof(cis))

@unroll function do_count(cis)
n = 0
@unroll for i = 1:size(cis, 2)
n += 1
end
n
end
@test do_count(CartesianIndexSpace{(1,4)}()) == 4

0 comments on commit 37df04f

Please sign in to comment.