Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for user-supplied RNG state in all interfaces #520

Open
wants to merge 9 commits into
base: modular-rng
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/Gen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

module Gen

using Random: AbstractRNG, default_rng

"""
load_generated_functions(__module__=Main)

Expand Down
15 changes: 9 additions & 6 deletions src/dynamic/dynamic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ end

accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad

mutable struct GFUntracedState
mutable struct GFUntracedState{R<:AbstractRNG}
params::Dict{Symbol,Any}
rng::R
end

function (gen_fn::DynamicDSLFunction)(args...)
state = GFUntracedState(gen_fn.params)
(gen_fn::DynamicDSLFunction)(args...) = gen_fn(default_rng(), args...)

function (gen_fn::DynamicDSLFunction)(rng::AbstractRNG, args...)
state = GFUntracedState(gen_fn.params, rng)
gen_fn.julia_function(state, args...)
end

Expand Down Expand Up @@ -82,13 +85,13 @@ end

# Defaults for untraced execution
@inline traceat(state::GFUntracedState, gen_fn::GenerativeFunction, args, key) =
gen_fn(args...)
gen_fn(state.rng, args...)

@inline traceat(state::GFUntracedState, dist::Distribution, args, key) =
random(dist, args...)
random(state.rng, dist, args...)

@inline splice(state::GFUntracedState, gen_fn::DynamicDSLFunction, args::Tuple) =
gen_fn(args...)
gen_fn(state.rng, args...)

########################
# trainable parameters #
Expand Down
20 changes: 12 additions & 8 deletions src/dynamic/generate.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
mutable struct GFGenerateState
mutable struct GFGenerateState{R<:AbstractRNG}
trace::DynamicDSLTrace
constraints::ChoiceMap
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::R
end

function GFGenerateState(gen_fn, args, constraints, params)
function GFGenerateState(gen_fn, args, constraints, params, rng::AbstractRNG)
trace = DynamicDSLTrace(gen_fn, args)
GFGenerateState(trace, constraints, 0., AddressVisitor(), params)
GFGenerateState(trace, constraints, 0., AddressVisitor(), params, rng)
end

function traceat(state::GFGenerateState, dist::Distribution{T},
Expand All @@ -26,7 +27,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T},
if constrained
retval = get_value(state.constraints, key)
else
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
end

# compute logpdf
Expand Down Expand Up @@ -55,7 +56,7 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U},
constraints = get_submap(state.constraints, key)

# get subtrace
(subtrace, weight) = generate(gen_fn, args, constraints)
(subtrace, weight) = generate(state.rng, gen_fn, args, constraints)

# add to the trace
add_call!(state.trace, key, subtrace)
Expand All @@ -78,9 +79,12 @@ function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction,
retval
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 59, the recursive call to generate needs to pass state.rng to the callee function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function generate(gen_fn::DynamicDSLFunction, args::Tuple,
constraints::ChoiceMap)
state = GFGenerateState(gen_fn, args, constraints, gen_fn.params)
generate(gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap) =
generate(default_rng(), gen_fn, args, constraints)

function generate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple,
constraints::ChoiceMap)
state = GFGenerateState(gen_fn, args, constraints, gen_fn.params, rng)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
(state.trace, state.weight)
Expand Down
15 changes: 8 additions & 7 deletions src/dynamic/propose.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
mutable struct GFProposeState
mutable struct GFProposeState{R<:AbstractRNG}
choices::DynamicChoiceMap
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::R
end

function GFProposeState(params::Dict{Symbol,Any})
GFProposeState(choicemap(), 0., AddressVisitor(), params)
function GFProposeState(params::Dict{Symbol,Any}, rng::AbstractRNG)
GFProposeState(choicemap(), 0., AddressVisitor(), params, rng)
end

function traceat(state::GFProposeState, dist::Distribution{T},
Expand All @@ -17,7 +18,7 @@ function traceat(state::GFProposeState, dist::Distribution{T},
visit!(state.visitor, key)

# sample return value
retval = random(dist, args...)
retval = random(state.rng, dist, args...)

# update assignment
set_value!(state.choices, key, retval)
Expand All @@ -36,7 +37,7 @@ function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U},
visit!(state.visitor, key)

# get subtrace
(submap, weight, retval) = propose(gen_fn, args)
(submap, weight, retval) = propose(state.rng, gen_fn, args)

# update assignment
set_submap!(state.choices, key, submap)
Expand All @@ -55,8 +56,8 @@ function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple)
retval
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 40, state.rng needs to be passed to the recursive call to propose.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function propose(gen_fn::DynamicDSLFunction, args::Tuple)
state = GFProposeState(gen_fn.params)
function propose(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple)
state = GFProposeState(gen_fn.params, rng)
retval = exec(gen_fn, state, args)
(state.choices, state.weight, retval)
end
21 changes: 11 additions & 10 deletions src/dynamic/regenerate.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
mutable struct GFRegenerateState
mutable struct GFRegenerateState{R<:AbstractRNG}
prev_trace::DynamicDSLTrace
trace::DynamicDSLTrace
selection::Selection
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::R
end

function GFRegenerateState(gen_fn, args, prev_trace,
selection, params)
selection, params, rng::AbstractRNG)
visitor = AddressVisitor()
GFRegenerateState(prev_trace, DynamicDSLTrace(gen_fn, args), selection,
0., visitor, params)
0., visitor, params, rng)
end

function traceat(state::GFRegenerateState, dist::Distribution{T},
Expand All @@ -35,11 +36,11 @@ function traceat(state::GFRegenerateState, dist::Distribution{T},

# get return value
if has_previous && in_selection
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
elseif has_previous
retval = prev_retval
else
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
end

# compute logpdf
Expand Down Expand Up @@ -75,9 +76,9 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U},
prev_subtrace = prev_call.subtrace
get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key)
(subtrace, weight, _) = regenerate(
prev_subtrace, args, map((_) -> UnknownChange(), args), subselection)
state.rng, prev_subtrace, args, map((_) -> UnknownChange(), args), subselection)
else
(subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap())
(subtrace, weight) = generate(state.rng, gen_fn, args, EmptyChoiceMap())
end

# update weight
Expand Down Expand Up @@ -130,10 +131,10 @@ function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord},
noise
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On lines 78 and 81, state.rng needs to be passed to the calls to regenerate and generate respectively.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function regenerate(trace::DynamicDSLTrace, args::Tuple, argdiffs::Tuple,
selection::Selection)
function regenerate(rng::AbstractRNG, trace::DynamicDSLTrace, args::Tuple,
argdiffs::Tuple, selection::Selection)
gen_fn = trace.gen_fn
state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params)
state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params, rng)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
visited = state.visitor.visited
Expand Down
15 changes: 8 additions & 7 deletions src/dynamic/simulate.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
mutable struct GFSimulateState
mutable struct GFSimulateState{R<:AbstractRNG}
trace::DynamicDSLTrace
visitor::AddressVisitor
params::Dict{Symbol,Any}
rng::R
end

function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params)
function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params, rng::AbstractRNG)
trace = DynamicDSLTrace(gen_fn, args)
GFSimulateState(trace, AddressVisitor(), params)
GFSimulateState(trace, AddressVisitor(), params, rng)
end

function traceat(state::GFSimulateState, dist::Distribution{T},
Expand All @@ -16,7 +17,7 @@ function traceat(state::GFSimulateState, dist::Distribution{T},
# check that key was not already visited, and mark it as visited
visit!(state.visitor, key)

retval = random(dist, args...)
retval = random(state.rng, dist, args...)

# compute logpdf
score = logpdf(dist, retval, args...)
Expand All @@ -36,7 +37,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U},
visit!(state.visitor, key)

# get subtrace
subtrace = simulate(gen_fn, args)
subtrace = simulate(state.rng, gen_fn, args)

# add to the trace
add_call!(state.trace, key, subtrace)
Expand All @@ -56,8 +57,8 @@ function splice(state::GFSimulateState, gen_fn::DynamicDSLFunction,
retval
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 40, state.rng needs to be passed to the call to simulate.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function simulate(gen_fn::DynamicDSLFunction, args::Tuple)
state = GFSimulateState(gen_fn, args, gen_fn.params)
function simulate(rng::AbstractRNG, gen_fn::DynamicDSLFunction, args::Tuple)
state = GFSimulateState(gen_fn, args, gen_fn.params, rng)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
state.trace
Expand Down
17 changes: 9 additions & 8 deletions src/dynamic/update.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
mutable struct GFUpdateState
mutable struct GFUpdateState{R<:AbstractRNG}
prev_trace::DynamicDSLTrace
trace::DynamicDSLTrace
constraints::Any
weight::Float64
visitor::AddressVisitor
params::Dict{Symbol,Any}
discard::DynamicChoiceMap
rng::R
end

function GFUpdateState(gen_fn, args, prev_trace, constraints, params)
function GFUpdateState(gen_fn, args, prev_trace, constraints, params, rng::AbstractRNG)
visitor = AddressVisitor()
discard = choicemap()
trace = DynamicDSLTrace(gen_fn, args)
GFUpdateState(prev_trace, trace, constraints,
0., visitor, params, discard)
0., visitor, params, discard, rng)
end

function traceat(state::GFUpdateState, dist::Distribution{T},
Expand Down Expand Up @@ -48,7 +49,7 @@ function traceat(state::GFUpdateState, dist::Distribution{T},
elseif has_previous
retval = prev_retval
else
retval = random(dist, args...)
retval = random(state.rng, dist, args...)
end

# compute logpdf
Expand Down Expand Up @@ -87,10 +88,10 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U},
prev_call = get_call(state.prev_trace, key)
prev_subtrace = prev_call.subtrace
get_gen_fn(prev_subtrace) == gen_fn || gen_fn_changed_error(key)
(subtrace, weight, _, discard) = update(prev_subtrace,
(subtrace, weight, _, discard) = update(state.rng, prev_subtrace,
args, map((_) -> UnknownChange(), args), constraints)
else
(subtrace, weight) = generate(gen_fn, args, constraints)
(subtrace, weight) = generate(state.rng, gen_fn, args, constraints)
end

# update the weight
Expand Down Expand Up @@ -184,10 +185,10 @@ function add_unvisited_to_discard!(discard::DynamicChoiceMap,
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On lines 91 and 94, state.rng needs to be passed to the recursive calls to update and generate respectively.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


function update(trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple,
function update(rng::AbstractRNG, trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple,
constraints::ChoiceMap)
gen_fn = trace.gen_fn
state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params)
state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params, rng)
retval = exec(gen_fn, state, arg_values)
set_retval!(state.trace, retval)
visited = get_visited(state.visitor)
Expand Down
Loading
Loading