From fd04b9645a1d5c361a4d676257e9ab8a7354fe1c Mon Sep 17 00:00:00 2001 From: Andrea Zonca Date: Mon, 27 Jul 2020 18:09:37 -0700 Subject: [PATCH 1/3] import from jwst 0.13.7 --- iris_pipeline/resample/__init__.py | 4 + iris_pipeline/resample/gwcs_drizzle.py | 457 ++++++++++++++++++ iris_pipeline/resample/resample.py | 177 +++++++ iris_pipeline/resample/resample_spec.py | 297 ++++++++++++ iris_pipeline/resample/resample_spec_step.py | 126 +++++ iris_pipeline/resample/resample_step.py | 206 ++++++++ iris_pipeline/resample/resample_utils.py | 193 ++++++++ iris_pipeline/resample/tests/__init__.py | 0 .../resample/tests/test_interface.py | 16 + .../resample/tests/test_resample_spec.py | 86 ++++ iris_pipeline/resample/tests/test_utils.py | 34 ++ 11 files changed, 1596 insertions(+) create mode 100644 iris_pipeline/resample/__init__.py create mode 100644 iris_pipeline/resample/gwcs_drizzle.py create mode 100644 iris_pipeline/resample/resample.py create mode 100644 iris_pipeline/resample/resample_spec.py create mode 100755 iris_pipeline/resample/resample_spec_step.py create mode 100755 iris_pipeline/resample/resample_step.py create mode 100644 iris_pipeline/resample/resample_utils.py create mode 100644 iris_pipeline/resample/tests/__init__.py create mode 100644 iris_pipeline/resample/tests/test_interface.py create mode 100644 iris_pipeline/resample/tests/test_resample_spec.py create mode 100644 iris_pipeline/resample/tests/test_utils.py diff --git a/iris_pipeline/resample/__init__.py b/iris_pipeline/resample/__init__.py new file mode 100644 index 0000000..15e44b8 --- /dev/null +++ b/iris_pipeline/resample/__init__.py @@ -0,0 +1,4 @@ +from .resample_step import ResampleStep +from .resample_spec_step import ResampleSpecStep + +__all__ = ['ResampleStep', 'ResampleSpecStep'] diff --git a/iris_pipeline/resample/gwcs_drizzle.py b/iris_pipeline/resample/gwcs_drizzle.py new file mode 100644 index 0000000..26a12b1 --- /dev/null +++ b/iris_pipeline/resample/gwcs_drizzle.py @@ -0,0 +1,457 @@ +import numpy as np + +from drizzle import util +from drizzle import doblot +from drizzle import cdrizzle +from . import resample_utils + +import logging +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +class GWCSDrizzle: + """ + Combine images using the drizzle algorithm + """ + def __init__(self, product, outwcs=None, single=False, + wt_scl="exptime", pixfrac=1.0, kernel="square", + fillval="INDEF"): + """ + Create a new Drizzle output object and set the drizzle parameters. + + Parameters + ---------- + + product : str, optional + A data model containing results from a previous run. The three + extensions SCI, WHT, and CTX contain the combined image, total counts + and image id bitmap, repectively. The WCS of the combined image is + also read from the SCI extension. + + outwcs : `gwcs.WCS` + The world coordinate system (WCS) of the resampled image. If not + provided, the WCS is taken from product. + + wt_scl : str, optional + How each input image should be scaled. The choices are `exptime` + which scales each image by its exposure time, `expsq` which scales + each image by the exposure time squared, or an empty string, which + allows each input image to be scaled individually. + + pixfrac : float, optional + The fraction of a pixel that the pixel flux is confined to. The + default value of 1 has the pixel flux evenly spread across the image. + A value of 0.5 confines it to half a pixel in the linear dimension, + so the flux is confined to a quarter of the pixel area when the square + kernel is used. + + kernel : str, optional + The name of the kernel used to combine the inputs. The choice of + kernel controls the distribution of flux over the kernel. The kernel + names are: "square", "gaussian", "point", "tophat", "turbo", "lanczos2", + and "lanczos3". The square kernel is the default. + + fillval : str, otional + The value a pixel is set to in the output if the input image does + not overlap it. The default value of INDEF does not set a value. + """ + + # Initialize the object fields + self.outsci = None + self.outwht = None + self.outcon = None + + self.outexptime = 0.0 + self.uniqid = 0 + + self.wt_scl = wt_scl + self.kernel = kernel + self.fillval = fillval + self.pixfrac = pixfrac + + self.sciext = "SCI" + self.whtext = "WHT" + self.conext = "CON" + + out_units = "cps" + + self.outexptime = product.meta.resample.product_exposure_time or 0.0 + + self.outsci = product.data + if outwcs: + self.outwcs = outwcs + else: + self.outwcs = product.meta.wcs + + self.outwht = product.wht + self.outcon = product.con + + if self.outcon.ndim == 2: + self.outcon = np.reshape(self.outcon, (1, + self.outcon.shape[0], + self.outcon.shape[1])) + + elif self.outcon.ndim == 3: + pass + + else: + raise ValueError("Drizzle context image has wrong dimensions: \ + {0}".format(product)) + + # Check field values + if not self.outwcs: + raise ValueError("Either an existing file or wcs must be supplied") + + if util.is_blank(self.wt_scl): + self.wt_scl = '' + elif self.wt_scl != "exptime" and self.wt_scl != "expsq": + raise ValueError("Illegal value for wt_scl: %s" % self.wt_scl) + + if out_units == "counts": + np.divide(self.outsci, self.outexptime, self.outsci) + elif out_units != "cps": + raise ValueError("Illegal value for out_units: %s" % out_units) + + def add_image(self, insci, inwcs, inwht=None, + xmin=0, xmax=0, ymin=0, ymax=0, pscale_ratio=1.0, + expin=1.0, in_units="cps", wt_scl=1.0): + """ + Combine an input image with the output drizzled image. + + Instead of reading the parameters from a fits file, you can set + them by calling this lower level method. `Add_fits_file` calls + this method after doing its setup. + + Parameters + ---------- + + insci : array + A 2d numpy array containing the input image to be drizzled. + it is an error to not supply an image. + + inwcs : wcs + The world coordinate system of the input image. This is + used to convert the pixels to the output coordinate system. + + inwht : array, optional + A 2d numpy array containing the pixel by pixel weighting. + Must have the same dimenstions as insci. If none is supplied, + the weghting is set to one. + + xmin : float, optional + This and the following three parameters set a bounding rectangle + on the output image. Only pixels on the output image inside this + rectangle will have their flux updated. Xmin sets the minimum value + of the x dimension. The x dimension is the dimension that varies + quickest on the image. If the value is zero or less, no minimum will + be set in the x dimension. All four parameters are zero based, + counting starts at zero. + + xmax : float, optional + Sets the maximum value of the x dimension on the bounding box + of the ouput image. If the value is zero or less, no maximum will + be set in the x dimension. + + ymin : float, optional + Sets the minimum value in the y dimension on the bounding box. The + y dimension varies less rapidly than the x and represents the line + index on the output image. If the value is zero or less, no minimum + will be set in the y dimension. + + ymax : float, optional + Sets the maximum value in the y dimension. If the value is zero or + less, no maximum will be set in the y dimension. + + expin : float, optional + The exposure time of the input image, a positive number. The + exposure time is used to scale the image if the units are counts and + to scale the image weighting if the drizzle was initialized with + wt_scl equal to "exptime" or "expsq." + + in_units : str, optional + The units of the input image. The units can either be "counts" + or "cps" (counts per second.) If the value is counts, before using + the input image it is scaled by dividing it by the exposure time. + + wt_scl : float, optional + If drizzle was initialized with wt_scl left blank, this value will + set a scaling factor for the pixel weighting. If drizzle was + initialized with wt_scl set to "exptime" or "expsq", the exposure time + will be used to set the weight scaling and the value of this parameter + will be ignored. + """ + insci = insci.astype(np.float32) + + if inwht is None: + inwht = np.ones(insci.shape, dtype=insci.dtype) + else: + inwht = inwht.astype(np.float32) + + if self.wt_scl == "exptime": + wt_scl = expin + elif self.wt_scl == "expsq": + wt_scl = expin * expin + + wt_scl = 1.0 # hard-coded for JWST count-rate data + self.increment_id() + + dodrizzle(insci, inwcs, inwht, self.outwcs, + self.outsci, self.outwht, self.outcon, + expin, in_units, wt_scl, + pscale_ratio=pscale_ratio, uniqid=self.uniqid, + xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, + pixfrac=self.pixfrac, kernel=self.kernel, + fillval=self.fillval) + + def blot_image(self, blotwcs, interp='poly5', sinscl=1.0): + """ + Resample the output image using an input world coordinate system. + + Parameters + ---------- + + blotwcs : wcs + The world coordinate system to resample on. + + interp : str, optional + The type of interpolation used in the resampling. The + possible values are "nearest" (nearest neighbor interpolation), + "linear" (bilinear interpolation), "poly3" (cubic polynomial + interpolation), "poly5" (quintic polynomial interpolation), + "sinc" (sinc interpolation), "lan3" (3rd order Lanczos + interpolation), and "lan5" (5th order Lanczos interpolation). + + sincscl : float, optional + The scaling factor for sinc interpolation. + """ + + util.set_pscale(blotwcs) + self.outsci = doblot.doblot(self.outsci, self.outwcs, blotwcs, + 1.0, interp=interp, sinscl=sinscl) + + self.outwcs = blotwcs + + def increment_id(self): + """ + Increment the id count and add a plane to the context image if needed + + Drizzle tracks which input images contribute to the output image + by setting a bit in the corresponding pixel in the context image. + The uniqid indicates which bit. So it must be incremented each time + a new image is added. Each plane in the context image can hold 32 bits, + so after each 32 images, a new plane is added to the context. + """ + + # Compute what plane of the context image this input would + # correspond to: + planeid = int(self.uniqid / 32) + + # Add a new plane to the context image if planeid overflows + + if self.outcon.shape[0] == planeid: + plane = np.zeros_like(self.outcon[0]) + plane = plane.reshape((1, plane.shape[0], plane.shape[1])) + self.outcon = np.concatenate((self.outcon, plane)) + + # Increment the id + self.uniqid += 1 + + +def dodrizzle(insci, input_wcs, inwht, + output_wcs, outsci, outwht, outcon, + expin, in_units, wt_scl, + pscale_ratio=1.0, uniqid=1, + xmin=0, xmax=0, ymin=0, ymax=0, + pixfrac=1.0, kernel='square', fillval="INDEF"): + """ + Low level routine for performing 'drizzle' operation on one image. + + The interface is compatible with STScI code. All images are Python + ndarrays, instead of filenames. File handling (input and output) is + performed by the calling routine. + + Parameters + ---------- + + insci : 2d array + A 2d numpy array containing the input image to be drizzled. + + input_wcs : gwcs.WCS object + The world coordinate system of the input image. + + inwht : 2d array + A 2d numpy array containing the pixel by pixel weighting. + Must have the same dimensions as insci. If none is supplied, + the weghting is set to one. + + output_wcs : gwcs.WCS object + The world coordinate system of the output image. + + outsci : 2d array + A 2d numpy array containing the output image produced by + drizzling. On the first call it should be set to zero. + Subsequent calls it will hold the intermediate results + + outwht : 2d array + A 2d numpy array containing the output counts. On the first + call it should be set to zero. On subsequent calls it will + hold the intermediate results. + + outcon : 2d or 3d array, optional + A 2d or 3d numpy array holding a bitmap of which image was an input + for each output pixel. Should be integer zero on first call. + Subsequent calls hold intermediate results. + + expin : float + The exposure time of the input image, a positive number. The + exposure time is used to scale the image if the units are counts. + + in_units : str + The units of the input image. The units can either be "counts" + or "cps" (counts per second.) + + wt_scl : float + A scaling factor applied to the pixel by pixel weighting. + + wcslin_pscale : float, optional + The pixel scale of the input image. Conceptually, this is the + linear dimension of a side of a pixel in the input image, but it + is not limited to this and can be set to change how the drizzling + algorithm operates. + + uniqid : int, optional + The id number of the input image. Should be one the first time + this function is called and incremented by one on each subsequent + call. + + xmin : float, optional + This and the following three parameters set a bounding rectangle + on the input image. Only pixels on the input image inside this + rectangle will have their flux added to the output image. Xmin + sets the minimum value of the x dimension. The x dimension is the + dimension that varies quickest on the image. If the value is zero, + no minimum will be set in the x dimension. All four parameters are + zero based, counting starts at zero. + + xmax : float, optional + Sets the maximum value of the x dimension on the bounding box + of the input image. If the value is zero, no maximum will + be set in the x dimension, the full x dimension of the output + image is the bounding box. + + ymin : float, optional + Sets the minimum value in the y dimension on the bounding box. The + y dimension varies less rapidly than the x and represents the line + index on the input image. If the value is zero, no minimum will be + set in the y dimension. + + ymax : float, optional + Sets the maximum value in the y dimension. If the value is zero, no + maximum will be set in the y dimension, the full x dimension + of the output image is the bounding box. + + pixfrac : float, optional + The fraction of a pixel that the pixel flux is confined to. The + default value of 1 has the pixel flux evenly spread across the image. + A value of 0.5 confines it to half a pixel in the linear dimension, + so the flux is confined to a quarter of the pixel area when the square + kernel is used. + + kernel: str, optional + The name of the kernel used to combine the input. The choice of + kernel controls the distribution of flux over the kernel. The kernel + names are: "square", "gaussian", "point", "tophat", "turbo", "lanczos2", + and "lanczos3". The square kernel is the default. + + fillval: str, optional + The value a pixel is set to in the output if the input image does + not overlap it. The default value of INDEF does not set a value. + + Returns + ------- + A tuple with three values: a version string, the number of pixels + on the input image that do not overlap the output image, and the + number of complete lines on the input image that do not overlap the + output input image. + + """ + + # Insure that the fillval parameter gets properly interpreted for use with tdriz + if util.is_blank(str(fillval)): + fillval = 'INDEF' + else: + fillval = str(fillval) + + if in_units == 'cps': + expscale = 1.0 + else: + expscale = expin + + # Add input weight image if it was not passed in + + if (insci.dtype > np.float32): + insci = insci.astype(np.float32) + + if inwht is None: + inwht = np.ones_like(insci) + + if xmax is None or xmax == xmin: + xmax = insci.shape[1] + if ymax is None or ymax == ymin: + ymax = insci.shape[0] + + # Compute what plane of the context image this input would + # correspond to: + planeid = int((uniqid - 1) / 32) + + # Check if the context image has this many planes + if outcon.ndim == 3: + nplanes = outcon.shape[0] + elif outcon.ndim == 2: + nplanes = 1 + else: + nplanes = 0 + + if nplanes <= planeid: + raise IndexError("Not enough planes in drizzle context image") + + # Alias context image to the requested plane if 3d + if outcon.ndim == 3: + outcon = outcon[planeid] + + # Compute the mapping between the input and output pixel coordinates + # for use in drizzle.cdrizzle.tdriz + pixmap = resample_utils.calc_gwcs_pixmap(input_wcs, output_wcs, insci.shape) + # pixmap[np.isnan(pixmap)] = -10 + # print("Number of NaNs: ", len(np.isnan(pixmap)) / 2) + # inwht[np.isnan(pixmap[:,:,0])] = 0. + + log.debug("Pixmap shape: {}".format(pixmap[:,:,0].shape)) + log.debug("Input Sci shape: {}".format(insci.shape)) + log.debug("Output Sci shape: {}".format(outsci.shape)) + + # y_mid = pixmap.shape[0] // 2 + # x_mid = pixmap.shape[1] // 2 + # print("x slice: ", pixmap[y_mid,:,0]) + # print("y slice: ", pixmap[:,x_mid,1]) + # print("insci: ", insci) + + # Call 'drizzle' to perform image combination + log.info('Drizzling {} --> {}'.format(insci.shape, outsci.shape)) + _vers, nmiss, nskip = cdrizzle.tdriz( + insci, inwht, pixmap, + outsci, outwht, outcon, + uniqid=uniqid, + xmin=xmin, xmax=xmax, + ymin=ymin, ymax=ymax, + scale=pscale_ratio, + pixfrac=pixfrac, + kernel=kernel, + in_units=in_units, + expscale=expscale, + wtscale=wt_scl, + fillstr=fillval + ) + + return _vers, nmiss, nskip diff --git a/iris_pipeline/resample/resample.py b/iris_pipeline/resample/resample.py new file mode 100644 index 0000000..e096eeb --- /dev/null +++ b/iris_pipeline/resample/resample.py @@ -0,0 +1,177 @@ +import logging +from collections import OrderedDict +import numpy as np + +from .. import datamodels + +from . import gwcs_drizzle +from . import resample_utils +from ..model_blender import blendmeta + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + +__all__ = ["ResampleData"] + + +class ResampleData: + """ + This is the controlling routine for the resampling process. + It loads and sets the various input data and parameters needed by + the drizzle function and then calls the C-based cdriz.tdriz function + to do the actual resampling. + + Notes + ----- + This routine performs the following operations:: + + 1. Extracts parameter settings from input model, such as pixfrac, + weight type, exposure time (if relevant), and kernel, and merges + them with any user-provided values. + 2. Creates output WCS based on input images and define mapping function + between all input arrays and the output array. + 3. Initializes all output arrays, including WHT and CTX arrays. + 4. Passes all information for each input chip to drizzle function. + 5. Updates output data model with output arrays from drizzle, including + (eventually) a record of metadata from all input models. + """ + def __init__(self, input_models, output=None, **pars): + """ + Parameters + ---------- + input_models : list of objects + list of data models, one for each input image + + output : str + filename for output + """ + self.input_models = input_models + self.drizpars = pars + if output is None: + output = input_models.meta.resample.output + self.output_filename = output + + # Define output WCS based on all inputs, including a reference WCS + self.output_wcs = resample_utils.make_output_wcs(self.input_models) + log.debug('Output mosaic size: {}'.format(self.output_wcs.data_size)) + self.blank_output = datamodels.DrizProductModel(self.output_wcs.data_size) + + # update meta data and wcs + self.blank_output.update(input_models[0]) + self.blank_output.meta.wcs = self.output_wcs + + self.output_models = datamodels.ModelContainer() + + def update_driz_outputs(self): + """ Define output arrays for use with drizzle operations. + """ + numchips = len(self.input_models) + numplanes = (numchips // 32) + 1 + + # Replace CONTEXT array with full set of planes needed for all inputs + outcon = np.zeros((numplanes, self.output_wcs.data_size[0], + self.output_wcs.data_size[1]), dtype=np.int32) + self.blank_output.con = outcon + + def blend_output_metadata(self, output_model): + """Create new output metadata based on blending all input metadata.""" + # Run fitsblender on output product + output_file = output_model.meta.filename + + log.info('Blending metadata for {}'.format(output_file)) + blendmeta.blendmodels(output_model, inputs=self.input_models, + output=output_file) + + def do_drizzle(self): + """ Perform drizzling operation on input images's to create a new output + """ + # Set up information about what outputs we need to create: single or final + # Key: value from metadata for output/observation name + # Value: full filename for output file + driz_outputs = OrderedDict() + + # Look for input configuration parameter telling the code to run + # in single-drizzle mode (mosaic all detectors in a single observation) + if self.drizpars['single']: + driz_outputs = self.input_models.group_names + exposures = self.input_models.models_grouped + group_exptime = [] + for exposure in exposures: + group_exptime.append(exposure[0].meta.exposure.exposure_time) + else: + driz_outputs = [self.output_filename] + exposures = [self.input_models] + + total_exposure_time = 0.0 + for exposure in exposures: + total_exposure_time += exposure[0].meta.exposure.exposure_time + group_exptime = [total_exposure_time] + pointings = len(self.input_models.group_names) + + for obs_product, exposure, texptime in zip(driz_outputs, exposures, + group_exptime): + output_model = self.blank_output.copy() + output_model.meta.filename = obs_product + saved_model_type = output_model.meta.model_type + + if self.drizpars['blendheaders']: + self.blend_output_metadata(output_model) + output_model.meta.model_type = saved_model_type + + exposure_times = {'start': [], 'end': []} + + # Initialize the output with the wcs + driz = gwcs_drizzle.GWCSDrizzle(output_model, + single=self.drizpars['single'], + pixfrac=self.drizpars['pixfrac'], + kernel=self.drizpars['kernel'], + fillval=self.drizpars['fillval']) + + for n, img in enumerate(exposure): + exposure_times['start'].append(img.meta.exposure.start_time) + exposure_times['end'].append(img.meta.exposure.end_time) + + # apply sky subtraction + blevel = img.meta.background.level + if not img.meta.background.subtracted and blevel is not None: + img.data -= blevel + + outwcs_pscale = output_model.meta.wcsinfo.cdelt1 + wcslin_pscale = img.meta.wcsinfo.cdelt1 + + inwht = resample_utils.build_driz_weight(img, + weight_type=self.drizpars['weight_type'], + good_bits=self.drizpars['good_bits']) + driz.add_image(img.data, img.meta.wcs, inwht=inwht, + expin=img.meta.exposure.exposure_time, + pscale_ratio=outwcs_pscale / wcslin_pscale) + + # Update some basic exposure time values based on all the inputs + output_model.meta.exposure.exposure_time = texptime + output_model.meta.exposure.start_time = min(exposure_times['start']) + output_model.meta.exposure.end_time = max(exposure_times['end']) + output_model.meta.resample.product_exposure_time = texptime + output_model.meta.resample.weight_type = self.drizpars['weight_type'] + output_model.meta.resample.pointings = pointings + + self.update_fits_wcs(output_model) + + self.output_models.append(output_model) + + def update_fits_wcs(self, model): + """ + Update FITS WCS keywords of the resampled image. + """ + transform = model.meta.wcs.forward_transform + model.meta.wcsinfo.crpix1 = -transform[0].offset.value + 1 + model.meta.wcsinfo.crpix2 = -transform[1].offset.value + 1 + model.meta.wcsinfo.cdelt1 = transform[3].factor.value + model.meta.wcsinfo.cdelt2 = transform[4].factor.value + model.meta.wcsinfo.ra_ref = transform[6].lon.value + model.meta.wcsinfo.dec_ref = transform[6].lat.value + model.meta.wcsinfo.crval1 = model.meta.wcsinfo.ra_ref + model.meta.wcsinfo.crval2 = model.meta.wcsinfo.dec_ref + model.meta.wcsinfo.pc1_1 = transform[2].matrix.value[0][0] + model.meta.wcsinfo.pc1_2 = transform[2].matrix.value[0][1] + model.meta.wcsinfo.pc2_1 = transform[2].matrix.value[1][0] + model.meta.wcsinfo.pc2_2 = transform[2].matrix.value[1][1] diff --git a/iris_pipeline/resample/resample_spec.py b/iris_pipeline/resample/resample_spec.py new file mode 100644 index 0000000..f6081f8 --- /dev/null +++ b/iris_pipeline/resample/resample_spec.py @@ -0,0 +1,297 @@ +import logging +from collections import OrderedDict + +import numpy as np + +from astropy import coordinates as coord +from astropy import units as u +from astropy.modeling.models import (Mapping, Tabular1D, Linear1D, + Pix2Sky_TAN, RotateNative2Celestial) +from astropy.modeling.fitting import LinearLSQFitter +from gwcs import wcstools, WCS +from gwcs import coordinate_frames as cf + +from .. import datamodels +from . import gwcs_drizzle +from . import resample_utils + + +CRBIT = np.uint32(datamodels.dqflags.pixel['JUMP_DET']) + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +class ResampleSpecData: + """ + This is the controlling routine for the resampling process. + It loads and sets the various input data and parameters needed by + the drizzle function and then calls the C-based cdriz.tdriz function + to do the actual resampling. + + Notes + ----- + This routine performs the following operations:: + + 1. Extracts parameter settings from input model, such as pixfrac, + weight type, exposure time (if relevant), and kernel, and merges + them with any user-provided values. + 2. Creates output WCS based on input images and define mapping function + between all input arrays and the output array. + 3. Initializes all output arrays, including WHT and CTX arrays. + 4. Passes all information for each input chip to drizzle function. + 5. Updates output data model with output arrays from drizzle, including + (eventually) a record of metadata from all input models. + """ + + def __init__(self, input_models, output=None, **pars): + """ + Parameters + ---------- + input_models : list of objects + list of data models, one for each input image + + output : str + filename for output + """ + self.input_models = input_models + if output is None: + output = input_models.meta.resample.output + + self.drizpars = pars + + self.pscale_ratio = 1. + self.blank_output = None + + # Define output WCS based on all inputs, including a reference WCS + # wcslist = [m.meta.wcs for m in self.input_models] + self.output_wcs = self.build_interpolated_output_wcs() + self.blank_output = datamodels.DrizProductModel(self.data_size) + + self.blank_output.update(datamodels.ImageModel(self.input_models[0]._instance)) + self.blank_output.meta.wcs = self.output_wcs + self.output_models = datamodels.ModelContainer() + + def build_interpolated_output_wcs(self, refmodel=None): + """ + Create a spatial/spectral WCS output frame + + Creates output frame by linearly fitting RA, Dec along the slit and + producing a lookup table to interpolate wavelengths in the dispersion + direction. + + Parameters + ---------- + refmodel : `~jwst.datamodels.DataModel` + The reference input image from which the fiducial WCS is created. + If not specified, the first image in self.input_models is used. + + Returns + ------- + output_wcs : `~gwcs.WCS` object + A gwcs WCS object defining the output frame WCS + """ + if refmodel is None: + refmodel = self.input_models[0] + refwcs = refmodel.meta.wcs + bb = refwcs.bounding_box + + grid = wcstools.grid_from_bounding_box(bb) + ra, dec, lam = np.array(refwcs(*grid)) + lon = np.nanmean(ra) + lat = np.nanmean(dec) + tan = Pix2Sky_TAN() + native2celestial = RotateNative2Celestial(lon, lat, 180) + undist2sky = tan | native2celestial + x_tan, y_tan = undist2sky.inverse(ra, dec) + + spectral_axis = find_dispersion_axis(lam) + spatial_axis = spectral_axis ^ 1 + + # Compute the wavelength array, trimming NaNs from the ends + wavelength_array = np.nanmedian(lam, axis=spectral_axis) + wavelength_array = wavelength_array[~np.isnan(wavelength_array)] + + # Compute RA and Dec up the slit (spatial direction) at the center + # of the dispersion. Use spectral_axis to determine slicing dimension + lam_center_index = int((bb[spectral_axis][1] - bb[spectral_axis][0]) / 2) + if not spectral_axis: + x_tan_array = x_tan.T[lam_center_index] + y_tan_array = y_tan.T[lam_center_index] + else: + x_tan_array = x_tan[lam_center_index] + y_tan_array = y_tan[lam_center_index] + x_tan_array = x_tan_array[~np.isnan(x_tan_array)] + y_tan_array = y_tan_array[~np.isnan(y_tan_array)] + + fitter = LinearLSQFitter() + fit_model = Linear1D() + pix_to_ra = fitter(fit_model, np.arange(x_tan_array.shape[0]), x_tan_array) + pix_to_dec = fitter(fit_model, np.arange(y_tan_array.shape[0]), y_tan_array) + + # Tabular interpolation model, pixels -> lambda + pix_to_wavelength = Tabular1D(lookup_table=wavelength_array, + bounds_error=False, fill_value=None, name='pix2wavelength') + + # Tabular models need an inverse explicitly defined. + # If the wavelength array is decending instead of ascending, both + # points and lookup_table need to be reversed in the inverse transform + # for scipy.interpolate to work properly + points = wavelength_array + lookup_table = np.arange(wavelength_array.shape[0]) + if not np.all(np.diff(wavelength_array) > 0): + points = points[::-1] + lookup_table = lookup_table[::-1] + pix_to_wavelength.inverse = Tabular1D(points=points, + lookup_table=lookup_table, + bounds_error=False, fill_value=None, name='wavelength2pix') + + # For the input mapping, duplicate the spatial coordinate + mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) + + # Sometimes the slit is perpendicular to the RA or Dec axis. + # For example, if the slit is perpendicular to RA, that means + # the slope of pix_to_ra will be nearly zero, so make sure + # mapping.inverse uses pix_to_dec.inverse. The auto definition + # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec. + if np.isclose(pix_to_dec.slope, 0, atol=1e-8): + mapping_tuple = (0, 1) + # Account for vertical or horizontal dispersion on detector + if spatial_axis: + mapping.inverse = Mapping(mapping_tuple[::-1]) + else: + mapping.inverse = Mapping(mapping_tuple) + + # The final transform + transform = mapping | (pix_to_ra & pix_to_dec | undist2sky) & pix_to_wavelength + + det = cf.Frame2D(name='detector', axes_order=(0, 1)) + sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), + reference_frame=coord.ICRS()) + spec = cf.SpectralFrame(name='spectral', axes_order=(2,), + unit=(u.micron,), axes_names=('wavelength',)) + world = cf.CompositeFrame([sky, spec], name='world') + + pipeline = [(det, transform), + (world, None)] + + output_wcs = WCS(pipeline) + + # compute the output array size in WCS axes order, i.e. (x, y) + output_array_size = [0, 0] + output_array_size[spectral_axis] = len(wavelength_array) + output_array_size[spatial_axis] = len(x_tan_array) + + # turn the size into a numpy shape in (y, x) order + self.data_size = tuple(output_array_size[::-1]) + + bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size) + output_wcs.bounding_box = bounding_box + + return output_wcs + + def do_drizzle(self, **pars): + """ Perform drizzling operation on input images's to create a new output + """ + # Set up information about what outputs we need to create: single or final + # Key: value from metadata for output/observation name + # Value: full filename for output file + driz_outputs = OrderedDict() + + # Look for input configuration parameter telling the code to run + # in single-drizzle mode (mosaic all detectors in a single observation?) + if self.drizpars['single']: + driz_outputs = ['{0}_resamp.fits'.format(g) for g in self.input_models.group_names] + model_groups = self.input_models.models_grouped + group_exptime = [] + for group in model_groups: + group_exptime.append(group[0].meta.exposure.exposure_time) + else: + final_output = self.input_models.meta.resample.output + driz_outputs = [final_output] + model_groups = [self.input_models] + + total_exposure_time = 0.0 + for group in self.input_models.models_grouped: + total_exposure_time += group[0].meta.exposure.exposure_time + group_exptime = [total_exposure_time] + + pointings = len(self.input_models.group_names) + # Now, generate each output for all input_models + for obs_product, group, texptime in zip(driz_outputs, model_groups, group_exptime): + output_model = self.blank_output.copy() + output_model.meta.wcs = self.output_wcs + + bb = resample_utils.wcs_bbox_from_shape(output_model.data.shape) + output_model.meta.wcs.bounding_box = bb + output_model.meta.filename = obs_product + + exposure_times = {'start': [], 'end': []} + + outwcs = output_model.meta.wcs + + # Initialize the output with the wcs + driz = gwcs_drizzle.GWCSDrizzle(output_model, + outwcs=outwcs, + single=self.drizpars['single'], + pixfrac=self.drizpars['pixfrac'], + kernel=self.drizpars['kernel'], + fillval=self.drizpars['fillval']) + + for n, img in enumerate(group): + exposure_times['start'].append(img.meta.exposure.start_time) + exposure_times['end'].append(img.meta.exposure.end_time) + + inwht = resample_utils.build_driz_weight(img, + weight_type=self.drizpars['weight_type'], + good_bits=self.drizpars['good_bits']) + if hasattr(img, 'name'): + log.info('Resampling slit {} {}'.format(img.name, self.data_size)) + else: + log.info('Resampling slit {}'.format(self.data_size)) + + in_wcs = img.meta.wcs + driz.add_image(img.data, in_wcs, inwht=inwht, + expin=img.meta.exposure.exposure_time, + pscale_ratio=self.pscale_ratio) + + # Update some basic exposure time values based on all the inputs + output_model.meta.exposure.exposure_time = texptime + output_model.meta.exposure.start_time = min(exposure_times['start']) + output_model.meta.exposure.end_time = max(exposure_times['end']) + output_model.meta.resample.product_exposure_time = texptime + output_model.meta.resample.weight_type = self.drizpars['weight_type'] + output_model.meta.resample.pointings = pointings + + # Update mutlislit slit info on the output_model + for attr in ['name', 'xstart', 'xsize', 'ystart', 'ysize', + 'slitlet_id', 'source_id', 'source_name', 'source_alias', + 'stellarity', 'source_type', 'source_xpos', 'source_ypos', + 'shutter_state']: + try: + val = getattr(img, attr) + except AttributeError: + pass + else: + setattr(output_model, attr, val) + + self.output_models.append(output_model) + + return self.output_models + + +def find_dispersion_axis(wavelength_array): + """ + Find the dispersion axis (0-indexed) of the given 2D wavelength array + """ + diffx = wavelength_array[:, 1:] - wavelength_array[:, 0:-1] + diffy = wavelength_array[1:, :] - wavelength_array[0:-1, :] + dwlx = np.abs(np.nanmean(diffx)) + dwly = np.abs(np.nanmean(diffy)) + if dwlx > dwly: + return 0 + elif dwlx < dwly: + return 1 + else: + raise RuntimeError("Can't find dispersion axis. dx: {}, dy: {}".format( + dwlx, dwly)) diff --git a/iris_pipeline/resample/resample_spec_step.py b/iris_pipeline/resample/resample_spec_step.py new file mode 100755 index 0000000..0facdae --- /dev/null +++ b/iris_pipeline/resample/resample_spec_step.py @@ -0,0 +1,126 @@ +__all__ = ["ResampleSpecStep"] + +from .. import datamodels +from ..datamodels import MultiSlitModel, ModelContainer +from . import resample_spec, ResampleStep +from ..exp_to_source import multislit_to_container +from ..assign_wcs.util import update_s_region_spectral + + +class ResampleSpecStep(ResampleStep): + """ + ResampleSpecStep: Resample input data onto a regular grid using the + drizzle algorithm. + + Parameters + ----------- + input : `~jwst.datamodels.MultSlitModel`, `~jwst.datamodels.ModelContainer`, Association + A singe datamodel, a container of datamodels, or an association file + """ + + def process(self, input): + input = datamodels.open(input) + + # If single DataModel input, wrap in a ModelContainer + if not isinstance(input, ModelContainer): + input_models = datamodels.ModelContainer([input]) + input_models.meta.resample.output = input.meta.filename + self.blendheaders = False + else: + input_models = input + + for reftype in self.reference_file_types: + ref_filename = self.get_reference_file(input_models[0], reftype) + + if ref_filename != 'N/A': + self.log.info('Drizpars reference file: {}'.format(ref_filename)) + kwargs = self.get_drizpars(ref_filename, input_models) + else: + # Deal with NIRSpec which currently has no default drizpars reffile + self.log.info("No NIRSpec DIRZPARS reffile") + kwargs = self._set_spec_defaults() + + self.drizpars = kwargs + + if isinstance(input_models[0], MultiSlitModel): + # result is a MultiProductModel + result = self._process_multislit(input_models) + elif len(input_models[0].data.shape) != 2: + # resample can only handle 2D images, not 3D cubes, etc + raise RuntimeError("Input {} is not a 2D image.".format(input_models[0])) + else: + # result is a DrizProductModel + result = self._process_slit(input_models) + return result + + def _process_multislit(self, input_models): + """ + Resample MultiSlit data + + Parameters + ---------- + input : `~jwst.datamodels.ModelContainer` + A container of `~jwst.datamodels.MultiSlitModel` + + Returns + ------- + result : `~jwst.datamodels.MultiProductModel` + The resampled output, one per source + """ + containers = multislit_to_container(input_models) + result = datamodels.MultiProductModel() + result.update(input_models[0]) + for container in containers.values(): + resamp = resample_spec.ResampleSpecData(container, **self.drizpars) + drizzled_models = resamp.do_drizzle() + + for model in drizzled_models: + model.meta.cal_step.resample = "COMPLETE" + model.meta.asn.pool_name = input_models.meta.pool_name + model.meta.asn.table_name = input_models.meta.table_name + update_s_region_spectral(model) + + # Everything resampled to single output model + if len(drizzled_models) == 1: + result.products.append(drizzled_models[0]) + result.products[-1].bunit_data = container[0].meta.bunit_data + else: + # When each input is resampled to its own output + for model in drizzled_models: + result.products.append(model) + result.products[-1].bunit_data = container[0].meta.bunit_data + + return result + + def _process_slit(self, input_models): + """ + Resample Slit data + + Parameters + ---------- + input : `~jwst.datamodels.ModelContainer` + A container of `~jwst.datamodels.ImageModel` + or `~jwst.datamodels.SlitModel` + + Returns + ------- + result : `~jwst.datamodels.DrizProductModel` + The resampled output, one per source + """ + resamp = resample_spec.ResampleSpecData(input_models, **self.drizpars) + drizzled_models = resamp.do_drizzle() + + for model in drizzled_models: + model.meta.cal_step.resample = "COMPLETE" + model.meta.asn.pool_name = input_models.meta.pool_name + model.meta.asn.table_name = input_models.meta.table_name + update_s_region_spectral(model) + + # Return either the single resampled datamodel, or the container + # of datamodels. + if len(drizzled_models) == 1: + result = drizzled_models[0] + else: + result = drizzled_models + + return result diff --git a/iris_pipeline/resample/resample_step.py b/iris_pipeline/resample/resample_step.py new file mode 100755 index 0000000..8bc5de9 --- /dev/null +++ b/iris_pipeline/resample/resample_step.py @@ -0,0 +1,206 @@ +import logging + +import numpy as np + +from ..stpipe import Step +from ..extern.configobj.validate import Validator +from ..extern.configobj.configobj import ConfigObj +from .. import datamodels +from . import resample +from ..assign_wcs import util + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + +__all__ = ["ResampleStep"] + + +class ResampleStep(Step): + """ + Resample input data onto a regular grid using the drizzle algorithm. + + Parameters + ----------- + input : DataModel or Association + Single filename for either a single image or an association table. + """ + + spec = """ + pixfrac = float(default=None) + kernel = string(default=None) + fillval = string(default=None) + weight_type = option('exptime', default=None) + good_bits = integer(min=0, default=6) + single = boolean(default=False) + blendheaders = boolean(default=True) + """ + + reference_file_types = ['drizpars'] + + def process(self, input): + + input = datamodels.open(input) + + # If single input, wrap in a ModelContainer + if not isinstance(input, datamodels.ModelContainer): + input_models = datamodels.ModelContainer([input]) + input_models.meta.resample.output = input.meta.filename + self.blendheaders = False + else: + input_models = input + + # Check that input models are 2D images + if len(input_models[0].data.shape) != 2: + # resample can only handle 2D images, not 3D cubes, etc + raise RuntimeError("Input {} is not a 2D image.".format(input_models[0])) + + # Get drizzle parameters reference file + for reftype in self.reference_file_types: + ref_filename = self.get_reference_file(input_models[0], reftype) + + if ref_filename != 'N/A': + self.log.info('Drizpars reference file: {}'.format(ref_filename)) + kwargs = self.get_drizpars(ref_filename, input_models) + else: + # Deal with NIRSpec which currently has no default drizpars reffile + self.log.info("No NIRSpec DIRZPARS reffile") + kwargs = self._set_spec_defaults() + + # Call the resampling routine + resamp = resample.ResampleData(input_models, **kwargs) + resamp.do_drizzle() + + for model in resamp.output_models: + model.meta.cal_step.resample = "COMPLETE" + util.update_s_region_imaging(model) + model.meta.asn.pool_name = input_models.meta.pool_name + model.meta.asn.table_name = input_models.meta.table_name + + if len(resamp.output_models) == 1: + result = resamp.output_models[0] + else: + result = resamp.output_models + + return result + + def get_drizpars(self, ref_filename, input_models): + """ + Extract drizzle parameters from reference file. + + This method extracts parameters from the drizpars reference file and + uses those to set defaults on the following ResampleStep configuration + parameters: + + pixfrac = float(default=None) + kernel = string(default=None) + fillval = string(default=None) + weight_type = option('exptime', default=None) + + Once the defaults are set from the reference file, if the user has + used a resample.cfg file or run ResampleStep using command line args, + then these will overwerite the defaults pulled from the reference file. + """ + drizpars_table = datamodels.DrizParsModel(ref_filename).data + + num_groups = len(input_models.group_names) + filtname = input_models[0].meta.instrument.filter + row = None + filter_match = False + # look for row that applies to this set of input data models + for n, filt, num in zip( + range(0, len(drizpars_table)), + drizpars_table['filter'], + drizpars_table['numimages'] + ): + # only remember this row if no exact match has already been made for + # the filter. This allows the wild-card row to be anywhere in the + # table; since it may be placed at beginning or end of table. + + if str(filt) == "ANY" and not filter_match and num_groups >= num: + row = n + # always go for an exact match if present, though... + if filtname == filt and num_groups >= num: + row = n + filter_match = True + + # With presence of wild-card rows, code should never trigger this logic + if row is None: + self.log.error("No row found in %s matching input data.", ref_filename) + raise ValueError + + # Define the keys to pull from drizpars reffile table. Note the + # step param 'weight_type' is 'wht_type' in the FITS binary table. + # All values should be None unless the user set them on the command + # line or in the call to the step + drizpars = dict( + pixfrac=self.pixfrac, + kernel=self.kernel, + fillval=self.fillval, + wht_type=self.weight_type + ) + + # For parameters that are set in drizpars table but not set by the + # user, use these. Otherwise, use values set by user. + reffile_drizpars = {k:v for k,v in drizpars.items() if v is None} + user_drizpars = {k:v for k,v in drizpars.items() if v is not None} + + # read in values from that row for each parameter + for k in reffile_drizpars: + if k in drizpars_table.names: + reffile_drizpars[k] = drizpars_table[k][row] + + # Convert the strings in the FITS binary table from np.bytes_ to str + for k,v in reffile_drizpars.items(): + if isinstance(v, np.bytes_): + reffile_drizpars[k] = v.decode('UTF-8') + + all_drizpars = {**reffile_drizpars, **user_drizpars} + + # Convert the 'wht_type' key to a 'weight_type' key + all_drizpars['weight_type'] = all_drizpars.pop('wht_type') + + kwargs = dict( + good_bits=self.good_bits, + single=self.single, + blendheaders=self.blendheaders + ) + + kwargs.update(all_drizpars) + + if 'wht_type' in kwargs: + raise DeprecationWarning('`wht_type` config keyword has changed ' + + 'to `weight_type`; ' + + 'please update calls to ResampleStep and resample.cfg files') + kwargs.pop('wht_type') + + for k,v in kwargs.items(): + self.log.debug(' {}={}'.format(k, v)) + + return kwargs + + @classmethod + def _set_spec_defaults(cls): + """NIRSpec currently has no default drizpars reference file, so default + drizzle parameters are not set properly. This method sets them. + + Remove this class method when a drizpars reffile is delivered. + """ + configspec = cls.load_spec_file() + config = ConfigObj(configspec=configspec) + if config.validate(Validator()): + kwargs = config.dict() + + if kwargs['pixfrac'] is None: + kwargs['pixfrac'] = 1.0 + if kwargs['kernel'] is None: + kwargs['kernel'] = 'square' + if kwargs['fillval'] is None: + kwargs['fillval'] = 'INDEF' + if kwargs['weight_type'] is None: + kwargs['weight_type'] = 'exptime' + + for k,v in kwargs.items(): + if k in ['pixfrac', 'kernel', 'fillval', 'weight_type']: + log.info(' setting: %s=%s', k, repr(v)) + + return kwargs diff --git a/iris_pipeline/resample/resample_utils.py b/iris_pipeline/resample/resample_utils.py new file mode 100644 index 0000000..9a09c49 --- /dev/null +++ b/iris_pipeline/resample/resample_utils.py @@ -0,0 +1,193 @@ +import numpy as np + +from astropy import wcs as fitswcs +from astropy.coordinates import SkyCoord +from astropy.modeling.models import Scale, AffineTransformation2D +from astropy.modeling import Model +from gwcs import WCS, wcstools + +from astropy.nddata.bitmask import interpret_bit_flags + +from ..assign_wcs.util import wcs_from_footprints, wcs_bbox_from_shape + +import logging +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +def make_output_wcs(input_models): + """ Generate output WCS here based on footprints of all input WCS objects + Parameters + ---------- + wcslist : list of gwcs.WCS objects + + Returns + ------- + output_wcs : object + WCS object, with defined domain, covering entire set of input frames + + """ + + # The API needing input_models instead of just wcslist is because + # currently the domain is not defined in any of imaging modes for NIRCam + # NIRISS or MIRI + # + # TODO: change the API to take wcslist instead of input_models and + # remove the following block + wcslist = [i.meta.wcs for i in input_models] + for w, i in zip(wcslist, input_models): + if w.bounding_box is None: + w.bounding_box = wcs_bbox_from_shape(i.data.shape) + naxes = wcslist[0].output_frame.naxes + + if naxes == 3: + # THIS BLOCK CURRENTLY ISN"T USED BY resample_spec + pass + elif naxes == 2: + output_wcs = wcs_from_footprints(input_models) + output_wcs.data_size = shape_from_bounding_box(output_wcs.bounding_box) + + # Check that the output data shape has no zero length dimensions + if not np.product(output_wcs.data_size): + raise ValueError("Invalid output frame shape: " + "{}".format(output_wcs.data_size)) + + return output_wcs + + +def compute_output_transform(refwcs, filename, fiducial): + """Compute a simple FITS-type WCS transform + """ + x0, y0 = refwcs.backward_transform(*fiducial) + x1 = x0 + 1 + y1 = y0 + 1 + ra0, dec0 = refwcs(x0, y0) + ra_xdir, dec_xdir = refwcs(x1, y0) + ra_ydir, dec_ydir = refwcs(x0, y1) + + position0 = SkyCoord(ra=ra0, dec=dec0, unit='deg') + position_xdir = SkyCoord(ra=ra_xdir, dec=dec_xdir, unit='deg') + position_ydir = SkyCoord(ra=ra_ydir, dec=dec_ydir, unit='deg') + offset_xdir = position0.spherical_offsets_to(position_xdir) + offset_ydir = position0.spherical_offsets_to(position_ydir) + + xscale = np.abs(position0.separation(position_xdir).value) + yscale = np.abs(position0.separation(position_ydir).value) + scale = np.sqrt(xscale * yscale) + + c00 = offset_xdir[0].value / scale + c01 = offset_xdir[1].value / scale + c10 = offset_ydir[0].value / scale + c11 = offset_ydir[1].value / scale + pc_matrix = AffineTransformation2D(matrix=[[c00, c01], [c10, c11]]) + cdelt = Scale(scale) & Scale(scale) + + return pc_matrix | cdelt + + +def shape_from_bounding_box(bounding_box): + """ Return a numpy shape based on the provided bounding_box + """ + size = [] + for axs in bounding_box: + delta = axs[1] - axs[0] + size.append(int(delta + 0.5)) + return tuple(reversed(size)) + + +def calc_gwcs_pixmap(in_wcs, out_wcs, shape=None): + """ Return a pixel grid map from input frame to output frame. + """ + if shape: + bb = wcs_bbox_from_shape(shape) + log.debug("Bounding box from data shape: {}".format(bb)) + else: + bb = in_wcs.bounding_box + log.debug("Bounding box from WCS: {}".format(in_wcs.bounding_box)) + + grid = wcstools.grid_from_bounding_box(bb) + pixmap = np.dstack(reproject(in_wcs, out_wcs)(grid[0], grid[1])) + pixmap[np.isnan(pixmap)] = -1 + + return pixmap + + +def reproject(wcs1, wcs2): + """ + Given two WCSs or transforms return a function which takes pixel + coordinates in the first WCS or transform and computes them in the second + one. It performs the forward transformation of ``wcs1`` followed by the + inverse of ``wcs2``. + + Parameters + ---------- + wcs1, wcs2 : `~astropy.wcs.WCS` or `~gwcs.wcs.WCS` or `~astropy.modeling.Model` + WCS objects. + + Returns + ------- + _reproject : func + Function to compute the transformations. It takes x, y + positions in ``wcs1`` and returns x, y positions in ``wcs2``. + """ + + if isinstance(wcs1, fitswcs.WCS): + forward_transform = wcs1.all_pix2world + elif isinstance(wcs1, WCS): + forward_transform = wcs1.forward_transform + elif issubclass(wcs1, Model): + forward_transform = wcs1 + else: + raise TypeError("Expected input to be astropy.wcs.WCS or gwcs.WCS " + "object or astropy.modeling.Model subclass") + + if isinstance(wcs2, fitswcs.WCS): + backward_transform = wcs2.all_world2pix + elif isinstance(wcs2, WCS): + backward_transform = wcs2.backward_transform + elif issubclass(wcs2, Model): + backward_transform = wcs2.inverse + else: + raise TypeError("Expected input to be astropy.wcs.WCS or gwcs.WCS " + "object or astropy.modeling.Model subclass") + + def _reproject(x, y): + sky = forward_transform(x, y) + flat_sky = [] + for axis in sky: + flat_sky.append(axis.flatten()) + det = backward_transform(*tuple(flat_sky)) + det_reshaped = [] + for axis in det: + det_reshaped.append(axis.reshape(x.shape)) + return tuple(det_reshaped) + return _reproject + + +def build_driz_weight(model, weight_type=None, good_bits=None): + """ Create input weighting image + """ + dqmask = build_mask(model.dq, good_bits) + exptime = model.meta.exposure.exposure_time + + if weight_type == 'error': + err_model = np.nan_to_num(model.err) + inwht = (exptime / err_model)**2 * dqmask + log.debug("DEBUG weight mask: {} {}".format(type(inwht), np.sum(inwht))) + # elif weight_type == 'ivm': + # _inwht = img.buildIVMmask(chip._chip,dqarr,pix_ratio) + elif weight_type == 'exptime': + inwht = exptime * dqmask + else: + inwht = np.ones(model.data.shape, dtype=model.data.dtype) + return inwht + + +def build_mask(dqarr, bitvalue): + """ Builds a bit-mask from an input DQ array and a bitvalue flag + """ + bitvalue = interpret_bit_flags(bitvalue) + + if bitvalue is None: + return (np.ones(dqarr.shape, dtype=np.uint8)) + return np.logical_not(np.bitwise_and(dqarr, ~bitvalue)).astype(np.uint8) diff --git a/iris_pipeline/resample/tests/__init__.py b/iris_pipeline/resample/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/iris_pipeline/resample/tests/test_interface.py b/iris_pipeline/resample/tests/test_interface.py new file mode 100644 index 0000000..6a50838 --- /dev/null +++ b/iris_pipeline/resample/tests/test_interface.py @@ -0,0 +1,16 @@ +import pytest + +from ... import datamodels +from .. import ResampleSpecStep, ResampleStep + + +@pytest.mark.parametrize('resample_class', [ResampleSpecStep, ResampleStep]) +def test_multi_integration_input(resample_class): + cube = datamodels.CubeModel((5, 100, 100)) + cube.meta.instrument.name = 'MIRI' + cube.meta.observation.date = '2018-09-07' + cube.meta.observation.time = '10:32:20.181' + + # Resample can't handle cubes, so it should fail + with pytest.raises(RuntimeError): + resample_class().call(cube) diff --git a/iris_pipeline/resample/tests/test_resample_spec.py b/iris_pipeline/resample/tests/test_resample_spec.py new file mode 100644 index 0000000..4e637f5 --- /dev/null +++ b/iris_pipeline/resample/tests/test_resample_spec.py @@ -0,0 +1,86 @@ +import numpy as np +from numpy.testing import assert_allclose + +from gwcs.wcstools import grid_from_bounding_box + +from ...datamodels import ImageModel +from jwst.assign_wcs import AssignWcsStep +from jwst.resample import ResampleSpecStep + + +wcsinfo = { + 'dec_ref': -0.00601415671349804, + 'ra_ref': -0.02073605215697509, + 'roll_ref': -0.0, + 'v2_ref': -453.5134, + 'v3_ref': -373.4826, + 'v3yangle': 0.0, + 'vparity': -1 +} + + +instrument = { + 'detector': 'MIRIMAGE', + 'filter': 'P750L', + 'name': 'MIRI' +} + + +observation = { + 'date': '2019-01-01', + 'time': '17:00:00'} + + +subarray = { + 'fastaxis': 1, + 'name': 'SUBPRISM', + 'slowaxis': 2, + 'xsize': 72, + 'xstart': 1, + 'ysize': 416, + 'ystart': 529 +} + + +exposure = { + 'duration': 11.805952, + 'end_time': 58119.85416, + 'exposure_time': 11.776, + 'frame_time': 0.11776, + 'group_time': 0.11776, + 'groupgap': 0, + 'integration_time': 11.776, + 'nframes': 1, + 'ngroups': 100, + 'nints': 1, + 'nresets_between_ints': 0, + 'nsamples': 1, + 'readpatt': 'FAST', + 'sample_time': 10.0, + 'start_time': 58119.8333, + 'type': 'MIR_LRS-SLITLESS', + 'zero_frame': False} + + +def test_spatial_transform(): + """ + Calling the backwards WCS transform gives the same results + for ``negative RA`` and ``negative RA + 360``. + """ + im = ImageModel() + im.meta.wcsinfo._instance.update(wcsinfo) + im.meta.instrument._instance.update(instrument) + im.meta.exposure._instance.update(exposure) + im.meta.observation._instance.update(observation) + im.meta.subarray._instance.update(subarray) + + im = AssignWcsStep.call(im) + im.data = np.random.rand(416, 72) + im.error = np.random.rand(416, 72) + im.dq = np.random.rand(416, 72) + + im = ResampleSpecStep.call(im) + x, y =grid_from_bounding_box(im.meta.wcs.bounding_box) + ra, dec, lam = im.meta.wcs(x, y) + ra1 = np.where(ra < 0, 360 + ra, ra) + assert_allclose(im.meta.wcs.invert(ra, dec, lam), im.meta.wcs.invert(ra1, dec, lam)) diff --git a/iris_pipeline/resample/tests/test_utils.py b/iris_pipeline/resample/tests/test_utils.py new file mode 100644 index 0000000..460a134 --- /dev/null +++ b/iris_pipeline/resample/tests/test_utils.py @@ -0,0 +1,34 @@ +import numpy as np +import pytest + +from jwst.resample.resample_spec import find_dispersion_axis + + +def test_find_dispersion_axis(): + """ + Test the find_dispersion_axis() function + """ + wavelengths = np.arange(100) * 0.1 + np.exp(0.1) * 13.0 + # [14.36722193, 14.46722193, 14.56722193, ... 24.16722193, 24.26722193] + + wavelengths_horizontal = np.tile(wavelengths, 15).reshape(15, 100) + assert find_dispersion_axis(wavelengths_horizontal) == 0 + + wavelengths_vertical = np.repeat(wavelengths, 15).reshape(100, 15) + assert find_dispersion_axis(wavelengths_vertical) == 1 + + # Make sure it works for decreasing wavelengths + assert find_dispersion_axis(np.fliplr(wavelengths_horizontal)) == 0 + assert find_dispersion_axis(np.flipud(wavelengths_vertical)) == 1 + + # Make sure it works if there are NaNs + wavelengths_horizontal[:,0] = np.nan + assert find_dispersion_axis(wavelengths_horizontal) == 0 + + wavelengths_vertical[:,0] = np.nan + assert find_dispersion_axis(wavelengths_vertical) == 1 + + # Make sure if wavelengths don't change it produces an error + wavelengths_zeros = np.zeros((15, 100)) + with pytest.raises(RuntimeError): + find_dispersion_axis(wavelengths_zeros) From a30d7f14a0d3119f9dbe213e70f4696496cd5574 Mon Sep 17 00:00:00 2001 From: Andrea Zonca Date: Mon, 27 Jul 2020 18:17:04 -0700 Subject: [PATCH 2/3] import eveything from jwst --- iris_pipeline/__init__.py | 1 + iris_pipeline/resample/resample.py | 5 +++-- iris_pipeline/resample/resample_spec.py | 3 ++- iris_pipeline/resample/resample_spec_step.py | 10 ++++++---- iris_pipeline/resample/resample_step.py | 10 +++++----- iris_pipeline/resample/resample_utils.py | 2 +- 6 files changed, 18 insertions(+), 13 deletions(-) diff --git a/iris_pipeline/__init__.py b/iris_pipeline/__init__.py index cfe85d3..53e5bbc 100644 --- a/iris_pipeline/__init__.py +++ b/iris_pipeline/__init__.py @@ -37,5 +37,6 @@ class UnsupportedPythonError(Exception): from .parse_subarray_map import * from .merge_subarrays import * from .assign_wcs import * +from .resample import * from .datamodels import monkeypatch_jwst_datamodels diff --git a/iris_pipeline/resample/resample.py b/iris_pipeline/resample/resample.py index e096eeb..a2dc51f 100644 --- a/iris_pipeline/resample/resample.py +++ b/iris_pipeline/resample/resample.py @@ -2,11 +2,12 @@ from collections import OrderedDict import numpy as np -from .. import datamodels +# try first importing everything from JWST +from jwst import datamodels from . import gwcs_drizzle from . import resample_utils -from ..model_blender import blendmeta +from jwst.model_blender import blendmeta log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) diff --git a/iris_pipeline/resample/resample_spec.py b/iris_pipeline/resample/resample_spec.py index f6081f8..268978c 100644 --- a/iris_pipeline/resample/resample_spec.py +++ b/iris_pipeline/resample/resample_spec.py @@ -11,7 +11,8 @@ from gwcs import wcstools, WCS from gwcs import coordinate_frames as cf -from .. import datamodels +# try first importing everything from JWST +from jwst import datamodels from . import gwcs_drizzle from . import resample_utils diff --git a/iris_pipeline/resample/resample_spec_step.py b/iris_pipeline/resample/resample_spec_step.py index 0facdae..440950a 100755 --- a/iris_pipeline/resample/resample_spec_step.py +++ b/iris_pipeline/resample/resample_spec_step.py @@ -1,10 +1,12 @@ __all__ = ["ResampleSpecStep"] -from .. import datamodels -from ..datamodels import MultiSlitModel, ModelContainer +# try first importing everything from JWST +from jwst import datamodels +from jwst.datamodels import MultiSlitModel, ModelContainer + from . import resample_spec, ResampleStep -from ..exp_to_source import multislit_to_container -from ..assign_wcs.util import update_s_region_spectral +from jwst.exp_to_source import multislit_to_container +from jwst.assign_wcs.util import update_s_region_spectral class ResampleSpecStep(ResampleStep): diff --git a/iris_pipeline/resample/resample_step.py b/iris_pipeline/resample/resample_step.py index 8bc5de9..a7d9b13 100755 --- a/iris_pipeline/resample/resample_step.py +++ b/iris_pipeline/resample/resample_step.py @@ -2,12 +2,12 @@ import numpy as np -from ..stpipe import Step -from ..extern.configobj.validate import Validator -from ..extern.configobj.configobj import ConfigObj -from .. import datamodels +from jwst.stpipe import Step +from jwst.extern.configobj.validate import Validator +from jwst.extern.configobj.configobj import ConfigObj +from jwst import datamodels from . import resample -from ..assign_wcs import util +from jwst.assign_wcs import util log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) diff --git a/iris_pipeline/resample/resample_utils.py b/iris_pipeline/resample/resample_utils.py index 9a09c49..0ea02bf 100644 --- a/iris_pipeline/resample/resample_utils.py +++ b/iris_pipeline/resample/resample_utils.py @@ -8,7 +8,7 @@ from astropy.nddata.bitmask import interpret_bit_flags -from ..assign_wcs.util import wcs_from_footprints, wcs_bbox_from_shape +from jwst.assign_wcs.util import wcs_from_footprints, wcs_bbox_from_shape import logging log = logging.getLogger(__name__) From 00ded4ab444e3d67216c9d54d8fbc10d1cf3cd9f Mon Sep 17 00:00:00 2001 From: Andrea Zonca Date: Mon, 27 Jul 2020 18:40:47 -0700 Subject: [PATCH 3/3] feat: import all datamodels into root package --- iris_pipeline/__init__.py | 2 +- iris_pipeline/datamodels/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/iris_pipeline/__init__.py b/iris_pipeline/__init__.py index 53e5bbc..4016e87 100644 --- a/iris_pipeline/__init__.py +++ b/iris_pipeline/__init__.py @@ -39,4 +39,4 @@ class UnsupportedPythonError(Exception): from .assign_wcs import * from .resample import * -from .datamodels import monkeypatch_jwst_datamodels +from .datamodels import * diff --git a/iris_pipeline/datamodels/__init__.py b/iris_pipeline/datamodels/__init__.py index c8a61c6..8e0977c 100644 --- a/iris_pipeline/datamodels/__init__.py +++ b/iris_pipeline/datamodels/__init__.py @@ -14,6 +14,7 @@ __all__ = [ + "monkeypatch_jwst_datamodels", "IRISImageModel", "TMTRampModel", "TMTMaskModel", @@ -25,7 +26,7 @@ "TMTReferenceFileModel", ] -_all_models = __all__ +_all_models = __all__[1:] _local_dict = locals() _defined_models = {k: _local_dict[k] for k in _all_models}