diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py
index 22c056a13..f407d632f 100644
--- a/sbi/utils/user_input_checks.py
+++ b/sbi/utils/user_input_checks.py
@@ -17,6 +17,7 @@
 from sbi.utils.user_input_checks_utils import (
     CustomPriorWrapper,
     MultipleIndependent,
+    OneDimPriorWrapper,
     PytorchReturnTypeWrapper,
 )
 
@@ -220,6 +221,11 @@ def process_pytorch_prior(prior: Distribution) -> Tuple[Distribution, int, bool]
     # This will fail for float64 priors.
     check_prior_return_type(prior)
 
+    # Potentially required wrapper if the prior returns an additional sample dimension
+    # for `.log_prob()`.
+    if prior.log_prob(prior.sample(torch.Size((10,)))).shape == torch.Size([10, 1]):
+        prior = OneDimPriorWrapper(prior, validate_args=False)
+
     theta_numel = prior.sample().numel()
 
     return prior, theta_numel, False
diff --git a/sbi/utils/user_input_checks_utils.py b/sbi/utils/user_input_checks_utils.py
index 184d167c5..554a6f8f2 100644
--- a/sbi/utils/user_input_checks_utils.py
+++ b/sbi/utils/user_input_checks_utils.py
@@ -373,3 +373,69 @@ def build_support(
             support = constraints.interval(lower_bound, upper_bound)
 
     return support
+
+
+class OneDimPriorWrapper(Distribution):
+    """Wrap batched 1D distributions to get rid of the batch dim of `.log_prob()`.
+
+    1D pytorch distributions such as `torch.distributions.Exponential`, `.Uniform`, or
+    `.Normal` do not, by default return __any__ sample or batch dimension. E.g.:
+    ```python
+    dist = torch.distributions.Exponential(torch.tensor(3.0))
+    dist.sample((10,)).shape  # (10,)
+    ```
+
+    `sbi` will raise an error that the sample dimension is missing. A simple solution is
+    to add a batch dimension to `dist` as follows:
+    ```python
+    dist = torch.distributions.Exponential(torch.tensor([3.0]))
+    dist.sample((10,)).shape  # (10, 1)
+    ```
+
+    Unfortunately, this `dist` will return the batch dimension also for `.log_prob():
+    ```python
+    dist = torch.distributions.Exponential(torch.tensor([3.0]))
+    samples = dist.sample((10,))
+    dist.log_prob(samples).shape  # (10, 1)
+    ```
+
+    This will lead to unexpected errors in `sbi`. The point of this class is to wrap
+    those batched 1D distributions to get rid of their batch dimension in `.log_prob()`.
+    """
+
+    def __init__(
+        self,
+        prior: Distribution,
+        validate_args=None,
+    ) -> None:
+        super().__init__(
+            batch_shape=prior.batch_shape,
+            event_shape=prior.event_shape,
+            validate_args=(
+                prior._validate_args if validate_args is None else validate_args
+            ),
+        )
+        self.prior = prior
+
+    def sample(self, *args, **kwargs) -> Tensor:
+        return self.prior.sample(*args, **kwargs)
+
+    def log_prob(self, *args, **kwargs) -> Tensor:
+        """Override the log_prob method to get rid of the additional batch dimension."""
+        return self.prior.log_prob(*args, **kwargs)[..., 0]
+
+    @property
+    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
+        return self.prior.arg_constraints
+
+    @property
+    def support(self):
+        return self.prior.support
+
+    @property
+    def mean(self) -> Tensor:
+        return self.prior.mean
+
+    @property
+    def variance(self) -> Tensor:
+        return self.prior.variance
diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py
index 2766d8513..cd275a7b0 100644
--- a/tests/user_input_checks_test.py
+++ b/tests/user_input_checks_test.py
@@ -10,7 +10,14 @@
 import torch
 from pyknos.mdn.mdn import MultivariateGaussianMDN
 from torch import Tensor, eye, nn, ones, zeros
-from torch.distributions import Beta, Distribution, Gamma, MultivariateNormal, Uniform
+from torch.distributions import (
+    Beta,
+    Distribution,
+    Exponential,
+    Gamma,
+    MultivariateNormal,
+    Uniform,
+)
 
 from sbi.inference import NPE_A, NPE_C, simulate_for_sbi
 from sbi.inference.posteriors.direct_posterior import DirectPosterior
@@ -27,6 +34,7 @@
 from sbi.utils.user_input_checks_utils import (
     CustomPriorWrapper,
     MultipleIndependent,
+    OneDimPriorWrapper,
     PytorchReturnTypeWrapper,
 )
 
@@ -93,6 +101,11 @@ def matrix_simulator(theta):
             BoxUniform(zeros(3, dtype=torch.float64), ones(3, dtype=torch.float64)),
             dict(),
         ),
+        (
+            OneDimPriorWrapper,
+            Exponential(torch.tensor([3.0])),
+            dict(),
+        ),
     ),
 )
 def test_prior_wrappers(wrapper, prior, kwargs):
@@ -118,6 +131,9 @@ def test_prior_wrappers(wrapper, prior, kwargs):
     # Test transform
     mcmc_transform(prior)
 
+    # For 1D priors, the `log_prob()` should not have a batch dim.
+    assert len(prior.log_prob(prior.sample((10,))).shape) == 1
+
 
 def test_reinterpreted_batch_dim_prior():
     """Test whether the right warning and error are raised for reinterpreted priors."""
@@ -268,7 +284,6 @@ def test_prepare_sbi_problem(simulator: Callable, prior):
         prior: prior as defined by the user (pytorch, scipy, custom)
         x_shape: shape of data as defined by the user.
     """
-
     prior, _, prior_returns_numpy = process_prior(prior)
     simulator = process_simulator(simulator, prior, prior_returns_numpy)
     check_sbi_inputs(simulator, prior)
@@ -308,6 +323,7 @@ def test_prepare_sbi_problem(simulator: Callable, prior):
                 MultivariateNormal(zeros(2), eye(2)),
             ),
         ),
+        (diagonal_linear_gaussian, Exponential(torch.tensor([3.0]))),
     ),
 )
 def test_inference_with_user_sbi_problems(