Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Roman prism and grism bands #104

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions skycatalogs/objects/base_object.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence, Iterable
from collections import namedtuple
from packaging import version
import os
import itertools
import numpy as np
Expand All @@ -8,6 +9,12 @@
from galsim.roman import shortwave_bands as roman_shortwave_bands
from galsim.roman import getBandpasses as roman_getBandpasses

if version.parse(galsim.version) >= version.parse('2.6.0'):
from galsim.roman import non_imaging_bands as roman_non_imaging_bands
else:
# galsim 2.6.0 and earlier do not have non_imaging_bands
roman_non_imaging_bands = []

from skycatalogs.utils.translate_utils import form_object_string
from skycatalogs.utils.config_utils import Config

Expand All @@ -22,7 +29,9 @@
'load_lsst_bandpasses', 'load_roman_bandpasses']

LSST_BANDS = ('ugrizy')
ROMAN_BANDS = roman_shortwave_bands+roman_longwave_bands
ROMAN_BANDS = roman_shortwave_bands + roman_longwave_bands \
+ roman_non_imaging_bands


# global for easy access for code run within mp

Expand Down Expand Up @@ -73,13 +82,15 @@ def load_lsst_bandpasses():
return lsst_bandpasses


def load_roman_bandpasses():
def load_roman_bandpasses(**kwargs):
'''
Read in Roman bandpasses from standard place, trim, and store in global dict
Returns: The dict
'''
global roman_bandpasses
roman_bandpasses = roman_getBandpasses()
if version.parse(galsim.version) < version.parse('2.6.0'):
kwargs.pop("include_non_imaging_bands", None)
roman_bandpasses = roman_getBandpasses(**kwargs)
return roman_bandpasses


Expand Down
6 changes: 5 additions & 1 deletion skycatalogs/objects/snana_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
'H158': 'H',
'F184': 'F',
'K213': 'K',
'W146': 'W'}
'W146': 'W',
'SNPrism': 'S',
'Grism_0thOrder': 'G0',
'Grism_1stOrder': 'G1',
}

class SnanaObject(BaseObject):
_type_name = 'snana'
Expand Down
2 changes: 1 addition & 1 deletion skycatalogs/skyCatalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False):
'''
# Get bandpasses in case we need to compute fluxes
_ = load_lsst_bandpasses()
_ = load_roman_bandpasses()
_ = load_roman_bandpasses(include_non_imaging_bands=True)

from skycatalogs.utils.config_utils import open_config_file

Expand Down
34 changes: 27 additions & 7 deletions skycatalogs/utils/parquet_schema_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pyarrow as pa
import logging
from galsim import version as galsim_version
from packaging import version

__all__ = ['make_galaxy_schema', 'make_galaxy_flux_schema',
'make_pointsource_schema', 'make_star_flux_schema']


def _add_roman_fluxes(fields):
def _add_roman_fluxes(fields, include_all_bands=False):
fields += [pa.field('roman_flux_W146', pa.float32(), True),
pa.field('roman_flux_R062', pa.float32(), True),
pa.field('roman_flux_Z087', pa.float32(), True),
Expand All @@ -14,6 +16,12 @@ def _add_roman_fluxes(fields):
pa.field('roman_flux_H158', pa.float32(), True),
pa.field('roman_flux_F184', pa.float32(), True),
pa.field('roman_flux_K213', pa.float32(), True)]

if include_all_bands and (version.parse(galsim_version) >= version.parse("2.6.0")):
fields += [pa.field('roman_flux_SNPrism', pa.float32(), True),
pa.field('roman_flux_Grism_0thOrder', pa.float32(), True),
pa.field('roman_flux_Grism_1stOrder', pa.float32(), True)]

return fields


Expand Down Expand Up @@ -101,7 +109,8 @@ def make_galaxy_schema(logname, sed_subdir=False, knots=True,


def make_galaxy_flux_schema(logname, galaxy_type='cosmodc2',
include_roman_flux=False):
include_roman_flux=False,
include_nonimaging_roman_bands=False):
'''
Will make a separate parquet file with lsst flux for each band
and galaxy id for joining with the main galaxy file
Expand All @@ -117,11 +126,15 @@ def make_galaxy_flux_schema(logname, galaxy_type='cosmodc2',
pa.field('lsst_flux_z', pa.float32(), True),
pa.field('lsst_flux_y', pa.float32(), True)]
if include_roman_flux:
fields = _add_roman_fluxes(fields)
fields = _add_roman_fluxes(
fields,
include_all_bands=include_nonimaging_roman_bands
)
return pa.schema(fields)


def make_star_flux_schema(logname, include_roman_flux=False):
def make_star_flux_schema(logname, include_roman_flux=False,
include_nonimaging_roman_bands=False):
'''
Will make a separate parquet file with lsst flux for each band
and id for joining with the main star file
Expand All @@ -136,7 +149,10 @@ def make_star_flux_schema(logname, include_roman_flux=False):
pa.field('lsst_flux_z', pa.float32(), True),
pa.field('lsst_flux_y', pa.float32(), True)]
if include_roman_flux:
fields = _add_roman_fluxes(fields)
fields = _add_roman_fluxes(
fields,
include_all_bands=include_nonimaging_roman_bands,
)
return pa.schema(fields)


Expand Down Expand Up @@ -172,7 +188,8 @@ def make_pointsource_schema():
return pa.schema(fields)


def make_pointsource_flux_schema(logname, include_roman_flux=False):
def make_pointsource_flux_schema(logname, include_roman_flux=False,
include_nonimaging_roman_bands=False):
'''
Will make a separate parquet file with lsst flux for each band
and id for joining with the main star file.
Expand All @@ -190,5 +207,8 @@ def make_pointsource_flux_schema(logname, include_roman_flux=False):
pa.field('lsst_flux_y', pa.float32(), True),
pa.field('mjd', pa.float64(), True)]
if include_roman_flux:
fields = _add_roman_fluxes(fields)
fields = _add_roman_fluxes(
fields,
include_all_bands=include_nonimaging_roman_bands,
)
return pa.schema(fields)