Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to create a PixelDistribution from a sky footprint. #806

Merged
merged 2 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/toast/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ install(FILES
rng.py
qarray.py
fft.py
footprint.py
healpix.py
weather.py
schedule.py
Expand Down
160 changes: 160 additions & 0 deletions src/toast/footprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2024-2025 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.

import numpy as np
import healpy as hp
import astropy.io.fits as af

from .pixels import PixelData, PixelDistribution
from .pixels_io_healpix import read_healpix
from .pixels_io_wcs import read_wcs_fits


def footprint_distribution(
healpix_nside=None,
healpix_nside_submap=None,
healpix_submap_file=None,
healpix_coverage_file=None,
wcs_coverage_file=None,
comm=None,
):
"""Create a PixelDistribution from a pre-defined sky footprint.

Usually a PixelDistribution object is created by passing through the detector
pointing and determining the locally hit submaps. However, this can be expensive
if the data must be loaded from disk and if there is insufficient memory to hold
the detector data in a persistent way.

This function provides a way for building a PixelDistribution where all processes
have the full footprint locally, regardless of whether their local detector
pointing hits all submaps. For high resolution sky products with many processes
per node, use of shared memory may be required.

Only certain combinations of options are supported:

1. If `wcs_coverage_file` is specified, that is taken to be the WCS projection
of the coverage. The number of pixels is set by the extent of the WCS, NOT
the actual pixel values. The number of submaps is set to one. All healpix
options should be None.
2. If `healpix_coverage_file` is specified, the NSIDE of the file is used to
define the number of pixels and the non-zero pixel values along with
`healpix_nside_submap` is used to compute the nonzero submaps in this coverage.
The same hit submaps are used across all processes.
3. If `healpix_submap_file` is specified, non-zero values represent the hit
submaps. `healpix_nside` is then used to define the NSIDE and the number of
pixels.
4. If neither file is specified, `healpix_nside` is used to define the NSIDE and
number of pixels. `healpix_nside_submap` is used to compute the number of
submaps. All submaps are considered hit in this case.

Args:
healpix_nside (int): If specified, the NSIDE of the coverage map.
healpix_nside_submap (int): If specified, the NSIDE of the submaps.
healpix_coverage_file (str): The path to a coverage map.
healpix_submap_file (str): The path to a map with the submaps to use.
wcs_coverage_file (str): The path to a WCS coverage map in the primary HDU.
comm (MPI.Comm): The MPI communicator or None.

Returns:
(PixelDistribution): The output pixel distribution.

"""
rank = 0
if comm is not None:
rank = comm.rank

if wcs_coverage_file is not None:
# Load a WCS projection
if (
healpix_nside is not None
or healpix_nside_submap is not None
or healpix_coverage_file is not None
or healpix_submap_file is not None
):
msg = "If loading a wcs coverage file, all other options should be None"
raise RuntimeError(msg)
n_pix = None
if rank == 0:
hdulist = af.open(wcs_coverage_file)
n_pix = np.prod(hdulist[0].data.shape)
hdulist.close()
del hdulist
if comm is not None:
n_pix = comm.bcast(n_pix, root=0)
n_submap = 1
local_submaps = [0]
elif healpix_coverage_file is not None:
if healpix_nside_submap is None:
msg = "You must specify the submap NSIDE to use with the coverage file"
raise RuntimeError(msg)
n_pix = None
n_submap = None
local_submaps = None
if rank == 0:
hpix_data = read_healpix(healpix_coverage_file, field=(0,), nest=True)
nside = hp.get_nside(hpix_data)
n_pix = 12 * nside**2
n_submap = 12 * healpix_nside_submap**2

# Find hit pixels
hit_pixels = np.logical_and(
hpix_data != 0,
hp.mask_good(hpix_data),
)
unhit_pixels = np.logical_not(hit_pixels)

# Set map data to one or zero so we can find hit submaps
hpix_data[hit_pixels] = 1
hpix_data[unhit_pixels] = 0

# Degrade to submap resolution
submap_data = hp.ud_grade(
hpix_data, healpix_nside_submap, order_in="NEST", order_out="NEST"
)

# Find hit submaps
hit_submaps = submap_data > 0
local_submaps = np.arange(12 * healpix_nside_submap**2, dtype=np.int32)[
hit_submaps
]
if comm is not None:
n_pix = comm.bcast(n_pix, root=0)
n_submap = comm.bcast(n_submap, root=0)
local_submaps = comm.bcast(local_submaps, root=0)
elif healpix_submap_file is not None:
if healpix_nside is None:
msg = "You must specify the coverage NSIDE to use with the submap file"
raise RuntimeError(msg)
n_pix = None
n_submap = None
local_submaps = None
if rank == 0:
submap_data = read_healpix(healpix_submap_file, field=(0,), nest=True)
nside_submap = hp.npix2nside(len(submap_data))
n_submap = 12 * nside_submap**2
n_pix = 12 * healpix_nside**2

# Find hit submaps
hit_submaps = np.logical_and(
submap_data != 0,
hp.mask_good(submap_data),
)
local_submaps = np.arange(n_submap, dtype=np.int32)[hit_submaps]
if comm is not None:
n_pix = comm.bcast(n_pix, root=0)
n_submap = comm.bcast(n_submap, root=0)
local_submaps = comm.bcast(local_submaps, root=0)
else:
if healpix_nside is None:
msg = "No files specified, you must set healpix_nside"
raise RuntimeError(msg)
if healpix_nside_submap is None:
msg = "No files specified, you must set healpix_nside_submap"
raise RuntimeError(msg)
n_pix = 12 * healpix_nside**2
n_submap = 12 * healpix_nside_submap**2
local_submaps = np.arange(n_submap, dtype=np.int32)
return PixelDistribution(
n_pix=n_pix, n_submap=n_submap, local_submaps=local_submaps, comm=comm
)
1 change: 1 addition & 0 deletions src/toast/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ install(FILES
math_misc.py
qarray.py
fft.py
footprint.py
healpix.py
config.py
observation.py
Expand Down
129 changes: 129 additions & 0 deletions src/toast/tests/footprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) 2024-2025 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.

import os

import astropy.io.fits as af
import healpy as hp
import numpy as np
import numpy.testing as nt
from astropy import units as u

from .. import ops
from ..footprint import footprint_distribution
from ._helpers import create_outdir
from .mpi import MPITestCase


class FootprintTest(MPITestCase):
def setUp(self):
fixture_name = os.path.splitext(os.path.basename(__file__))[0]
self.outdir = create_outdir(self.comm, fixture_name)
self.wcs_proj_dims = (1000, 500)
self.nside = 128
self.nside_submap = 16

def tearDown(self):
pass

def _create_wcs_coverage(self, outfile):
res_deg = (0.01, 0.01)
dims = self.wcs_proj_dims
center_deg = (130.0, -30.0)
wcs, wcs_shape = ops.PixelsWCS.create_wcs(
coord="EQU",
proj="CAR",
center_deg=center_deg,
bounds_deg=None,
res_deg=res_deg,
dims=dims,
)
if self.comm is None or self.comm.rank == 0:
pixdata = np.ones((1, wcs_shape[1], wcs_shape[0]), dtype=np.float32)
header = wcs.to_header()
hdu = af.PrimaryHDU(data=pixdata, header=header)
hdu.writeto(outfile)
return wcs, wcs_shape

def _create_healpix_coverage(self, nside, nside_submap, outfile, is_submap=False):
n_submap = 12 * nside_submap**2
hit_submaps = None
if self.comm is None or self.comm.rank == 0:
# Randomly select some submaps
subvals = [True, False]
hit_submaps = np.random.choice(subvals, size=(n_submap,)).astype(bool)
if self.comm is not None:
hit_submaps = self.comm.bcast(hit_submaps, root=0)
if self.comm is None or self.comm.rank == 0:
sub_pixels = np.zeros(n_submap, dtype=np.int32)
sub_pixels[hit_submaps] = 1
if is_submap:
# Write it out and we are done
hp.write_map(outfile, sub_pixels, nest=True)
else:
# Compute the full-resolution map and write that
pixels = hp.ud_grade(
sub_pixels, nside, order_in="NEST", order_out="NEST"
)
hp.write_map(outfile, pixels, nest=True)
return hit_submaps

def test_wcs(self):
footfile = os.path.join(self.outdir, "wcs_footprint.fits")
wcs, wcs_shape = self._create_wcs_coverage(footfile)
dist = footprint_distribution(wcs_coverage_file=footfile, comm=self.comm)

# Check that the distribution has expected properties
n_pix = np.prod(wcs_shape)
self.assertTrue(dist.n_submap == 1)
self.assertTrue(n_pix == dist.n_pix)
self.assertTrue(n_pix == dist.n_pix_submap)
self.assertTrue(dist.local_submaps[0] == 0)

def test_healpix(self):
n_submap = 12 * self.nside_submap**2
n_pix = 12 * self.nside**2
n_pix_submap = n_pix // n_submap

# Create a distribution from healpix footprint file
footfile = os.path.join(self.outdir, "healpix_footprint.fits")
hit_submaps = self._create_healpix_coverage(
self.nside, self.nside_submap, footfile, is_submap=False
)
dist = footprint_distribution(
healpix_coverage_file=footfile,
healpix_nside_submap=self.nside_submap,
comm=self.comm,
)
self.assertTrue(dist.n_submap == n_submap)
self.assertTrue(dist.n_pix == n_pix)
self.assertTrue(dist.n_pix_submap == n_pix_submap)
check_submaps = np.arange(n_submap, dtype=np.int64)[hit_submaps]
self.assertTrue(np.array_equal(dist.local_submaps, check_submaps))

# Create a distribution from healpix submap footprint file
footfile = os.path.join(self.outdir, "healpix_submap_footprint.fits")
hit_submaps = self._create_healpix_coverage(
self.nside, self.nside_submap, footfile, is_submap=True
)
dist = footprint_distribution(
healpix_submap_file=footfile, healpix_nside=self.nside, comm=self.comm
)
self.assertTrue(dist.n_submap == n_submap)
self.assertTrue(dist.n_pix == n_pix)
self.assertTrue(dist.n_pix_submap == n_pix_submap)
check_submaps = np.arange(n_submap, dtype=np.int64)[hit_submaps]
self.assertTrue(np.array_equal(dist.local_submaps, check_submaps))

# Now check manual creation of a full-sky healpix footprint
dist = footprint_distribution(
healpix_nside=self.nside,
healpix_nside_submap=self.nside_submap,
comm=self.comm,
)
self.assertTrue(dist.n_submap == n_submap)
self.assertTrue(dist.n_pix == n_pix)
self.assertTrue(dist.n_pix_submap == n_pix_submap)
check_submaps = np.arange(n_submap, dtype=np.int64)
self.assertTrue(np.array_equal(dist.local_submaps, check_submaps))
3 changes: 1 addition & 2 deletions src/toast/tests/ops_pointing_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def check_hits(self, prefix, pixels, data):
)

def test_wcs(self):
return
# Test basic creation of WCS projections and plotting
res_deg = (0.01, 0.01)
dims = self.proj_dims
Expand Down Expand Up @@ -161,7 +160,6 @@ def test_wcs(self):
plot_wcs_maps(hitfile=outfile)

def test_projections(self):
return
centers = list()
for lon in [130.0, 180.0]:
for lat in [-40.0, 0.0]:
Expand All @@ -184,6 +182,7 @@ def test_projections(self):
# Verify that we can change the projection traits in various ways.
# First use non-auto_bounds to create one boresight pointing per
# pixel.
pixels.auto_bounds = False
pixels.center = center
pixels.bounds = ()
pixels.resolution = (0.02 * u.degree, 0.02 * u.degree)
Expand Down
2 changes: 2 additions & 0 deletions src/toast/tests/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import dist as test_dist
from . import env as test_env
from . import fft as test_fft
from . import footprint as test_footprint
from . import healpix as test_healpix
from . import instrument as test_instrument
from . import intervals as test_intervals
Expand Down Expand Up @@ -181,6 +182,7 @@ def test(name=None, verbosity=2):
suite.addTest(loader.loadTestsFromModule(test_instrument))
suite.addTest(loader.loadTestsFromModule(test_pixels))
suite.addTest(loader.loadTestsFromModule(test_weather))
suite.addTest(loader.loadTestsFromModule(test_footprint))

suite.addTest(loader.loadTestsFromModule(test_observation))
suite.addTest(loader.loadTestsFromModule(test_dist))
Expand Down