Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/data/api/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ compute aggregations.
Unique
MissingValuePercentage
ZeroPercentage
ApproximateQuantile

101 changes: 101 additions & 0 deletions python/ray/data/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,3 +1189,104 @@ def finalize(self, accumulator: List[int]) -> Optional[float]:
if accumulator[1] == 0:
return None
return (accumulator[0] / accumulator[1]) * 100.0


@PublicAPI(stability="alpha")
class ApproximateQuantile(AggregateFnV2):
def _require_datasketches(self):
try:
from datasketches import kll_floats_sketch # type: ignore[import]
except ImportError as exc:
raise ImportError(
"ApproximateQuantile requires the `datasketches` package. "
"Install it with `pip install datasketches`."
) from exc
return kll_floats_sketch

def __init__(
self,
on: str,
quantiles: List[float],
k: int = 800,
alias_name: Optional[str] = None,
):
"""
Computes the approximate quantiles of a column by using a datasketches kll_floats_sketch.
https://datasketches.apache.org/docs/KLL/KLLSketch.html

The accuracy of the KLL quantile sketch is a function of the configured K, which also affects
the overall size of the sketch.
The KLL Sketch has absolute error. For example, a specified rank accuracy of 1% at the
median (rank = 0.50) means that the true quantile (if you could extract it from the set)
should be between getQuantile(0.49) and getQuantile(0.51). This same 1% error applied at a
rank of 0.95 means that the true quantile should be between getQuantile(0.94) and getQuantile(0.96).
In other words, the error is a fixed +/- epsilon for the entire range of ranks.

Typical single-sided rank error by k (use for getQuantile/getRank):
- k=100 → ~2.61%
- k=200 → ~1.33%
- k=400 → ~0.68%
- k=800 → ~0.35%

See https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html for details on accuracy and size.

Null values in the target column are ignored when constructing the sketch.

Example:

.. testcode::

import ray
from ray.data.aggregate import ApproximateQuantile

# Create a dataset with some values
ds = ray.data.from_items(
[{"value": 20.0}, {"value": 40.0}, {"value": 60.0},
{"value": 80.0}, {"value": 100.0}]
)

result = ds.aggregate(ApproximateQuantile(on="value", quantiles=[0.1, 0.5, 0.9]))
# Result: {'approx_quantile(value)': [20.0, 60.0, 100.0]}


Args:
on: The name of the column to calculate the quantile on. Must be a numeric column.
quantiles: The list of quantiles to compute. Must be between 0 and 1 inclusive. For example, quantiles=[0.5] computes the median. Null entries in the source column are skipped.
k: Controls the accuracy and memory footprint of the sketch; higher k yields lower error but uses more memory. Defaults to 800. See https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html for details on accuracy and size.
alias_name: Optional name for the resulting column. If not provided, defaults to "approx_quantile({column_name})".
"""
self._require_datasketches()
self._quantiles = quantiles
self._k = k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of k, let's use capacity_per_level

Copy link
Member Author

@owenowenisme owenowenisme Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

capacity_per_level does not feel accurate to me, I think maybe we don't need to hide the detail of k, since user will need to see the doc from datasketches anyway.

I added link to k params description to guide users to the doc for more info.

super().__init__(
alias_name if alias_name else f"approx_quantile({str(on)})",
on=on,
ignore_nulls=True,
zero_factory=lambda: self.zero(k).serialize(),
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Quantile Initialization Error and Inconsistent Parameter Usage

The ApproximateQuantile constructor passes on and ignore_nulls as positional arguments to super().__init__, which expects keyword-only arguments, causing a TypeError. Additionally, the zero_factory lambda uses the k parameter from __init__ instead of the self._k instance variable, which could lead to inconsistent behavior if self._k is modified after initialization.

Fix in Cursor Fix in Web


def zero(self, k: int):
sketch_cls = self._require_datasketches()
return sketch_cls(k=k)

def aggregate_block(self, block: Block) -> bytes:
block_acc = BlockAccessor.for_block(block)
table = block_acc.to_arrow()
column = table.column(self.get_target_column())
sketch = self.zero(self._k)
for value in column:
# we ignore nulls here
if value.as_py() is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need an as_py() conversion here? What type is this value?

Copy link
Member Author

@owenowenisme owenowenisme Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because we will get this error when the value is none.

 def test_approximate_quantile_ignores_nulls(self, ray_start_regular_shared_2_cpus):
        data = [
            {"id": 1, "value": 5.0},
            {"id": 2, "value": None},
            {"id": 3, "value": 15.0},
            {"id": 4, "value": None},
            {"id": 5, "value": 25.0},
        ]
        ds = ray.data.from_items(data)

        result = ds.aggregate(ApproximateQuantile(on="value", quantiles=[0.5]))
        assert result["approx_quantile(value)"] == [15.0]
TypeError: float() argument must be a string or a number, not 'pyarrow.lib.NullScalar'

sketch.update(float(value.as_py()))
return sketch.serialize()

def combine(self, current_accumulator: bytes, new: bytes) -> bytes:
combined = self.zero(self._k)
sketch_cls = self._require_datasketches()
combined.merge(sketch_cls.deserialize(current_accumulator))
combined.merge(sketch_cls.deserialize(new))
return combined.serialize()

def finalize(self, accumulator: bytes) -> List[float]:
sketch_cls = self._require_datasketches()
return sketch_cls.deserialize(accumulator).get_quantiles(self._quantiles)
3 changes: 3 additions & 0 deletions python/ray/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ray.data.aggregate import (
AggregateFnV2,
ApproximateQuantile,
Count,
Max,
Mean,
Expand All @@ -31,6 +32,7 @@ def numerical_aggregators(column: str) -> List[AggregateFnV2]:
- min
- max
- std
- approximate_quantile
- missing_value_percentage
- zero_percentage

Expand All @@ -46,6 +48,7 @@ def numerical_aggregators(column: str) -> List[AggregateFnV2]:
Min(on=column, ignore_nulls=True),
Max(on=column, ignore_nulls=True),
Std(on=column, ignore_nulls=True, ddof=0),
ApproximateQuantile(on=column, quantiles=[0.5]),
MissingValuePercentage(on=column),
ZeroPercentage(on=column, ignore_nulls=True),
]
Expand Down
87 changes: 86 additions & 1 deletion python/ray/data/tests/test_custom_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import pytest

import ray
from ray.data.aggregate import MissingValuePercentage, ZeroPercentage
from ray.data.aggregate import (
ApproximateQuantile,
MissingValuePercentage,
ZeroPercentage,
)
from ray.data.tests.conftest import * # noqa
from ray.tests.conftest import * # noqa

Expand Down Expand Up @@ -276,6 +280,87 @@ def test_zero_percentage_negative_values(self, ray_start_regular_shared_2_cpus):
assert result["zero_pct(value)"] == expected


class TestApproximateQuantile:
"""Test cases for ApproximateQuantile aggregation."""

def test_approximate_quantile_basic(self, ray_start_regular_shared_2_cpus):
"""Test basic approximate quantile calculation."""
data = [
{
"id": 1,
"value": 10,
},
{"id": 2, "value": 0},
{"id": 3, "value": 30},
{"id": 4, "value": 0},
{"id": 5, "value": 50},
]
ds = ray.data.from_items(data)

result = ds.aggregate(
ApproximateQuantile(on="value", quantiles=[0.1, 0.5, 0.9])
)
expected = [0.0, 10.0, 50.0]
assert result["approx_quantile(value)"] == expected

def test_approximate_quantile_ignores_nulls(self, ray_start_regular_shared_2_cpus):
data = [
{"id": 1, "value": 5.0},
{"id": 2, "value": None},
{"id": 3, "value": 15.0},
{"id": 4, "value": None},
{"id": 5, "value": 25.0},
]
ds = ray.data.from_items(data)

result = ds.aggregate(ApproximateQuantile(on="value", quantiles=[0.5]))
assert result["approx_quantile(value)"] == [15.0]

def test_approximate_quantile_custom_alias(self, ray_start_regular_shared_2_cpus):
data = [
{"id": 1, "value": 1.0},
{"id": 2, "value": 3.0},
{"id": 3, "value": 5.0},
{"id": 4, "value": 7.0},
{"id": 5, "value": 9.0},
]
ds = ray.data.from_items(data)

quantiles = [0.0, 1.0]
result = ds.aggregate(
ApproximateQuantile(
on="value", quantiles=quantiles, alias_name="value_range"
)
)

assert result["value_range"] == [1.0, 9.0]
assert len(result["value_range"]) == len(quantiles)

def test_approximate_quantile_groupby(self, ray_start_regular_shared_2_cpus):
data = [
{"group": "A", "value": 1.0},
{"group": "A", "value": 2.0},
{"group": "A", "value": 3.0},
{"group": "B", "value": 10.0},
{"group": "B", "value": 20.0},
{"group": "B", "value": 30.0},
]
ds = ray.data.from_items(data)

result = (
ds.groupby("group")
.aggregate(ApproximateQuantile(on="value", quantiles=[0.5]))
.take_all()
)

result_by_group = {
row["group"]: row["approx_quantile(value)"] for row in result
}

assert result_by_group["A"] == [2.0]
assert result_by_group["B"] == [20.0]


if __name__ == "__main__":
import sys

Expand Down
39 changes: 21 additions & 18 deletions python/ray/data/tests/test_dataset_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ray
from ray.data.aggregate import (
ApproximateQuantile,
Count,
Max,
Mean,
Expand Down Expand Up @@ -51,8 +52,8 @@ def test_numerical_columns_detection(self):
assert len(feature_aggs.vector_columns) == 0

# Check that we have the right number of aggregators
# 3 numerical columns * 7 aggregators each + 1 string column * 2 aggregators = 23 total
assert len(feature_aggs.aggregators) == 23
# 3 numerical columns * 8 aggregators each + 1 string column * 2 aggregators = 26 total
assert len(feature_aggs.aggregators) == 26

def test_categorical_columns_detection(self):
"""Test that string columns are correctly identified as categorical."""
Expand All @@ -74,8 +75,8 @@ def test_categorical_columns_detection(self):
assert "value" in feature_aggs.numerical_columns
assert "category" not in feature_aggs.numerical_columns

# Check aggregator count: 1 numerical * 7 + 2 categorical * 2 = 11
assert len(feature_aggs.aggregators) == 11
# Check aggregator count: 1 numerical * 8 + 2 categorical * 2 = 12
assert len(feature_aggs.aggregators) == 12

def test_vector_columns_detection(self):
"""Test that list columns are correctly identified as vector columns."""
Expand All @@ -97,8 +98,8 @@ def test_vector_columns_detection(self):
assert "scalar" in feature_aggs.numerical_columns
assert "text" in feature_aggs.str_columns

# Check aggregator count: 1 numerical * 7 + 1 categorical * 2 + 1 vector * 2 = 11
assert len(feature_aggs.aggregators) == 11
# Check aggregator count: 1 numerical * 8 + 1 categorical * 2 + 1 vector * 2 = 12
assert len(feature_aggs.aggregators) == 12

def test_mixed_column_types(self):
"""Test dataset with all column types mixed together."""
Expand Down Expand Up @@ -130,8 +131,8 @@ def test_mixed_column_types(self):
# bool_val should be treated as numerical (integer-like)
assert "bool_val" in feature_aggs.numerical_columns

# Check aggregator count: 3 numerical * 7 + 1 categorical * 2 + 1 vector * 2 = 25
assert len(feature_aggs.aggregators) == 25
# Check aggregator count: 3 numerical * 8 + 1 categorical * 2 + 1 vector * 2 = 28
assert len(feature_aggs.aggregators) == 28

def test_column_filtering(self):
"""Test that only specified columns are included when columns parameter is provided."""
Expand All @@ -151,8 +152,8 @@ def test_column_filtering(self):
assert "col3" in feature_aggs.vector_columns
assert "col4" not in feature_aggs.numerical_columns

# Check aggregator count: 1 numerical * 7 + 1 vector * 2 = 9
assert len(feature_aggs.aggregators) == 9
# Check aggregator count: 1 numerical * 8 + 1 vector * 2 = 10
assert len(feature_aggs.aggregators) == 10

def test_empty_dataset_schema(self):
"""Test behavior with empty dataset that has no schema."""
Expand Down Expand Up @@ -199,8 +200,8 @@ def test_unsupported_column_types(self):
assert "unsupported_binary" not in feature_aggs.str_columns
assert "unsupported_binary" not in feature_aggs.vector_columns

# Check aggregator count: 1 numerical * 7 + 1 categorical * 2 = 9
assert len(feature_aggs.aggregators) == 9
# Check aggregator count: 1 numerical * 8 + 1 categorical * 2 = 10
assert len(feature_aggs.aggregators) == 10

def test_aggregator_types_verification(self):
"""Test that the correct aggregator types are generated for each column type."""
Expand All @@ -215,16 +216,17 @@ def test_aggregator_types_verification(self):
# Check that we have the right types of aggregators
agg_names = [agg.name for agg in feature_aggs.aggregators]

# Numerical aggregators should include all 7 types
# Numerical aggregators should include all 8 types
num_agg_names = [name for name in agg_names if "num" in name]
assert len(num_agg_names) == 7
assert len(num_agg_names) == 8
assert any("count" in name.lower() for name in num_agg_names)
assert any("mean" in name.lower() for name in num_agg_names)
assert any("min" in name.lower() for name in num_agg_names)
assert any("max" in name.lower() for name in num_agg_names)
assert any("std" in name.lower() for name in num_agg_names)
assert any("missing" in name.lower() for name in num_agg_names)
assert any("zero" in name.lower() for name in num_agg_names)
assert any("approx_quantile" in name.lower() for name in num_agg_names)

# Categorical aggregators should include count and missing percentage
cat_agg_names = [name for name in agg_names if "cat" in name]
Expand All @@ -246,7 +248,7 @@ def test_aggregator_instances_verification(self):

# Find aggregators for the numerical column
num_aggs = [agg for agg in feature_aggs.aggregators if "num" in agg.name]
assert len(num_aggs) == 7
assert len(num_aggs) == 8

# Check that we have the right aggregator types
agg_types = [type(agg) for agg in num_aggs]
Expand All @@ -257,6 +259,7 @@ def test_aggregator_instances_verification(self):
assert Std in agg_types
assert MissingValuePercentage in agg_types
assert ZeroPercentage in agg_types
assert ApproximateQuantile in agg_types

# Find aggregators for the categorical column
cat_aggs = [agg for agg in feature_aggs.aggregators if "cat" in agg.name]
Expand Down Expand Up @@ -352,8 +355,8 @@ def test_large_dataset_performance(self):
assert "category" in feature_aggs.str_columns
assert "vector" in feature_aggs.vector_columns

# Check aggregator count: 2 numerical * 7 + 1 categorical * 2 + 1 vector * 2 = 18
assert len(feature_aggs.aggregators) == 18
# Check aggregator count: 2 numerical * 8 + 1 categorical * 2 + 1 vector * 2 = 20
assert len(feature_aggs.aggregators) == 20


class TestIndividualAggregatorFunctions:
Expand All @@ -363,7 +366,7 @@ def test_numerical_aggregators(self):
"""Test numerical_aggregators function."""
aggs = numerical_aggregators("test_column")

assert len(aggs) == 7
assert len(aggs) == 8
assert all(hasattr(agg, "get_target_column") for agg in aggs)
assert all(agg.get_target_column() == "test_column" for agg in aggs)

Expand Down
1 change: 1 addition & 0 deletions python/requirements/ml/data-test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ pyiceberg[sql-sqlite]==0.9.0
clickhouse-connect
pybase64
hudi==0.4.0
datasketches
3 changes: 3 additions & 0 deletions python/requirements_compiled.txt
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ datasets==3.6.0
# -r python/requirements/ml/data-test-requirements.txt
# -r python/requirements/ml/train-requirements.txt
# evaluate
datasketches==5.2.0
# via -r python/requirements/ml/data-test-requirements.txt
debugpy==1.8.0
# via ipykernel
decorator==5.1.1
Expand Down Expand Up @@ -1247,6 +1249,7 @@ numpy==1.26.4
# cupy-cuda12x
# dask
# datasets
# datasketches
# decord
# deepspeed
# dm-control
Expand Down