Skip to content

Commit

Permalink
Introduce aggregate_column_statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu committed Oct 18, 2024
1 parent 632292a commit d5ce329
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
89 changes: 89 additions & 0 deletions src/hats/io/parquet_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pds
import pyarrow.parquet as pq
Expand Down Expand Up @@ -190,3 +191,91 @@ def read_row_group_fragments(metadata_file: str):

for frag in dataset.get_fragments():
yield from frag.row_groups


def aggregate_column_statistics(
metadata_file: str | Path | UPath,
exclude_hats_columns: bool = True,
exclude_columns: List[str] = None,
include_columns: List[str] = None,
):
"""Read footer statistics in parquet metadata, and report on global min/max values.
Args:
metadata_file (str | Path | UPath): path to `_metadata` file
exclude_hats_columns (bool): exclude HATS spatial and partitioning fields
from the statistics. Defaults to True.
exclude_columns (List[str]): additional columns to exclude from the statistics.
include_columns (List[str]): if specified, only return statistics for the column
names provided. Defaults to None, and returns all non-hats columns.
"""
total_metadata = file_io.read_parquet_metadata(metadata_file)
num_row_groups = total_metadata.num_row_groups
first_row_group = total_metadata.row_group(0)

if include_columns is None:
include_columns = []

if exclude_columns is None:
exclude_columns = []
if exclude_hats_columns:
exclude_columns.extend(["Norder", "Dir", "Npix", "_healpix_29"])

column_names = [
first_row_group.column(col).path_in_schema for col in range(0, first_row_group.num_columns)
]
good_column_indexes = [
index
for index, name in enumerate(column_names)
if (len(include_columns) == 0 or name in include_columns)
and not (len(exclude_columns) > 0 and name in exclude_columns)
]
column_names = [column_names[i] for i in good_column_indexes]
extrema = [
(
first_row_group.column(col).statistics.min,
first_row_group.column(col).statistics.max,
first_row_group.column(col).statistics.null_count,
)
for col in good_column_indexes
]

for row_group_index in range(1, num_row_groups):
row_group = total_metadata.row_group(row_group_index)
row_stats = [
(
row_group.column(col).statistics.min,
row_group.column(col).statistics.max,
row_group.column(col).statistics.null_count,
)
for col in good_column_indexes
]
## This is annoying, but avoids extra copies, or none comparison.
extrema = [
(
(
min(extrema[col][0], row_stats[col][0])
if row_stats[col][0] is not None
else extrema[col][0]
),
(
max(extrema[col][1], row_stats[col][1])
if row_stats[col][1] is not None
else extrema[col][1]
),
extrema[col][2] + row_stats[col][2],
)
for col in range(0, len(good_column_indexes))
]

stats_lists = np.array(extrema).T

frame = pd.DataFrame(
{
"column_names": column_names,
"min_value": stats_lists[0],
"max_value": stats_lists[1],
"null_count": stats_lists[2],
}
)
return frame
14 changes: 14 additions & 0 deletions tests/hats/io/test_parquet_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from hats.io import file_io, paths
from hats.io.parquet_metadata import (
aggregate_column_statistics,
get_healpix_pixel_from_metadata,
read_row_group_fragments,
row_group_stat_single_value,
Expand Down Expand Up @@ -144,6 +145,19 @@ def test_row_group_fragments_with_dir(small_sky_order1_dir):
assert num_row_groups == 4


def test_aggregate_column_statistics(small_sky_order1_dir):
partition_info_file = paths.get_parquet_metadata_pointer(small_sky_order1_dir)

result_frame = aggregate_column_statistics(partition_info_file)
assert len(result_frame) == 5

result_frame = aggregate_column_statistics(partition_info_file, exclude_hats_columns=False)
assert len(result_frame) == 9

result_frame = aggregate_column_statistics(partition_info_file, include_columns=["ra", "dec"])
assert len(result_frame) == 2


def test_row_group_stats(small_sky_dir):
partition_info_file = paths.get_parquet_metadata_pointer(small_sky_dir)
first_row_group = next(read_row_group_fragments(partition_info_file))
Expand Down

0 comments on commit d5ce329

Please sign in to comment.