Skip to content

Commit 05f4242

Browse files
committed
Create containers with map instead of for loops
1 parent 8ce7535 commit 05f4242

19 files changed

+536
-418
lines changed

src/Containers/Containers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,7 @@ export DenseAxisArray, SparseAxisArray
2626
include("DenseAxisArray.jl")
2727
include("SparseAxisArray.jl")
2828
include("generate_container.jl")
29+
include("container.jl")
30+
include("macro.jl")
2931

3032
end

src/Containers/SparseAxisArray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# License, v. 2.0. If a copy of the MPL was not distributed with this
44
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
55

6+
include("nested_iterator.jl")
7+
68
"""
79
struct SparseAxisArray{T,N,K<:NTuple{N, Any}} <: AbstractArray{T,N}
810
data::Dict{K,T}

src/Containers/container.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
const ArrayIndices{N} = Base.Iterators.ProductIterator{NTuple{N, Base.OneTo{Int}}}
2+
container(f::Function, indices) = container(f, indices, default_container(indices))
3+
default_container(::ArrayIndices) = Array
4+
function container(f::Function, indices::ArrayIndices, ::Type{Array})
5+
return map(I -> f(I...), indices)
6+
end
7+
default_container(::Base.Iterators.ProductIterator) = DenseAxisArray
8+
function container(f::Function, indices::Base.Iterators.ProductIterator,
9+
::Type{DenseAxisArray})
10+
return DenseAxisArray(map(I -> f(I...), indices), indices.iterators...)
11+
end
12+
default_container(::NestedIterator) = SparseAxisArray
13+
function container(f::Function, indices,
14+
::Type{SparseAxisArray})
15+
mappings = map(I -> I => f(I...), indices)
16+
data = Dict(mappings)
17+
if length(mappings) != length(data)
18+
unique_indices = Set()
19+
duplicate = nothing
20+
for index in indices
21+
if index in unique_indices
22+
duplicate = index
23+
break
24+
end
25+
push!(unique_indices, index)
26+
end
27+
# TODO compute idx
28+
error("Repeated index ", duplicate, ". Index sets must have unique elements.")
29+
end
30+
return SparseAxisArray(Dict(data))
31+
end

src/Containers/macro.jl

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
using Base.Meta
2+
3+
"""
4+
_extract_kw_args(args)
5+
6+
Process the arguments to a macro, separating out the keyword arguments.
7+
Return a tuple of (flat_arguments, keyword arguments, and requested_container),
8+
where `requested_container` is a symbol to be passed to `parse_container`.
9+
"""
10+
function _extract_kw_args(args)
11+
kw_args = filter(x -> isexpr(x, :(=)) && x.args[1] != :container , collect(args))
12+
flat_args = filter(x->!isexpr(x, :(=)), collect(args))
13+
requested_container = :Auto
14+
for kw in args
15+
if isexpr(kw, :(=)) && kw.args[1] == :container
16+
requested_container = kw.args[2]
17+
end
18+
end
19+
return flat_args, kw_args, requested_container
20+
end
21+
22+
function _try_parse_idx_set(arg::Expr)
23+
# [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and
24+
# Expr(:ref, :x, Expr(:kw, :i, 1)) respectively.
25+
if arg.head === :kw || arg.head === :(=)
26+
@assert length(arg.args) == 2
27+
return true, arg.args[1], arg.args[2]
28+
elseif isexpr(arg, :call) && arg.args[1] === :in
29+
return true, arg.args[2], arg.args[3]
30+
else
31+
return false, nothing, nothing
32+
end
33+
end
34+
function _explicit_oneto(index_set)
35+
s = Meta.isexpr(index_set,:escape) ? index_set.args[1] : index_set
36+
if Meta.isexpr(s,:call) && length(s.args) == 3 && s.args[1] == :(:) && s.args[2] == 1
37+
return :(Base.OneTo($index_set))
38+
else
39+
return index_set
40+
end
41+
end
42+
43+
function _expr_is_splat(ex::Expr)
44+
if ex.head == :(...)
45+
return true
46+
elseif ex.head == :escape
47+
return _expr_is_splat(ex.args[1])
48+
end
49+
return false
50+
end
51+
_expr_is_splat(::Any) = false
52+
53+
"""
54+
_parse_ref_sets(expr::Expr)
55+
56+
Helper function for macros to construct container objects. Takes an `Expr` that
57+
specifies the container, e.g. `:(x[i=1:3,[:red,:blue]],k=S; i+k <= 6)`, and
58+
returns:
59+
60+
1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]`
61+
2. `idxsets`: Sets used for indexing, e.g. `[1:3, [:red,:blue], S]`
62+
3. `condition`: Expr containing any conditional imposed on indexing, or `:()` if none is present
63+
"""
64+
function _parse_ref_sets(_error::Function, expr::Expr)
65+
c = copy(expr)
66+
idxvars = Any[]
67+
idxsets = Any[]
68+
# On 0.7, :(t[i;j]) is a :ref, while t[i,j;j] is a :typed_vcat.
69+
# In both cases :t is the first arg.
70+
if isexpr(c, :typed_vcat) || isexpr(c, :ref)
71+
popfirst!(c.args)
72+
end
73+
condition = :()
74+
if isexpr(c, :vcat) || isexpr(c, :typed_vcat)
75+
# Parameters appear as plain args at the end.
76+
if length(c.args) > 2
77+
_error("Unsupported syntax $c.")
78+
elseif length(c.args) == 2
79+
condition = pop!(c.args)
80+
end # else no condition.
81+
elseif isexpr(c, :ref) || isexpr(c, :vect)
82+
# Parameters appear at the front.
83+
if isexpr(c.args[1], :parameters)
84+
if length(c.args[1].args) != 1
85+
_error("Invalid syntax: $c. Multiple semicolons are not " *
86+
"supported.")
87+
end
88+
condition = popfirst!(c.args).args[1]
89+
end
90+
end
91+
if isexpr(c, :vcat) || isexpr(c, :typed_vcat) || isexpr(c, :ref)
92+
if isexpr(c.args[1], :parameters)
93+
@assert length(c.args[1].args) == 1
94+
condition = popfirst!(c.args).args[1]
95+
end # else no condition.
96+
end
97+
98+
for s in c.args
99+
parse_done = false
100+
if isa(s, Expr)
101+
parse_done, idxvar, _idxset = _try_parse_idx_set(s::Expr)
102+
if parse_done
103+
idxset = esc(_idxset)
104+
end
105+
end
106+
if !parse_done # No index variable specified
107+
idxvar = gensym()
108+
idxset = esc(s)
109+
end
110+
push!(idxvars, idxvar)
111+
push!(idxsets, idxset)
112+
end
113+
return idxvars, idxsets, condition
114+
end
115+
_parse_ref_sets(_error::Function, expr) = (Any[], Any[], :())
116+
117+
"""
118+
_build_ref_sets(_error::Function, expr)
119+
120+
Helper function for macros to construct container objects. Takes an `Expr` that
121+
specifies the container, e.g. `:(x[i=1:3,[:red,:blue]],k=S; i+k <= 6)`, and
122+
returns:
123+
124+
1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]`
125+
2. `idxsets`: Sets used for indexing, e.g. `[1:3, [:red,:blue], S]`
126+
3. `condition`: Expr containing any conditional imposed on indexing, or `:()` if none is present
127+
"""
128+
function _build_ref_sets(_error::Function, expr)
129+
idxvars, idxsets, condition = _parse_ref_sets(_error, expr)
130+
if any(_expr_is_splat.(idxsets))
131+
_error("cannot use splatting operator `...` in the definition of an index set.")
132+
end
133+
has_dependent = has_dependent_sets(idxvars, idxsets)
134+
if has_dependent || condition != :()
135+
esc_idxvars = esc.(idxvars)
136+
idxfuns = [:(($(esc_idxvars[1:(i - 1)]...),) -> $(idxsets[i])) for i in 1:length(idxvars)]
137+
if condition == :()
138+
indices = :(Containers.NestedIterator(($(idxfuns...),)))
139+
else
140+
condition_fun = :(($(esc_idxvars...),) -> $(esc(condition)))
141+
indices = :(Containers.NestedIterator(($(idxfuns...),), $condition_fun))
142+
end
143+
else
144+
indices = :(Base.Iterators.product(($(_explicit_oneto.(idxsets)...))))
145+
end
146+
return idxvars, indices
147+
end
148+
149+
function container_code(idxvars, indices, code, requested_container)
150+
if isempty(idxvars)
151+
return code
152+
end
153+
if !(requested_container in [:Auto, :Array, :DenseAxisArray, :SparseAxisArray])
154+
# We do this two-step interpolation, first into the string, and then
155+
# into the expression because interpolating into a string inside an
156+
# expression has scoping issues.
157+
error_message = "Invalid container type $requested_container. Must be " *
158+
"Auto, Array, DenseAxisArray, or SparseAxisArray."
159+
return :(error($error_message))
160+
end
161+
if requested_container == :DenseAxisArray
162+
requested_container = :(JuMP.Containers.DenseAxisArray)
163+
elseif requested_container == :SparseAxisArray
164+
requested_container = :(JuMP.Containers.SparseAxisArray)
165+
end
166+
esc_idxvars = esc.(idxvars)
167+
func = :(($(esc_idxvars...),) -> $code)
168+
if requested_container == :Auto
169+
return :(Containers.container($func, $indices))
170+
else
171+
return :(Containers.container($func, $indices, $requested_container))
172+
end
173+
end
174+
function parse_container(_error, var, value, requested_container)
175+
idxvars, indices = _build_ref_sets(_error, var)
176+
return container_code(idxvars, indices, value, requested_container)
177+
end
178+
179+
macro container(args...)
180+
args, kw_args, requested_container = _extract_kw_args(args)
181+
@assert length(args) == 2
182+
@assert isempty(kw_args)
183+
var, value = args
184+
name = var.args[1]
185+
code = parse_container(error, var, esc(value), requested_container)
186+
return :($(esc(name)) = $code)
187+
end

src/Containers/nested_iterator.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
struct NestedIterator{T}
2+
iterators::T # Tuple of functions
3+
condition::Function
4+
end
5+
NestedIterator(iterator) = NestedIterator(iterator, (args...) -> true)
6+
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown()
7+
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown()
8+
function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state)
9+
if elem_state === nothing
10+
return nothing
11+
end
12+
elem, state = elem_state
13+
elems_states = first_iterate(
14+
it, i + 1, (elems..., elem),
15+
(states..., (iterator, state, elem)))
16+
if elems_states !== nothing
17+
return elems_states
18+
end
19+
return next_iterate(it, i, elems, states, iterator, iterate(iterator, state))
20+
end
21+
function first_iterate(it::NestedIterator, i, elems, states)
22+
if i > length(it.iterators)
23+
if it.condition(elems...)
24+
return elems, states
25+
else
26+
return nothing
27+
end
28+
end
29+
iterator = it.iterators[i](elems...)
30+
return next_iterate(it, i, elems, states, iterator, iterate(iterator))
31+
end
32+
function tail_iterate(it::NestedIterator, i, elems, states)
33+
if i > length(it.iterators)
34+
return nothing
35+
end
36+
next = tail_iterate(it, i + 1, (elems..., states[i][3]), states)
37+
if next !== nothing
38+
return next
39+
end
40+
iterator = states[i][1]
41+
next_iterate(it, i, elems, states[1:(i - 1)], iterator, iterate(iterator, states[i][2]))
42+
end
43+
Base.iterate(it::NestedIterator) = first_iterate(it, 1, tuple(), tuple())
44+
Base.iterate(it::NestedIterator, states) = tail_iterate(it, 1, tuple(), states)

src/JuMP.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import ForwardDiff
2323
include("_Derivatives/_Derivatives.jl")
2424
using ._Derivatives
2525

26+
include("Containers/Containers.jl")
27+
2628
# Exports are at the end of the file.
2729

2830
# Deprecations for JuMP v0.18 -> JuMP v0.19 transition
@@ -460,7 +462,7 @@ end
460462
"""
461463
set_time_limit_sec(model::Model, limit)
462464
463-
Sets the time limit (in seconds) of the solver.
465+
Sets the time limit (in seconds) of the solver.
464466
Can be unset using `unset_time_limit_sec` or with `limit` set to `nothing`.
465467
"""
466468
function set_time_limit_sec(model::Model, limit)
@@ -768,7 +770,6 @@ struct NonlinearParameter <: AbstractJuMPScalar
768770
end
769771

770772
include("copy.jl")
771-
include("Containers/Containers.jl")
772773
include("operators.jl")
773774
include("macros.jl")
774775
include("optimizer_interface.jl")

0 commit comments

Comments
 (0)