Skip to content

Commit

Permalink
Merge pull request #310 from astronomy-commons/add-arrow-schema-to-ca…
Browse files Browse the repository at this point in the history
…talog

Store arrow schema when reading catalogs
  • Loading branch information
camposandro authored Jul 25, 2024
2 parents 047600e + 6977e02 commit 2c625ce
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Tuple, Union

import pandas as pd
import pyarrow as pa
from mocpy import MOC
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -35,11 +36,14 @@ def __init__(
join_pixels: JoinPixelInputTypes,
catalog_path=None,
moc: MOC | None = None,
schema: pa.Schema | None = None,
storage_options: Union[Dict[Any, Any], None] = None,
) -> None:
if not catalog_info.catalog_type == CatalogType.ASSOCIATION:
raise ValueError("Catalog info `catalog_type` must be 'association'")
super().__init__(catalog_info, pixels, catalog_path, moc=moc, storage_options=storage_options)
super().__init__(
catalog_info, pixels, catalog_path, moc=moc, schema=schema, storage_options=storage_options
)
self.join_info = self._get_partition_join_info_from_pixels(join_pixels)

def get_join_pixels(self) -> pd.DataFrame:
Expand Down
12 changes: 10 additions & 2 deletions src/hipscat/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import pyarrow as pa
from mocpy import MOC
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
pixels: PixelInputTypes,
catalog_path: str = None,
moc: MOC | None = None,
schema: pa.Schema | None = None,
storage_options: Union[Dict[Any, Any], None] = None,
) -> None:
"""Initializes a Catalog
Expand All @@ -56,16 +58,22 @@ def __init__(
list of HealpixPixel, `PartitionInfo object`, or a `PixelTree` object
catalog_path: If the catalog is stored on disk, specify the location of the catalog
Does not load the catalog from this path, only store as metadata
storage_options: dictionary that contains abstract filesystem credentials
moc (mocpy.MOC): MOC object representing the coverage of the catalog
schema (pa.Schema): The pyarrow schema for the catalog
storage_options: dictionary that contains abstract filesystem credentials
"""
if catalog_info.catalog_type not in self.HIPS_CATALOG_TYPES:
raise ValueError(
f"Catalog info `catalog_type` must be one of "
f"{', '.join([t.value for t in self.HIPS_CATALOG_TYPES])}"
)
super().__init__(
catalog_info, pixels, catalog_path=catalog_path, moc=moc, storage_options=storage_options
catalog_info,
pixels,
catalog_path=catalog_path,
moc=moc,
schema=schema,
storage_options=storage_options,
)

def filter_by_cone(self, ra: float, dec: float, radius_arcsec: float) -> Catalog:
Expand Down
33 changes: 30 additions & 3 deletions src/hipscat/catalog/healpix_dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from __future__ import annotations

import dataclasses
import warnings
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import pandas as pd
import pyarrow as pa
from mocpy import MOC
from typing_extensions import Self, TypeAlias

import hipscat.pixel_math.healpix_shim as hp
from hipscat.catalog.dataset import BaseCatalogInfo, Dataset
from hipscat.catalog.partition_info import PartitionInfo
from hipscat.io import FilePointer, file_io, paths
from hipscat.io.file_io import read_parquet_metadata
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_tree import PixelAlignment, PixelAlignmentType
from hipscat.pixel_tree.moc_filter import filter_by_moc
Expand Down Expand Up @@ -39,6 +42,7 @@ def __init__(
pixels: PixelInputTypes,
catalog_path: str = None,
moc: MOC | None = None,
schema: pa.Schema | None = None,
storage_options: Union[Dict[Any, Any], None] = None,
) -> None:
"""Initializes a Catalog
Expand All @@ -49,13 +53,15 @@ def __init__(
list of HealpixPixel, `PartitionInfo object`, or a `PixelTree` object
catalog_path: If the catalog is stored on disk, specify the location of the catalog
Does not load the catalog from this path, only store as metadata
storage_options: dictionary that contains abstract filesystem credentials
moc (mocpy.MOC): MOC object representing the coverage of the catalog
schema (pa.Schema): The pyarrow schema for the catalog
storage_options: dictionary that contains abstract filesystem credentials
"""
super().__init__(catalog_info, catalog_path=catalog_path, storage_options=storage_options)
self.partition_info = self._get_partition_info_from_pixels(pixels)
self.pixel_tree = self._get_pixel_tree_from_pixels(pixels)
self.moc = moc
self.schema = schema

def get_healpix_pixels(self) -> List[HealpixPixel]:
"""Get healpix pixel objects for all pixels contained in the catalog.
Expand Down Expand Up @@ -101,6 +107,7 @@ def _read_kwargs(
) -> dict:
kwargs = super()._read_kwargs(catalog_base_dir, storage_options=storage_options)
kwargs["moc"] = cls._read_moc_from_point_map(catalog_base_dir, storage_options)
kwargs["schema"] = cls._read_schema_from_metadata(catalog_base_dir, storage_options)
return kwargs

@classmethod
Expand All @@ -118,10 +125,30 @@ def _read_moc_from_point_map(
orders = np.full(ipix.shape, order)
return MOC.from_healpix_cells(ipix, orders, order)

@classmethod
def _read_schema_from_metadata(
cls, catalog_base_dir: FilePointer, storage_options: dict | None = None
) -> pa.Schema | None:
"""Reads the schema information stored in the _common_metadata or _metadata files."""
common_metadata_file = paths.get_common_metadata_pointer(catalog_base_dir)
common_metadata_exists = file_io.does_file_or_directory_exist(
common_metadata_file, storage_options=storage_options
)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
metadata_exists = file_io.does_file_or_directory_exist(metadata_file, storage_options=storage_options)
if not (common_metadata_exists or metadata_exists):
warnings.warn(
"_common_metadata or _metadata files not found for this catalog."
"The arrow schema will not be set."
)
return None
schema_file = common_metadata_file if common_metadata_exists else metadata_file
metadata = read_parquet_metadata(schema_file, storage_options=storage_options)
return metadata.schema.to_arrow_schema()

@classmethod
def _check_files_exist(cls, catalog_base_dir: FilePointer, storage_options: dict = None):
super()._check_files_exist(catalog_base_dir, storage_options=storage_options)

partition_info_file = paths.get_partition_info_pointer(catalog_base_dir)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
if not (
Expand Down Expand Up @@ -170,7 +197,7 @@ def filter_by_moc(self, moc: MOC) -> Self:
filtered_tree = filter_by_moc(self.pixel_tree, moc)
filtered_moc = self.moc.intersection(moc) if self.moc is not None else None
filtered_catalog_info = dataclasses.replace(self.catalog_info, total_rows=None)
return self.__class__(filtered_catalog_info, filtered_tree, moc=filtered_moc)
return self.__class__(filtered_catalog_info, filtered_tree, moc=filtered_moc, schema=self.schema)

def align(
self, other_cat: Self, alignment_type: PixelAlignmentType = PixelAlignmentType.INNER
Expand Down
12 changes: 10 additions & 2 deletions src/hipscat/catalog/margin_cache/margin_catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import pyarrow as pa
from mocpy import MOC
from typing_extensions import Self, TypeAlias

Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(
pixels: PixelInputTypes,
catalog_path: str = None,
moc: MOC | None = None,
schema: pa.Schema | None = None,
storage_options: dict | None = None,
) -> None:
"""Initializes a Margin Catalog
Expand All @@ -40,13 +42,19 @@ def __init__(
list of HealpixPixel, `PartitionInfo object`, or a `PixelTree` object
catalog_path: If the catalog is stored on disk, specify the location of the catalog
Does not load the catalog from this path, only store as metadata
storage_options: dictionary that contains abstract filesystem credentials
moc (mocpy.MOC): MOC object representing the coverage of the catalog
schema (pa.Schema): The pyarrow schema for the catalog
storage_options: dictionary that contains abstract filesystem credentials
"""
if catalog_info.catalog_type != CatalogType.MARGIN:
raise ValueError(f"Catalog info `catalog_type` must equal {CatalogType.MARGIN}")
super().__init__(
catalog_info, pixels, catalog_path=catalog_path, moc=moc, storage_options=storage_options
catalog_info,
pixels,
catalog_path=catalog_path,
moc=moc,
schema=schema,
storage_options=storage_options,
)

def filter_by_moc(self, moc: MOC) -> Self:
Expand Down
71 changes: 71 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

import pandas as pd
import pyarrow as pa
import pytest

from hipscat.catalog.association_catalog.association_catalog_info import AssociationCatalogInfo
Expand Down Expand Up @@ -171,6 +172,76 @@ def index_catalog_info_with_extra() -> dict:
}


@pytest.fixture
def small_sky_schema() -> pa.Schema:
return pa.schema(
[
pa.field("id", pa.int64()),
pa.field("ra", pa.float64()),
pa.field("dec", pa.float64()),
pa.field("ra_error", pa.int64()),
pa.field("dec_error", pa.int64()),
pa.field("Norder", pa.uint8()),
pa.field("Dir", pa.uint64()),
pa.field("Npix", pa.uint64()),
pa.field("_hipscat_index", pa.uint64()),
]
)


@pytest.fixture
def small_sky_source_schema() -> pa.Schema:
return pa.schema(
[
pa.field("source_id", pa.int64()),
pa.field("source_ra", pa.float64()),
pa.field("source_dec", pa.float64()),
pa.field("mjd", pa.float64()),
pa.field("mag", pa.float64()),
pa.field("band", pa.string()),
pa.field("object_id", pa.int64()),
pa.field("object_ra", pa.float64()),
pa.field("object_dec", pa.float64()),
pa.field("Norder", pa.uint8()),
pa.field("Dir", pa.uint64()),
pa.field("Npix", pa.uint64()),
pa.field("_hipscat_index", pa.uint64()),
]
)


@pytest.fixture
def association_catalog_schema() -> pa.Schema:
return pa.schema(
[
pa.field("Norder", pa.int64()),
pa.field("Npix", pa.int64()),
pa.field("join_Norder", pa.int64()),
pa.field("join_Npix", pa.int64()),
]
)


@pytest.fixture
def margin_catalog_schema() -> pa.Schema:
return pa.schema(
[
pa.field("id", pa.int64()),
pa.field("ra", pa.float64()),
pa.field("dec", pa.float64()),
pa.field("ra_error", pa.int64()),
pa.field("dec_error", pa.int64()),
pa.field("Norder", pa.uint8()),
pa.field("Dir", pa.uint64()),
pa.field("Npix", pa.uint64()),
pa.field("_hipscat_index", pa.uint64()),
pa.field("margin_Norder", pa.uint8()),
pa.field("margin_Dir", pa.uint64()),
pa.field("margin_Npix", pa.uint64()),
]
)


@pytest.fixture
def dataset_path(test_data_dir) -> str:
return test_data_dir / "info_only" / "dataset"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

import pandas as pd
import pyarrow as pa
import pytest

from hipscat.catalog import CatalogType
Expand Down Expand Up @@ -49,7 +50,9 @@ def test_different_join_pixels_type(association_catalog_info, association_catalo
pd.testing.assert_frame_equal(catalog.get_join_pixels(), association_catalog_join_pixels)


def test_read_from_file(association_catalog_path, association_catalog_join_pixels):
def test_read_from_file(
association_catalog_path, association_catalog_join_pixels, association_catalog_schema
):
catalog = read_from_hipscat(association_catalog_path)

assert isinstance(catalog, AssociationCatalog)
Expand All @@ -66,6 +69,9 @@ def test_read_from_file(association_catalog_path, association_catalog_join_pixel
assert info.join_catalog == "small_sky_order1"
assert info.join_column == "id"

assert isinstance(catalog.schema, pa.Schema)
assert catalog.schema.equals(association_catalog_schema)


def test_empty_directory(tmp_path, association_catalog_info_data, association_catalog_join_pixels):
"""Test loading empty or incomplete data"""
Expand Down Expand Up @@ -121,5 +127,6 @@ def test_csv_round_trip(tmp_path, association_catalog_info_data, association_cat
part_info = PartitionJoinInfo(association_catalog_join_pixels)
part_info.write_to_csv(catalog_path=catalog_path)

catalog = read_from_hipscat(catalog_path)
with pytest.warns(UserWarning, match="_common_metadata or _metadata files not found"):
catalog = read_from_hipscat(catalog_path)
pd.testing.assert_frame_equal(catalog.get_join_pixels(), association_catalog_join_pixels)
6 changes: 5 additions & 1 deletion tests/hipscat/catalog/margin_cache/test_margin_catalog.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os

import pyarrow as pa
import pytest

from hipscat.catalog import CatalogType, MarginCatalog, PartitionInfo
Expand Down Expand Up @@ -32,7 +33,7 @@ def test_wrong_catalog_info_type(catalog_info, margin_catalog_pixels):
MarginCatalog(catalog_info, margin_catalog_pixels)


def test_read_from_file(margin_catalog_path, margin_catalog_pixels):
def test_read_from_file(margin_catalog_path, margin_catalog_pixels, margin_catalog_schema):
catalog = read_from_hipscat(margin_catalog_path)

assert isinstance(catalog, MarginCatalog)
Expand All @@ -50,6 +51,9 @@ def test_read_from_file(margin_catalog_path, margin_catalog_pixels):
assert info.primary_catalog == "small_sky_order1"
assert info.margin_threshold == 7200

assert isinstance(catalog.schema, pa.Schema)
assert catalog.schema.equals(margin_catalog_schema)


# pylint: disable=duplicate-code
def test_empty_directory(tmp_path, margin_cache_catalog_info_data, margin_catalog_pixels):
Expand Down
11 changes: 9 additions & 2 deletions tests/hipscat/catalog/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import astropy.units as u
import numpy as np
import pyarrow as pa
import pytest
from mocpy import MOC

Expand Down Expand Up @@ -79,14 +80,17 @@ def test_get_pixels_list(catalog_info, catalog_pixels):
assert pixels == catalog_pixels


def test_load_catalog_small_sky(small_sky_dir):
def test_load_catalog_small_sky(small_sky_dir, small_sky_schema):
"""Instantiate a catalog with 1 pixel"""
cat = read_from_hipscat(small_sky_dir)

assert isinstance(cat, Catalog)
assert cat.catalog_name == "small_sky"
assert len(cat.get_healpix_pixels()) == 1

assert isinstance(cat.schema, pa.Schema)
assert cat.schema.equals(small_sky_schema)


def test_load_catalog_small_sky_order1(small_sky_order1_dir):
"""Instantiate a catalog with 4 pixels"""
Expand All @@ -109,14 +113,17 @@ def test_load_catalog_small_sky_order1_moc(small_sky_order1_dir):
assert np.all(cat.moc.flatten() == np.where(counts_skymap > 0))


def test_load_catalog_small_sky_source(small_sky_source_dir):
def test_load_catalog_small_sky_source(small_sky_source_dir, small_sky_source_schema):
"""Instantiate a source catalog with 14 pixels"""
cat = read_from_hipscat(small_sky_source_dir)

assert isinstance(cat, Catalog)
assert cat.catalog_name == "small_sky_source"
assert len(cat.get_healpix_pixels()) == 14

assert isinstance(cat.schema, pa.Schema)
assert cat.schema.equals(small_sky_source_schema)


def test_max_coverage_order(small_sky_order1_catalog):
assert small_sky_order1_catalog.get_max_coverage_order() >= small_sky_order1_catalog.moc.max_order
Expand Down
Loading

0 comments on commit 2c625ce

Please sign in to comment.