Skip to content

Commit

Permalink
Fix pairplot histogram bins and aspect ratio (#1220)
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 authored Aug 16, 2024
1 parent d34c785 commit 593e153
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import collections
import copy
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from warnings import warn
Expand Down Expand Up @@ -81,24 +82,18 @@ def plt_hist_1d(
diag_kwargs: Dict,
) -> None:
"""Plot 1D histogram."""
if (
"bins" not in diag_kwargs["mpl_kwargs"]
or diag_kwargs["mpl_kwargs"]["bins"] is None
):
hist_kwargs = copy.deepcopy(diag_kwargs["mpl_kwargs"])
if "bins" not in hist_kwargs or hist_kwargs["bins"] is None:
if diag_kwargs["bin_heuristic"] == "Freedman-Diaconis":
# The Freedman-Diaconis heuristic
binsize = 2 * iqr(samples) * len(samples) ** (-1 / 3)
diag_kwargs["mpl_kwargs"]["bins"] = np.arange(
limits[0], limits[1] + binsize, binsize
)
hist_kwargs["bins"] = np.arange(limits[0], limits[1] + binsize, binsize)
else:
# TODO: add more bin heuristics
pass
if isinstance(diag_kwargs["mpl_kwargs"]["bins"], int):
diag_kwargs["mpl_kwargs"]["bins"] = np.linspace(
limits[0], limits[1], diag_kwargs["mpl_kwargs"]["bins"]
)
ax.hist(samples, **diag_kwargs["mpl_kwargs"])
if isinstance(hist_kwargs["bins"], int):
hist_kwargs["bins"] = np.linspace(limits[0], limits[1], hist_kwargs["bins"])
ax.hist(samples, **hist_kwargs)


def plt_kde_1d(
Expand Down Expand Up @@ -133,18 +128,19 @@ def plt_hist_2d(
limits_row: torch.Tensor,
offdiag_kwargs: Dict,
):
hist_kwargs = copy.deepcopy(offdiag_kwargs)
"""Plot 2D histogram."""
if (
"bins" not in offdiag_kwargs["np_hist_kwargs"]
or offdiag_kwargs["np_hist_kwargs"]["bins"] is None
"bins" not in hist_kwargs["np_hist_kwargs"]
or hist_kwargs["np_hist_kwargs"]["bins"] is None
):
if offdiag_kwargs["bin_heuristic"] == "Freedman-Diaconis":
if hist_kwargs["bin_heuristic"] == "Freedman-Diaconis":
# The Freedman-Diaconis heuristic applied to each direction
binsize_col = 2 * iqr(samples_col) * len(samples_col) ** (-1 / 3)
n_bins_col = int((limits_col[1] - limits_col[0]) / binsize_col)
binsize_row = 2 * iqr(samples_row) * len(samples_row) ** (-1 / 3)
n_bins_row = int((limits_row[1] - limits_row[0]) / binsize_row)
offdiag_kwargs["np_hist_kwargs"]["bins"] = [n_bins_col, n_bins_row]
hist_kwargs["np_hist_kwargs"]["bins"] = [n_bins_col, n_bins_row]
else:
# TODO: add more bin heuristics
pass
Expand All @@ -155,7 +151,7 @@ def plt_hist_2d(
[limits_col[0], limits_col[1]],
[limits_row[0], limits_row[1]],
],
**offdiag_kwargs["np_hist_kwargs"],
**hist_kwargs["np_hist_kwargs"],
)
ax.imshow(
hist.T,
Expand All @@ -165,7 +161,7 @@ def plt_hist_2d(
yedges[0],
yedges[-1],
),
**offdiag_kwargs["mpl_kwargs"],
**hist_kwargs["mpl_kwargs"],
)


Expand Down Expand Up @@ -919,13 +915,13 @@ def _get_default_offdiag_kwargs(offdiag: Optional[str], i: int = 0) -> Dict:
offdiag_kwargs = {
"bw_method": "scott",
"bins": 50,
"mpl_kwargs": {"cmap": "viridis", "origin": "lower"},
"mpl_kwargs": {"cmap": "viridis", "origin": "lower", "aspect": "auto"},
}

elif offdiag == "hist" or offdiag == "hist2d":
offdiag_kwargs = {
"bin_heuristic": None, # "Freedman-Diaconis",
"mpl_kwargs": {"cmap": "viridis", "origin": "lower"},
"mpl_kwargs": {"cmap": "viridis", "origin": "lower", "aspect": "auto"},
"np_hist_kwargs": {"bins": 50, "density": False},
}

Expand All @@ -945,13 +941,14 @@ def _get_default_offdiag_kwargs(offdiag: Optional[str], i: int = 0) -> Dict:
"levels": [0.68, 0.95, 0.99],
"percentile": True,
"mpl_kwargs": {
"colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2] # pyright: ignore[reportOptionalMemberAccess]
"colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2], # pyright: ignore[reportOptionalMemberAccess]
},
}
elif offdiag == "plot":
offdiag_kwargs = {
"mpl_kwargs": {
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2] # pyright: ignore[reportOptionalMemberAccess]
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2], # pyright: ignore[reportOptionalMemberAccess]
"aspect": "auto",
}
}
else:
Expand Down

0 comments on commit 593e153

Please sign in to comment.