Skip to content

Commit

Permalink
Merge pull request #345 from Havaldar/ess-normalization
Browse files Browse the repository at this point in the history
Normalize Acceptance Rate Plot by ESS
  • Loading branch information
yannikschaelte authored Sep 11, 2020
2 parents 283085c + 9b70038 commit 8a9b990
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
18 changes: 15 additions & 3 deletions pyabc/visualization/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..storage import History
from .util import to_lists_or_default
from ..weighted_statistics import effective_sample_size


def plot_sample_numbers(
Expand Down Expand Up @@ -263,7 +264,8 @@ def plot_acceptance_rates_trajectory(
yscale: str = 'lin',
size: tuple = None,
ax: mpl.axes.Axes = None,
colors: List[str] = None):
colors: List[str] = None,
normalize_by_ess: bool = False):
"""
Plot of acceptance rates over all iterations, i.e. one trajectory
per history.
Expand All @@ -288,6 +290,9 @@ def plot_acceptance_rates_trajectory(
The size of the plot in inches.
ax: matplotlib.axes.Axes, optional
The axis object to use.
normalize_by_ess: bool, optional (default = False)
Indicator to use effective sample size for the acceptance rate in
place of the population size.
Returns
-------
Expand All @@ -312,9 +317,16 @@ def plot_acceptance_rates_trajectory(
# note: the first entry of time -1 is trivial and is thus ignored here
h_info = history.get_all_populations()
times.append(np.array(h_info['t'])[1:])
samples.append(np.array(h_info['samples'])[1:])
pop_sizes.append(np.array(
if normalize_by_ess:
ess = np.zeros(len(h_info['t']) - 1)
for t in np.array(h_info['t'])[1:]:
w = history.get_weighted_distances(t=t)['w']
ess[t-1] = effective_sample_size(w)
pop_sizes.append(ess)
else:
pop_sizes.append(np.array(
history.get_nr_particles_per_population().values[1:]))
samples.append(np.array(h_info['samples'])[1:])

# compute acceptance rates
rates = []
Expand Down
6 changes: 5 additions & 1 deletion test/visualization/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def test_acceptance_rates_trajectory():
histories, labels, yscale='log', rotation=76)
_, ax = plt.subplots()
pyabc.visualization.plot_acceptance_rates_trajectory(
histories, labels, yscale='log10', rotation=76, size=(10, 5), ax=ax)
histories, labels, yscale='log10', rotation=76, size=(10, 5), ax=ax
normalize_by_ess=True)
pyabc.visualization.plot_acceptance_rates_trajectory(
histories, labels, yscale='log10', rotation=76, size=(10, 5), ax=ax
normalize_by_ess=False)
plt.close()


Expand Down

0 comments on commit 8a9b990

Please sign in to comment.