Skip to content

Commit 3a7fa95

Browse files
committed
Fix for iterators with shape
1 parent c771e5d commit 3a7fa95

File tree

7 files changed

+52
-12
lines changed

7 files changed

+52
-12
lines changed

src/Containers/Containers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ export DenseAxisArray, SparseAxisArray
2626
include("DenseAxisArray.jl")
2727
include("SparseAxisArray.jl")
2828
include("generate_container.jl")
29+
include("vectorized_product_iterator.jl")
30+
include("nested_iterator.jl")
2931
include("container.jl")
3032
include("macro.jl")
3133

src/Containers/SparseAxisArray.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
# License, v. 2.0. If a copy of the MPL was not distributed with this
44
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
55

6-
include("nested_iterator.jl")
7-
86
"""
97
struct SparseAxisArray{T,N,K<:NTuple{N, Any}} <: AbstractArray{T,N}
108
data::Dict{K,T}

src/Containers/container.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const ArrayIndices{N} = Iterators.ProductIterator{NTuple{N, Base.OneTo{Int}}}
1+
const ArrayIndices{N} = VectorizedProductIterator{NTuple{N, Base.OneTo{Int}}}
22
container(f::Function, indices) = container(f, indices, default_container(indices))
33
default_container(::ArrayIndices) = Array
44
function container(f::Function, indices::ArrayIndices, ::Type{Array})
@@ -10,14 +10,14 @@ function _oneto(indices)
1010
end
1111
error("Index set for array is not one-based interval.")
1212
end
13-
function container(f::Function, indices::Iterators.ProductIterator,
13+
function container(f::Function, indices::VectorizedProductIterator,
1414
::Type{Array})
15-
container(f, Iterators.ProductIterator(_oneto.(indices.iterators)), Array)
15+
container(f, vectorized_product(_oneto.(indices.prod.iterators)...), Array)
1616
end
17-
default_container(::Iterators.ProductIterator) = DenseAxisArray
18-
function container(f::Function, indices::Iterators.ProductIterator,
17+
default_container(::VectorizedProductIterator) = DenseAxisArray
18+
function container(f::Function, indices::VectorizedProductIterator,
1919
::Type{DenseAxisArray})
20-
return DenseAxisArray(map(I -> f(I...), indices), indices.iterators...)
20+
return DenseAxisArray(map(I -> f(I...), indices), indices.prod.iterators...)
2121
end
2222
default_container(::NestedIterator) = SparseAxisArray
2323
function container(f::Function, indices,

src/Containers/macro.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,13 @@ function _build_ref_sets(_error::Function, expr)
135135
esc_idxvars = esc.(idxvars)
136136
idxfuns = [:(($(esc_idxvars[1:(i - 1)]...),) -> $(idxsets[i])) for i in 1:length(idxvars)]
137137
if condition == :()
138-
indices = :(Containers.NestedIterator(($(idxfuns...),)))
138+
indices = :(Containers.nested($(idxfuns...)))
139139
else
140140
condition_fun = :(($(esc_idxvars...),) -> $(esc(condition)))
141-
indices = :(Containers.NestedIterator(($(idxfuns...),), $condition_fun))
141+
indices = :(Containers.nested($(idxfuns...); condition = $condition_fun))
142142
end
143143
else
144-
indices = :(Base.Iterators.product(($(_explicit_oneto.(idxsets)...))))
144+
indices = :(Containers.vectorized_product($(_explicit_oneto.(idxsets)...)))
145145
end
146146
return idxvars, indices
147147
end

src/Containers/nested_iterator.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ struct NestedIterator{T}
2323
iterators::T # Tuple of functions
2424
condition::Function
2525
end
26-
NestedIterator(iterator) = NestedIterator(iterator, (args...) -> true)
26+
function nested(iterators...; condition = (args...) -> true)
27+
return NestedIterator(iterators, condition)
28+
end
2729
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown()
2830
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown()
2931
function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
struct VectorizedProductIterator{T}
3+
prod::Iterators.ProductIterator{T}
4+
end
5+
6+
Same as `Base.Iterators.ProductIterator` except that it is independent
7+
on the `IteratorSize` of the elements of `prod.iterators`.
8+
For instance:
9+
* `size(Iterators.product(1, 2))` is `tuple()` while
10+
`size(VectorizedProductIterator(1, 2))` is `(1, 1)`.
11+
* `size(Iterators.product(ones(2, 3)))` is `(2, 3)` while
12+
`size(VectorizedProductIterator(ones(2, 3)))` is `(1, 1)`.
13+
"""
14+
struct VectorizedProductIterator{T}
15+
prod::Iterators.ProductIterator{T}
16+
end
17+
function vectorized_product(iterators...)
18+
return VectorizedProductIterator(Iterators.product(iterators...))
19+
end
20+
function Base.IteratorSize(::Type{<:VectorizedProductIterator{<:Tuple{Vararg{Any, N}}}}) where N
21+
return Base.HasShape{N}()
22+
end
23+
Base.IteratorEltype(::Type{<:VectorizedProductIterator}) = Base.EltypeUnknown()
24+
Base.size(it::VectorizedProductIterator) = _prod_size(it.prod.iterators)
25+
_prod_size(::Tuple{}) = ()
26+
_prod_size(t::Tuple) = (length(t[1]), _prod_size(Base.tail(t))...)
27+
Base.axes(it::VectorizedProductIterator) = _prod_indices(it.prod.iterators)
28+
_prod_indices(::Tuple{}) = ()
29+
_prod_indices(t::Tuple) = (Base.OneTo(length(t[1])), _prod_indices(Base.tail(t))...)
30+
Base.ndims(it::VectorizedProductIterator) = length(axes(it))
31+
Base.length(it::VectorizedProductIterator) = prod(size(it))
32+
Base.iterate(it::VectorizedProductIterator, args...) = iterate(it.prod, args...)

test/Containers/macro.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ using JuMP.Containers
1414
@test x isa Containers.DenseAxisArray{Int, 1}
1515
Containers.@container(x[i = 2:3, j = 1:2], i + j)
1616
@test x isa Containers.DenseAxisArray{Int, 2}
17+
Containers.@container(x[4], 0.0)
18+
@test x isa Containers.DenseAxisArray{Float64, 1}
19+
Containers.@container(x[4, 5], 0)
20+
@test x isa Containers.DenseAxisArray{Int, 2}
21+
Containers.@container(x[4, 1:3, 5], 0)
22+
@test x isa Containers.DenseAxisArray{Int, 3}
1723
end
1824
@testset "SparseAxisArray" begin
1925
Containers.@container(x[i = 1:3, j = 1:i], i + j)

0 commit comments

Comments
 (0)