Skip to content

Commit

Permalink
First commit for making QuantityTableCoordinate support pixel corners.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanRyanIrish committed May 29, 2024
1 parent ce4d0b1 commit c28ad88
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 22 deletions.
9 changes: 5 additions & 4 deletions ndcube/extra_coords/extra_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def add(self,
name: str | Iterable[str],
array_dimension: int | Iterable[int],
lookup_table: Any,
points: Iterable[float] = None,
physical_types: str | Iterable[str] = None,
**kwargs):
"""
Expand Down Expand Up @@ -204,7 +205,7 @@ def from_lookup_tables(cls, names, pixel_dimensions, lookup_tables, physical_typ

return extra_coords

def add(self, name, array_dimension, lookup_table, physical_types=None, **kwargs):
def add(self, name, array_dimension, lookup_table, points=None, physical_types=None, **kwargs):
# docstring in ABC

if self._wcs is not None:
Expand All @@ -217,13 +218,13 @@ def add(self, name, array_dimension, lookup_table, physical_types=None, **kwargs
if isinstance(lookup_table, BaseTableCoordinate):
coord = lookup_table
elif isinstance(lookup_table, Time):
coord = TimeTableCoordinate(lookup_table, physical_types=physical_types, **kwargs)
coord = TimeTableCoordinate(lookup_table, points=points, physical_types=physical_types, **kwargs)
elif isinstance(lookup_table, SkyCoord):
coord = SkyCoordTableCoordinate(lookup_table, physical_types=physical_types, **kwargs)
elif isinstance(lookup_table, (list, tuple)):
coord = QuantityTableCoordinate(*lookup_table, physical_types=physical_types, **kwargs)
coord = QuantityTableCoordinate(*lookup_table, points=points, physical_types=physical_types, **kwargs)
elif isinstance(lookup_table, u.Quantity):
coord = QuantityTableCoordinate(lookup_table, physical_types=physical_types, **kwargs)
coord = QuantityTableCoordinate(lookup_table, points=points, physical_types=physical_types, **kwargs)
else:
raise TypeError(f"The input type {type(lookup_table)} isn't supported")

Expand Down
89 changes: 71 additions & 18 deletions ndcube/extra_coords/table_coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _generate_generic_frame(naxes, unit, names=None, physical_types=None):
axes_names=names, name=name, axis_physical_types=physical_types)


def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs):
def _generate_tabular(lookup_table, points=None, interpolation='linear', points_unit=u.pix, **kwargs):
"""
Generate a Tabular model class and instance.
"""
Expand All @@ -139,7 +139,11 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, *
TabularND = tabular_model(ndim, name=f"Tabular{ndim}D")

# The integer location is at the centre of the pixel.
points = [(np.arange(size) - 0) * points_unit for size in lookup_table.shape]
if points is None:
points = [(np.arange(size) - 0) * points_unit for size in lookup_table.shape]
else:
points = points.to(points_unit)
[points] * ndim
if len(points) == 1:
points = points[0]

Expand All @@ -160,27 +164,28 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, *
return t


def _generate_compound_model(*lookup_tables, mesh=True):
def _generate_compound_model(*lookup_tables, points=None, mesh=True):
"""
Takes a set of quantities and returns a ND compound model.
"""
model = _generate_tabular(lookup_tables[0])
model = _generate_tabular(lookup_tables[0], points=points)
for lt in lookup_tables[1:]:
model = model & _generate_tabular(lt)
model = model & _generate_tabular(lt, points=points)

if mesh:
return model

# If we are not meshing the inputs duplicate the inputs across all models
mapping = list(range(lookup_tables[0].ndim)) * len(lookup_tables)
#mapping = list(range(lookup_tables[0].ndim)) * len(lookup_tables)
mapping = list(points) * len(lookup_tables)
return models.Mapping(mapping) | model


def _model_from_quantity(lookup_tables, mesh=False):
def _model_from_quantity(lookup_tables, points=None, mesh=False):
if len(lookup_tables) > 1:
return _generate_compound_model(*lookup_tables, mesh=mesh)
return _generate_compound_model(*lookup_tables, points=points, mesh=mesh)

return _generate_tabular(lookup_tables[0])
return _generate_tabular(lookup_tables[0], points=points)


class BaseTableCoordinate(abc.ABC):
Expand All @@ -196,8 +201,9 @@ class BaseTableCoordinate(abc.ABC):
coordinates, meaning it can have multiple gWCS frames.
"""

def __init__(self, *tables, mesh=False, names=None, physical_types=None):
def __init__(self, *tables, points=None, mesh=False, names=None, physical_types=None):
self.table = tables
self.points = points
self.mesh = mesh
self.names = names if not isinstance(names, str) else [names]
self.physical_types = physical_types if not isinstance(physical_types, str) else [physical_types]
Expand Down Expand Up @@ -290,6 +296,9 @@ class QuantityTableCoordinate(BaseTableCoordinate):
multiple 1-D Quantities can be provided representing the different
dimensions
points: `~astropy.units.Quantity` in pixel units.
The points in the grid grid to which the values in the tables correspond.
names: `str` or `list` of `str`
Custom names for the components of the QuantityTableCoord. If provided,
a name must be given for each input Quantity.
Expand All @@ -301,7 +310,7 @@ class QuantityTableCoordinate(BaseTableCoordinate):
a physical type must be given for each component.
"""

def __init__(self, *tables, names=None, physical_types=None):
def __init__(self, *tables, points=None, names=None, physical_types=None):
if not all([isinstance(t, u.Quantity) for t in tables]):
raise TypeError("All tables must be astropy Quantity objects")
if not all([t.unit.is_equivalent(tables[0].unit) for t in tables]):
Expand All @@ -312,6 +321,11 @@ def __init__(self, *tables, names=None, physical_types=None):
raise ValueError(
"Currently all tables must be 1-D. If you need >1D support, please "
"raise an issue at https://github.con/sunpy/ndcube/issues")
if points is not None:
if not points.unit.is_equivalent(u.pix):
raise u.UnitsError("Points must have pixel units.")
if points.shape != tables[0].shape:
raise ValueError("Points must be same shape as table(s).")

if isinstance(names, str):
names = [names]
Expand All @@ -324,7 +338,7 @@ def __init__(self, *tables, names=None, physical_types=None):

self.unit = tables[0].unit

super().__init__(*tables, mesh=True, names=names, physical_types=physical_types)
super().__init__(*tables, points=points, mesh=True, names=names, physical_types=physical_types)

def _slice_table(self, i, table, item, new_components, whole_slice):
"""
Expand Down Expand Up @@ -364,16 +378,43 @@ def __getitem__(self, item):
if not (len(item) == len(self.table) or len(item) == self.table[0].ndim):
raise ValueError("Can not slice with incorrect length")

# Convert item to table item based on points grid.
new_item = []
for idx in item:
if isinstance(idx, Integral):
new_idx = np.where(self.points == idx)
if new_idx == ():
raise NotImplementedError("Indexing QuantityTableCoordinate at inter-grid locations not supported.")
else:
new_item.append(new_idx[0][0])
elif isinstance(idx, slice):
new_start = np.where(self.points > slice.start - 1)
new_start = 0 if new_start == () else new_start[0][0]
new_stop = np.where(self.points >= slice.stop)
new_stop = len(self.points) if new_stop == () else new_stop[0][0]
new_item.append(slice(new_start, new_stop))
else:
new_idx = []
for i in idx:
new_i = np.where(self.points == i)
if new_i == ():
raise NotImplementedError("Indexing QuantityTableCoordinate at inter-grid locations not supported.")
else:
new_idx.append(new_i[0][0])
new_item.append(np.asarray(new_idx))
new_item = tuple(new_item)

new_components = defaultdict(list)
new_components["dropped_world_dimensions"] = copy.deepcopy(self._dropped_world_dimensions)

for i, (ele, table) in enumerate(zip(item, self.table)):
self._slice_table(i, table, ele, new_components, whole_slice=item)
self._slice_table(i, table, ele, new_components, whole_slice=new_item)

points = None if self.points is None else self.points[new_item]
names = new_components["names"] or None
physical_types = new_components["physical_types"] or None

ret_table = type(self)(*new_components["tables"], names=names, physical_types=physical_types)
ret_table = type(self)(*new_components["tables"], points=points, names=names, physical_types=physical_types)
ret_table._dropped_world_dimensions = new_components["dropped_world_dimensions"]
return ret_table

Expand All @@ -396,7 +437,7 @@ def model(self):
"""
Generate the Astropy Model for this LookupTable.
"""
return _model_from_quantity(self.table, True)
return _model_from_quantity(self.table, self.points, True)

@property
def ndim(self):
Expand All @@ -418,7 +459,7 @@ def shape(self):
"""
return tuple(len(t) for t in self.table)

def interpolate(self, *new_array_grids, **kwargs):
def interpolate(self, *new_array_grids, new_points=None, **kwargs):
"""
Interpolate QuantityTableCoordinate to new array index grids.
Expand All @@ -431,6 +472,10 @@ def interpolate(self, *new_array_grids, **kwargs):
represent a single location in the pixel grid. Therefore, array grids
must all have the same shape.
new_points: `~astropy.units.Quantity` in pixel units of `str`
The new pixel grid points to which the nely interpolating values will correspond.
Default=None implies they will correspond to pixel centers.
kwargs
All remaining kwargs are passed to underlying interpolation function.
Expand All @@ -455,8 +500,16 @@ def interpolate(self, *new_array_grids, **kwargs):
new_tables = [
np.interp(new_grid, old_grid, t.value, **kwargs) * t.unit
for new_grid, old_grid, t in zip(new_array_grids, old_array_grids, self.table)]
if new_points is None:
new_points = list(range(len(new_array_grids))) * u.pix
elif not isinstance(new_points, u.Quantity):
raise TypeError("new_points must be an astropy Quantity.")
elif not new_points.unit.is_equivalent(u.pix):
raise u.UnitError("new_points must have pixel units.")
elif new_points.shape == new_tables[0].shape):
raise ValueError("new_points must be same shape as tables.")
# Rebuild return interpolated coord.
new_coord = type(self)(*new_tables, names=self.names, physical_types=self.physical_types)
new_coord = type(self)(*new_tables, points=new_points, names=self.names, physical_types=self.physical_types)
new_coord._dropped_world_dimensions = self._dropped_world_dimensions
return new_coord

Expand Down Expand Up @@ -577,7 +630,7 @@ def model(self):
"""
Generate the Astropy Model for this LookupTable.
"""
return _model_from_quantity(self._sliced_components, mesh=self.mesh)
return _model_from_quantity(self._sliced_components, self.points, mesh=self.mesh)

@property
def ndim(self):
Expand Down

0 comments on commit c28ad88

Please sign in to comment.