Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterator APIs (via ResumableFunctions.jl) #121

Merged
merged 6 commits into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184"
Shapefile = "8e980c4a-a4fe-5da2-b3a7-4b4b0353a2f4"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
1 change: 1 addition & 0 deletions src/GerryChain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Shapefile
import LibGEOS
import LibSpatialIndex
using Logging
using ResumableFunctions

export AbstractGraph,
BaseGraph,
Expand Down
79 changes: 69 additions & 10 deletions src/flip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,17 @@ function update_partition!(
end

"""
flip_chain(graph::BaseGraph,
flip_chain_iter(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
cont_constraint::ContiguityConstraint,
num_steps::Int,
scores::Array{S, 1};
acceptance_fn::F=always_accept,
no_self_loops::Bool=false)::ChainScoreData where {F<:Function, S<:AbstractScore}
no_self_loops::Bool=false) where {F<:Function,S<:AbstractScore}

Runs a Markov Chain for `num_steps` steps using Flip proposals. Returns
a `ChainScoreData` object which can be queried to retrieve the values of
every score at each step of the chain.
an iterator of `(Partition, score_vals)`.

*Arguments:*
- graph: `BaseGraph`
Expand All @@ -149,8 +148,10 @@ every score at each step of the chain.
function is satisfied. BEWARE - this can create
infinite loops if the acceptance function is never
satisfied!
- progress_bar If this is true, a progress bar will be printed to stdout.
"""
function flip_chain(
function flip_chain_iter end # this is a workaround
InnovativeInventor marked this conversation as resolved.
Show resolved Hide resolved
@resumable function flip_chain_iter(
graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
Expand All @@ -160,10 +161,7 @@ function flip_chain(
acceptance_fn::F = always_accept,
no_self_loops::Bool = false,
progress_bar = true,
)::ChainScoreData where {F<:Function,S<:AbstractScore}
first_scores = score_initial_partition(graph, partition, scores)
chain_scores = ChainScoreData(deepcopy(scores), [first_scores])

) where {F<:Function,S<:AbstractScore}
if progress_bar
iter = ProgressBar(1:num_steps)
else
Expand All @@ -186,10 +184,71 @@ function flip_chain(
end
end
score_vals = score_partition_from_proposal(graph, partition, proposal, scores)
push!(chain_scores.step_values, score_vals)
@yield partition, score_vals
step_completed = true
end
end
end

"""
flip_chain(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
cont_constraint::ContiguityConstraint,
num_steps::Int,
scores::Array{S, 1};
acceptance_fn::F=always_accept,
no_self_loops::Bool=false)::ChainScoreData where {F<:Function, S<:AbstractScore}

Runs a Markov Chain for `num_steps` steps using Flip proposals. Returns
a `ChainScoreData` object which can be queried to retrieve the values of
every score at each step of the chain.

*Arguments:*
- graph: `BaseGraph`
- partition: `Partition` with the plan information
- pop_constraint: `PopulationConstraint`
- cont_constraint: `ContiguityConstraint`
- num_steps: Number of steps to run the chain for
- scores: Array of `AbstractScore`s to capture at each step
- acceptance_fn: A function generating a probability in [0, 1]
representing the likelihood of accepting the
proposal
- no\\_self\\_loops: If this is true, then a failure to accept a new state
is not considered a self-loop; rather, the chain
simply generates new proposals until the acceptance
function is satisfied. BEWARE - this can create
infinite loops if the acceptance function is never
satisfied!
- progress_bar If this is true, a progress bar will be printed to stdout.
"""
function flip_chain(
graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
cont_constraint::ContiguityConstraint,
num_steps::Int,
scores::Array{S,1};
acceptance_fn::F = always_accept,
no_self_loops::Bool = false,
progress_bar = true,
)::ChainScoreData where {F<:Function,S<:AbstractScore}
first_scores = score_initial_partition(graph, partition, scores)
chain_scores = ChainScoreData(deepcopy(scores), [first_scores])

for (_, score_vals) in flip_chain_iter(
graph,
partition,
pop_constraint,
cont_constraint,
num_steps,
scores;
acceptance_fn,
no_self_loops,
progress_bar,
)
push!(chain_scores.step_values, score_vals)
end

return chain_scores
end
89 changes: 78 additions & 11 deletions src/recom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,18 @@ function update_partition!(
end

"""
recom_chain(graph::BaseGraph,
recom_chain_iter(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
num_steps::Int,
scores::Array{S, 1};
num_tries::Int=3,
acceptance_fn::F=always_accept,
rng::AbstractRNG=Random.default_rng(),
no_self_loops::Bool=false)::ChainScoreData where {F<:Function, S<:AbstractScore}
no_self_loops::Bool=false) where {F<:Function,S<:AbstractScore}

Runs a Markov Chain for `num_steps` steps using ReCom. Returns a `ChainScoreData`
object which can be queried to retrieve the values of every score at each
step of the chain.
Runs a Markov Chain for `num_steps` steps using ReCom. Returns an iterator
of `(Partition, score_vals)`.

*Arguments:*
- graph: `BaseGraph`
Expand All @@ -284,8 +283,10 @@ step of the chain.
function is satisfied. BEWARE - this can create
infinite loops if the acceptance function is never
satisfied!
- progress_bar If this is true, a progress bar will be printed to stdout.
"""
function recom_chain(
function recom_chain_iter end # this is a workaround
@resumable function recom_chain_iter(
graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
Expand All @@ -296,10 +297,7 @@ function recom_chain(
rng::AbstractRNG = Random.default_rng(),
no_self_loops::Bool = false,
progress_bar = true,
)::ChainScoreData where {F<:Function,S<:AbstractScore}
first_scores = score_initial_partition(graph, partition, scores)
chain_scores = ChainScoreData(deepcopy(scores), [first_scores])

) where {F<:Function,S<:AbstractScore}
if progress_bar
iter = ProgressBar(1:num_steps)
else
Expand All @@ -322,10 +320,79 @@ function recom_chain(
end
end
score_vals = score_partition_from_proposal(graph, partition, proposal, scores)
push!(chain_scores.step_values, score_vals)
@yield partition, score_vals
step_completed = true
end
end
end

"""
recom_chain(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
num_steps::Int,
scores::Array{S, 1};
num_tries::Int=3,
acceptance_fn::F=always_accept,
rng::AbstractRNG=Random.default_rng(),
no_self_loops::Bool=false)::ChainScoreData where {F<:Function, S<:AbstractScore}

Runs a Markov Chain for `num_steps` steps using ReCom. Returns a `ChainScoreData`
object which can be queried to retrieve the values of every score at each
step of the chain.

*Arguments:*
- graph: `BaseGraph`
- partition: `Partition` with the plan information
- pop_constraint: `PopulationConstraint`
- num_steps: Number of steps to run the chain for
- scores: Array of `AbstractScore`s to capture at each step
- num_tries: num times to try getting a balanced cut from a subgraph
before giving up
- acceptance_fn: A function generating a probability in [0, 1]
representing the likelihood of accepting the
proposal. Should accept a `Partition` as input.
- rng: Random number generator. The user can pass in their
own; otherwise, we use the default RNG from Random. Must
implement the [AbstractRNG type](https://docs.julialang.org/en/v1/stdlib/Random/#Random.AbstractRNG)
(e.g. `Random.default_rng()` or `MersenneTwister(1234)`).
- no\\_self\\_loops: If this is true, then a failure to accept a new state
is not considered a self-loop; rather, the chain
simply generates new proposals until the acceptance
function is satisfied. BEWARE - this can create
infinite loops if the acceptance function is never
satisfied!
- progress_bar If this is true, a progress bar will be printed to stdout.
"""
function recom_chain(
graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
num_steps::Int,
scores::Array{S,1};
num_tries::Int = 3,
acceptance_fn::F = always_accept,
rng::AbstractRNG = Random.default_rng(),
no_self_loops::Bool = false,
progress_bar = true,
)::ChainScoreData where {F<:Function,S<:AbstractScore}
first_scores = score_initial_partition(graph, partition, scores)
chain_scores = ChainScoreData(deepcopy(scores), [first_scores])

for (_, score_vals) in recom_chain_iter(
graph,
partition,
pop_constraint,
num_steps,
scores;
num_tries,
acceptance_fn,
rng,
no_self_loops,
progress_bar,
)
push!(chain_scores.step_values, score_vals)
end

return chain_scores
end