diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 9100ce0cdc990..6680925b0cbb6 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -888,6 +888,7 @@ Other - Bug in :meth:`DataFrame.query` where using duplicate column names led to a ``TypeError``. (:issue:`59950`) - Bug in :meth:`DataFrame.query` which raised an exception or produced incorrect results when expressions contained backtick-quoted column names containing the hash character ``#``, backticks, or characters that fall outside the ASCII range (U+0001..U+007F). (:issue:`59285`) (:issue:`49633`) - Bug in :meth:`DataFrame.query` which raised an exception when querying integer column names using backticks. (:issue:`60494`) +- Bug in :meth:`DataFrame.sample` with ``replace=False`` and ``(n * max(weights) / sum(weights)) > 1``, the method would return biased results. Now raises ``ValueError``. (:issue:`61516`) - Bug in :meth:`DataFrame.shift` where passing a ``freq`` on a DataFrame with no columns did not shift the index correctly. (:issue:`60102`) - Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`) - Bug in :meth:`DataFrame.sort_values` where sorting by a column explicitly named ``None`` raised a ``KeyError`` instead of sorting by the column as expected. (:issue:`61512`) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 8aae4609b1833..ec5e105b24020 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -5815,6 +5815,8 @@ def sample( If weights do not sum to 1, they will be normalized to sum to 1. Missing values in the weights column will be treated as zero. Infinite values not allowed. + When replace = False will not allow ``(n * max(weights) / sum(weights)) > 1``, + in order to avoid biased results. random_state : int, array-like, BitGenerator, np.random.RandomState, np.random.Generator, optional If int, array-like, or BitGenerator, seed for random number generator. If np.random.RandomState or np.random.Generator, use as given. diff --git a/pandas/core/sample.py b/pandas/core/sample.py index 4f12563e3c5e2..463be2f41e47f 100644 --- a/pandas/core/sample.py +++ b/pandas/core/sample.py @@ -150,6 +150,13 @@ def sample( else: raise ValueError("Invalid weights: weights sum to zero") + is_max_weight_dominating = size * max(weights) > 1 + if (is_max_weight_dominating and not replace): + raise ValueError( + "Invalid weights: If `replace`=False," + " total unit probabilities have to be less than 1" + ) + return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype( np.intp, copy=False ) diff --git a/pandas/tests/frame/methods/test_sample.py b/pandas/tests/frame/methods/test_sample.py index a9d56cbfd2b46..80b472795a413 100644 --- a/pandas/tests/frame/methods/test_sample.py +++ b/pandas/tests/frame/methods/test_sample.py @@ -113,9 +113,6 @@ def test_sample_invalid_weight_lengths(self, obj): with pytest.raises(ValueError, match=msg): obj.sample(n=3, weights=[0.5] * 11) - with pytest.raises(ValueError, match="Fewer non-zero entries in p than size"): - obj.sample(n=4, weights=Series([0, 0, 0.2])) - def test_sample_negative_weights(self, obj): # Check won't accept negative weights bad_weights = [-0.1] * 10 @@ -137,6 +134,29 @@ def test_sample_inf_weights(self, obj): with pytest.raises(ValueError, match=msg): obj.sample(n=3, weights=weights_with_ninf) + def test_sample_unit_probabilities_raises(self, obj): + # GH#61516 + high_variance_weights = [1] * 10 + high_variance_weights[0] = 100 + msg = ( + "Invalid weights: If `replace`=False," + " total unit probabilities have to be less than 1" + ) + with pytest.raises(ValueError, match=msg): + obj.sample(n=2, weights=high_variance_weights, replace=False) + + # edge case, n*max(weights)/sum(weights) == 1 + edge_variance_weights = [1] * 10 + edge_variance_weights[0] = 9 + # should not raise + obj.sample(n=2, weights=edge_variance_weights, replace=False) + + low_variance_weights = [1] * 10 + low_variance_weights[0] = 8 + # should not raise + obj.sample(n=2, weights=low_variance_weights, replace=False) + + def test_sample_zero_weights(self, obj): # All zeros raises errors