diff --git a/sbi/analysis/plot.py b/sbi/analysis/plot.py index 45b072cea..9822fe590 100644 --- a/sbi/analysis/plot.py +++ b/sbi/analysis/plot.py @@ -554,42 +554,71 @@ def handle_nan_infs(samples: List[np.ndarray]) -> List[np.ndarray]: return samples +def convert_to_list_of_numpy( + arr: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor], +) -> List[np.ndarray]: + """Converts a list of torch.Tensor to a list of np.ndarray.""" + if not isinstance(arr, list): + arr = ensure_numpy(arr) + return [arr] + return [ensure_numpy(a) for a in arr] + + +def infer_limits( + samples: List[np.ndarray], + dim: int, + points: Optional[List[np.ndarray]] = None, + eps: float = 0.1, +) -> List[List[float]]: + """Infer limits for the plot. + + Args: + samples: List of set of samples. + dim: Dimension of the samples. + points: List of points. + eps: Relative margin for the limits. + """ + limits = [] + for d in range(dim): + # get min and max across all sets of samples + min_val = min(np.min(sample[:, d]) for sample in samples) + max_val = max(np.max(sample[:, d]) for sample in samples) + # include points in the limits + if points is not None: + min_val = min(min_val, min(np.min(point[:, d]) for point in points)) + max_val = max(max_val, max(np.max(point[:, d]) for point in points)) + # add margin + max_min_range = max_val - min_val + epsilon_range = eps * max_min_range + limits.append([min_val - epsilon_range, max_val + epsilon_range]) + return limits + + def prepare_for_plot( samples: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor], - limits: Optional[Union[List, torch.Tensor, np.ndarray]], + limits: Optional[Union[List, torch.Tensor, np.ndarray]] = None, + points: Optional[ + Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor] + ] = None, ) -> Tuple[List[np.ndarray], int, torch.Tensor]: """ Ensures correct formatting for samples and limits, and returns dimension of the samples. """ - # Prepare samples - if not isinstance(samples, list): - samples = ensure_numpy(samples) - samples = [samples] - else: - samples = [ensure_numpy(sample) for sample in samples] + samples = convert_to_list_of_numpy(samples) + if points is not None: + points = convert_to_list_of_numpy(points) - # check if nans and infs samples = handle_nan_infs(samples) - # Dimensionality of the problem. dim = samples[0].shape[1] - # Prepare limits. Infer them from samples if they had not been passed. - if limits == [] or limits is None: - limits = [] - for d in range(dim): - min = +np.inf - max = -np.inf - for sample in samples: - min_ = np.min(sample[:, d]) - min = min_ if min_ < min else min - max_ = np.max(sample[:, d]) - max = max_ if max_ > max else max - limits.append([min, max]) + if limits is None or limits == []: + limits = infer_limits(samples, dim, points) else: limits = [limits[0] for _ in range(dim)] if len(limits) == 1 else limits + limits = torch.as_tensor(limits) return samples, dim, limits @@ -737,7 +766,7 @@ def pairplot( ) return fig, axes - samples, dim, limits = prepare_for_plot(samples, limits) + samples, dim, limits = prepare_for_plot(samples, limits, points) # prepate figure kwargs fig_kwargs_filled = _get_default_fig_kwargs() diff --git a/tests/plot_test.py b/tests/plot_test.py index 8505956b5..9f21553ca 100644 --- a/tests/plot_test.py +++ b/tests/plot_test.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("samples", (torch.randn(100, 1),)) -@pytest.mark.parametrize("limits", ([(-1, 1)],)) +@pytest.mark.parametrize("limits", ([(-1, 1)], None)) def test_pairplot1D(samples, limits): fig, axs = pairplot(**{k: v for k, v in locals().items() if v is not None}) assert isinstance(fig, Figure) @@ -24,7 +24,7 @@ def test_pairplot1D(samples, limits): @pytest.mark.parametrize("samples", (torch.randn(100, 2),)) -@pytest.mark.parametrize("limits", ([(-1, 1)],)) +@pytest.mark.parametrize("limits", ([(-1, 1)], None)) def test_nan_inf(samples, limits): samples[0, 0] = np.nan samples[5, 1] = np.inf @@ -37,7 +37,7 @@ def test_nan_inf(samples, limits): @pytest.mark.parametrize("samples", (torch.randn(100, 2), [torch.randn(100, 3)] * 2)) @pytest.mark.parametrize("points", (torch.ones(1, 3),)) -@pytest.mark.parametrize("limits", ([(-3, 3)],)) +@pytest.mark.parametrize("limits", ([(-3, 3)], None)) @pytest.mark.parametrize("subset", (None, [0, 1])) @pytest.mark.parametrize("upper", ("scatter",)) @pytest.mark.parametrize(