diff --git a/src/Unrolled.jl b/src/Unrolled.jl index 2f9afc0..1c62ca9 100644 --- a/src/Unrolled.jl +++ b/src/Unrolled.jl @@ -14,6 +14,12 @@ function unrolled_filter end for sequence-like types, obviously. """ function type_length end +""" +`type_size(::Type, dim)` returns the size of an array in the specified dimension. +""" +function type_size end + + include("range.jl") const expansion_funs = Dict{Function, Function}() @@ -52,6 +58,7 @@ end type_length(tup::Type{T}) where {T<:Tuple} = length(tup.parameters) # Default fall-back type_length(typ::Type) = length(typ) +type_size(typ::Type, i) = size(typ, i) """ `function_argument_name(arg_expr)` @@ -80,6 +87,13 @@ macro unroll(fundef) "Can only unroll a loop over one of the function's arguments") return Expr(:($), Expr(:call, :($Unrolled.type_length), seq_var)) end + function seq_type_size(seq_var, dim) + @assert(seq_var in all_args, + "Can only unroll a loop over one of the function's arguments") + @assert(dim isa Integer, + "Dimension argument must be an integer") + return Expr(:($), Expr(:call, :($Unrolled.type_size), seq_var, dim)) + end process(x) = x function process(expr::Expr) if expr.args[1]==Symbol("@unroll") @@ -88,6 +102,9 @@ macro unroll(fundef) for var_ in 1:length(seq_) loopbody__ end => :($Unrolled.@unroll_loop(for $var in 1:$(seq_type_length(seq)); $(loopbody...) end)) + for var_ in 1:size(seq_, dim_) loopbody__ end => + :($Unrolled.@unroll_loop(for $var in 1:$(seq_type_size(seq, dim)); + $(loopbody...) end)) for var_ in seq_ loopbody__ end => :($Unrolled.@unroll_loop($(seq_type(seq)), for $var in $seq; $(loopbody...) end)) diff --git a/test/runtests.jl b/test/runtests.jl index 5effcd1..bfd09f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -106,3 +106,18 @@ end end return log_scale ? exp.(state) : state end + +# Unrolling with a sized argument +struct CartesianIndexSpace{dims} <: AbstractArray{Int, 2}; end +Base.size(cis::Type{CartesianIndexSpace{dims}}) where {dims} = dims +Base.size(cis::Type{<:CartesianIndexSpace}, i::Integer) = size(cis)[i] +Base.size(cis::CartesianIndexSpace) = size(typeof(cis)) + +@unroll function do_count(cis) + n = 0 + @unroll for i = 1:size(cis, 2) + n += 1 + end + n +end +@test do_count(CartesianIndexSpace{(1,4)}()) == 4