Skip to content

Commit

Permalink
Iterator APIs (via ResumableFunctions.jl) (#121)
Browse files Browse the repository at this point in the history
* Make recom and flip chains an iterable a la GerryChain Python

Although this is a breaking change, this brings GerryChain Julia
more in line with GerryChain Python and allows the user to save
and export state more easily. This API is different from the one
proposed in #94 in that it returns a Partition object and the
scores.

* Add recom_chain and flip_chain functions for backwards compat.

The iterators are now exposed in `recom_chain_iter` and
`flip_chain_iter` respectively to avoid any breaking changes.

* Update type signatures of the iterator APIs

* Point to issue on GitHub with the bug in ResumableFunctions.jl

* Export iter API

* Add note about `Partition` being mutable and changing in-place
  • Loading branch information
InnovativeInventor authored Aug 3, 2021
1 parent 6893a4f commit bee809d
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 21 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184"
Shapefile = "8e980c4a-a4fe-5da2-b3a7-4b4b0353a2f4"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
3 changes: 3 additions & 0 deletions src/GerryChain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Shapefile
import LibGEOS
import LibSpatialIndex
using Logging
using ResumableFunctions

export AbstractGraph,
BaseGraph,
Expand Down Expand Up @@ -42,9 +43,11 @@ export AbstractGraph,
# recom
update_partition!,
recom_chain,
recom_chain_iter,

# flip
flip_chain,
flip_chain_iter,

# scores
DistrictAggregate,
Expand Down
81 changes: 71 additions & 10 deletions src/flip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,19 @@ function update_partition!(
end

"""
flip_chain(graph::BaseGraph,
flip_chain_iter(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
cont_constraint::ContiguityConstraint,
num_steps::Int,
scores::Array{S, 1};
acceptance_fn::F=always_accept,
no_self_loops::Bool=false)::ChainScoreData where {F<:Function, S<:AbstractScore}
no_self_loops::Bool=false) where {F<:Function,S<:AbstractScore}
Runs a Markov Chain for `num_steps` steps using Flip proposals. Returns
a `ChainScoreData` object which can be queried to retrieve the values of
every score at each step of the chain.
an iterator of `(Partition, score_vals)`. Note that `Partition` is mutable and
will change in-place with each iteration -- a `deepcopy()` is needed if you wish
to interact with the `Partition` object outside of the for loop.
*Arguments:*
- graph: `BaseGraph`
Expand All @@ -149,8 +150,10 @@ every score at each step of the chain.
function is satisfied. BEWARE - this can create
infinite loops if the acceptance function is never
satisfied!
- progress_bar If this is true, a progress bar will be printed to stdout.
"""
function flip_chain(
function flip_chain_iter end # this is a workaround (https://github.com/BenLauwens/ResumableFunctions.jl/issues/45)
@resumable function flip_chain_iter(
graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
Expand All @@ -160,10 +163,7 @@ function flip_chain(
acceptance_fn::F = always_accept,
no_self_loops::Bool = false,
progress_bar = true,
)::ChainScoreData where {F<:Function,S<:AbstractScore}
first_scores = score_initial_partition(graph, partition, scores)
chain_scores = ChainScoreData(deepcopy(scores), [first_scores])

) where {F<:Function,S<:AbstractScore}
if progress_bar
iter = ProgressBar(1:num_steps)
else
Expand All @@ -186,10 +186,71 @@ function flip_chain(
end
end
score_vals = score_partition_from_proposal(graph, partition, proposal, scores)
push!(chain_scores.step_values, score_vals)
@yield partition, score_vals
step_completed = true
end
end
end

"""
flip_chain(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
cont_constraint::ContiguityConstraint,
num_steps::Int,
scores::Array{S, 1};
acceptance_fn::F=always_accept,
no_self_loops::Bool=false)::ChainScoreData where {F<:Function, S<:AbstractScore}
Runs a Markov Chain for `num_steps` steps using Flip proposals. Returns
a `ChainScoreData` object which can be queried to retrieve the values of
every score at each step of the chain.
*Arguments:*
- graph: `BaseGraph`
- partition: `Partition` with the plan information
- pop_constraint: `PopulationConstraint`
- cont_constraint: `ContiguityConstraint`
- num_steps: Number of steps to run the chain for
- scores: Array of `AbstractScore`s to capture at each step
- acceptance_fn: A function generating a probability in [0, 1]
representing the likelihood of accepting the
proposal
- no\\_self\\_loops: If this is true, then a failure to accept a new state
is not considered a self-loop; rather, the chain
simply generates new proposals until the acceptance
function is satisfied. BEWARE - this can create
infinite loops if the acceptance function is never
satisfied!
- progress_bar If this is true, a progress bar will be printed to stdout.
"""
function flip_chain(
graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
cont_constraint::ContiguityConstraint,
num_steps::Int,
scores::Array{S,1};
acceptance_fn::F = always_accept,
no_self_loops::Bool = false,
progress_bar = true,
)::ChainScoreData where {F<:Function,S<:AbstractScore}
first_scores = score_initial_partition(graph, partition, scores)
chain_scores = ChainScoreData(deepcopy(scores), [first_scores])

for (_, score_vals) in flip_chain_iter(
graph,
partition,
pop_constraint,
cont_constraint,
num_steps,
scores;
acceptance_fn,
no_self_loops,
progress_bar,
)
push!(chain_scores.step_values, score_vals)
end

return chain_scores
end
91 changes: 80 additions & 11 deletions src/recom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,20 @@ function update_partition!(
end

"""
recom_chain(graph::BaseGraph,
recom_chain_iter(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
num_steps::Int,
scores::Array{S, 1};
num_tries::Int=3,
acceptance_fn::F=always_accept,
rng::AbstractRNG=Random.default_rng(),
no_self_loops::Bool=false)::ChainScoreData where {F<:Function, S<:AbstractScore}
no_self_loops::Bool=false) where {F<:Function,S<:AbstractScore}
Runs a Markov Chain for `num_steps` steps using ReCom. Returns a `ChainScoreData`
object which can be queried to retrieve the values of every score at each
step of the chain.
Runs a Markov Chain for `num_steps` steps using ReCom. Returns an iterator
of `(Partition, score_vals)`. Note that `Partition` is mutable and will change
in-place with each iteration -- a `deepcopy()` is needed if you wish to interact
with the `Partition` object outside of the for loop.
*Arguments:*
- graph: `BaseGraph`
Expand All @@ -284,8 +285,10 @@ step of the chain.
function is satisfied. BEWARE - this can create
infinite loops if the acceptance function is never
satisfied!
- progress_bar If this is true, a progress bar will be printed to stdout.
"""
function recom_chain(
function recom_chain_iter end # this is a workaround (https://github.com/BenLauwens/ResumableFunctions.jl/issues/45)
@resumable function recom_chain_iter(
graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
Expand All @@ -296,10 +299,7 @@ function recom_chain(
rng::AbstractRNG = Random.default_rng(),
no_self_loops::Bool = false,
progress_bar = true,
)::ChainScoreData where {F<:Function,S<:AbstractScore}
first_scores = score_initial_partition(graph, partition, scores)
chain_scores = ChainScoreData(deepcopy(scores), [first_scores])

) where {F<:Function,S<:AbstractScore}
if progress_bar
iter = ProgressBar(1:num_steps)
else
Expand All @@ -322,10 +322,79 @@ function recom_chain(
end
end
score_vals = score_partition_from_proposal(graph, partition, proposal, scores)
push!(chain_scores.step_values, score_vals)
@yield partition, score_vals
step_completed = true
end
end
end

"""
recom_chain(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
num_steps::Int,
scores::Array{S, 1};
num_tries::Int=3,
acceptance_fn::F=always_accept,
rng::AbstractRNG=Random.default_rng(),
no_self_loops::Bool=false)::ChainScoreData where {F<:Function, S<:AbstractScore}
Runs a Markov Chain for `num_steps` steps using ReCom. Returns a `ChainScoreData`
object which can be queried to retrieve the values of every score at each
step of the chain.
*Arguments:*
- graph: `BaseGraph`
- partition: `Partition` with the plan information
- pop_constraint: `PopulationConstraint`
- num_steps: Number of steps to run the chain for
- scores: Array of `AbstractScore`s to capture at each step
- num_tries: num times to try getting a balanced cut from a subgraph
before giving up
- acceptance_fn: A function generating a probability in [0, 1]
representing the likelihood of accepting the
proposal. Should accept a `Partition` as input.
- rng: Random number generator. The user can pass in their
own; otherwise, we use the default RNG from Random. Must
implement the [AbstractRNG type](https://docs.julialang.org/en/v1/stdlib/Random/#Random.AbstractRNG)
(e.g. `Random.default_rng()` or `MersenneTwister(1234)`).
- no\\_self\\_loops: If this is true, then a failure to accept a new state
is not considered a self-loop; rather, the chain
simply generates new proposals until the acceptance
function is satisfied. BEWARE - this can create
infinite loops if the acceptance function is never
satisfied!
- progress_bar If this is true, a progress bar will be printed to stdout.
"""
function recom_chain(
graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
num_steps::Int,
scores::Array{S,1};
num_tries::Int = 3,
acceptance_fn::F = always_accept,
rng::AbstractRNG = Random.default_rng(),
no_self_loops::Bool = false,
progress_bar = true,
)::ChainScoreData where {F<:Function,S<:AbstractScore}
first_scores = score_initial_partition(graph, partition, scores)
chain_scores = ChainScoreData(deepcopy(scores), [first_scores])

for (_, score_vals) in recom_chain_iter(
graph,
partition,
pop_constraint,
num_steps,
scores;
num_tries,
acceptance_fn,
rng,
no_self_loops,
progress_bar,
)
push!(chain_scores.step_values, score_vals)
end

return chain_scores
end

0 comments on commit bee809d

Please sign in to comment.