Skip to content

Commit

Permalink
#151 - reduce bass class code, simplify column and spin logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Bradley Augstein committed Jun 5, 2024
1 parent f92f0d3 commit 0112bf8
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 92 deletions.
2 changes: 1 addition & 1 deletion heracles/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def transform(
msg = f"unknown field name: {k}"
raise ValueError(msg)

out[k, i] = field.mapper_or_error.transform(m)
out[k, i] = field.mapper.transform(m)

if progress:
subtask.remove()
Expand Down
153 changes: 62 additions & 91 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,95 +64,45 @@ class Field(metaclass=ABCMeta):
# definition of required and optional columns
__ncol: tuple[int, int]

def __init_subclass__(cls, *, spin: int | None = None) -> None:
''' def __init_subclass__(cls, *, spin: int | None = None) -> None:
"""Initialise spin weight of field subclasses."""
super().__init_subclass__()
if spin is not None:
cls.__spin = spin
uses = cls.uses
if uses is None:
uses = ()
elif isinstance(uses, str):
uses = (uses,)
ncol = len(uses)
nopt = 0
for u in uses[::-1]:
if u.startswith("[") and u.endswith("]"):
nopt += 1
else:
break
cls.__ncol = (ncol - nopt, ncol)
cls.__spin = spin'''


def __init__(
self,
mapper: Mapper | None,
*columns: str,
*columns: str,
weight: str | None,
mask: str | None = None,
) -> None:
"""Initialise the field."""
super().__init__()
self.__mapper = mapper
self.__columns = self._init_columns(*columns) if columns else None
self.__columns = columns if columns else None
self.__weight = weight if weight else None
self.__mask = mask

@classmethod
def _init_columns(cls, *columns: str) -> Columns:
"""Initialise the given set of columns for a specific field
subclass."""
nmin, nmax = cls.__ncol
if not nmin <= len(columns) <= nmax:
uses = cls.uses
if uses is None:
uses = ()
if isinstance(uses, str):
uses = (uses,)
count = f"{nmin}"
if nmax != nmin:
count += f" to {nmax}"
msg = f"field of type '{cls.__name__}' accepts {count} columns"
if uses:
msg += " (" + ", ".join(uses) + ")"
msg += f", received {len(columns)}"
raise ValueError(msg)
return columns + (None,) * (nmax - len(columns))
self.__spin = 0

@property
def mapper(self) -> Mapper | None:
"""Return the mapper used by this field."""
return self.__mapper

@property
def mapper_or_error(self) -> Mapper:
"""Return the mapper used by this field, or raise a :class:`ValueError`
if not set."""
if self.__mapper is None:
msg = "no mapper for field"
raise ValueError(msg)
return self.__mapper
def weight(self) -> str | None:
"""Return the mapper used by this field."""
return self.__weight

@property
def columns(self) -> Columns | None:
"""Return the catalogue columns used by this field."""
return self.__columns

@property
def columns_or_error(self) -> Columns:
"""Return the catalogue columns used by this field, or raise a
:class:`ValueError` if not set."""
if self.__columns is None:
msg = "no columns for field"
raise ValueError(msg)
return self.__columns

@property
def spin(self) -> int:
"""Spin weight of field."""
spin = self.__spin
if spin is None:
clsname = self.__class__.__name__
msg = f"field of type '{clsname}' has undefined spin weight"
raise ValueError(msg)
return spin
return self.__spin

@property
def mask(self) -> str | None:
Expand All @@ -168,7 +118,13 @@ async def __call__(
) -> ArrayLike:
"""Implementation for mapping a catalogue."""
...


def CheckColumns(self, *expected):
if(len(expected)!=len(self.columns)):
error = "Column error. Expected " + str(len(expected)) + " columns"
error += "with a format " + str(expected) + ". Received " + str(self.columns)
raise ValueError(error)


async def _pages(
catalog: Catalog,
Expand All @@ -190,26 +146,25 @@ async def _pages(
await coroutines.sleep()


class Positions(Field, spin=0):
class Positions(Field):
"""Field of positions in a catalogue.
Can produce both overdensity maps and number count maps, depending
on the ``overdensity`` property.
"""

uses = "longitude", "latitude"

def __init__(
self,
mapper: Mapper | None,
*columns: str,
weight: str | None,
overdensity: bool = True,
nbar: float | None = None,
mask: str | None = None,
) -> None:
"""Create a position field."""
super().__init__(mapper, *columns, mask=mask)
super().__init__(mapper, *columns,weight=weight, mask=mask)
self.__overdensity = overdensity
self.__nbar = nbar

Expand Down Expand Up @@ -241,11 +196,15 @@ async def __call__(
msg = "cannot compute density contrast: no visibility in catalog"
raise ValueError(msg)

# get mapper
mapper = self.mapper_or_error
#get mapper
mapper = self.mapper

# get catalogue column definition
col = self.columns_or_error
col = self.columns
self.CheckColumns("longitude", "latitude")

#if(len(col)!=2):
# raise ValueError("Expect 2 colummns, longitude and latitude")

# position map
pos = mapper.create(spin=self.spin)
Expand All @@ -259,7 +218,7 @@ async def __call__(
lon, lat = page.get(*col)
w = np.ones(page.size)

mapper.map_values(lon, lat, pos, w)
self.mapper.map_values(lon, lat, pos, w)

ngal += page.size

Expand Down Expand Up @@ -307,11 +266,9 @@ async def __call__(
return pos


class ScalarField(Field, spin=0):
class ScalarField(Field):
"""Field of real scalar values in a catalogue."""

uses = "longitude", "latitude", "value", "[weight]"

async def __call__(
self,
catalog: Catalog,
Expand All @@ -321,11 +278,13 @@ async def __call__(
"""Map real values from catalogue to HEALPix map."""

# get mapper
mapper = self.mapper_or_error
mapper = self.mapper

# get the column definition of the catalogue
*col, wcol = self.columns_or_error

col = self.columns
self.CheckColumns(self, "longitude", "latitude", "value")

wcol = self.__weight
# scalar field map
val = mapper.create(spin=self.spin)

Expand Down Expand Up @@ -373,16 +332,14 @@ async def __call__(
return val


class ComplexField(Field, spin=0):
class ComplexField(Field):
"""Field of complex values in a catalogue.
The :class:`ComplexField` class has zero spin weight, while
subclasses such as :class:`Spin2Field` have non-zero spin weight.
"""

uses = "longitude", "latitude", "real", "imag", "[weight]"

async def __call__(
self,
catalog: Catalog,
Expand All @@ -392,10 +349,14 @@ async def __call__(
"""Map complex values from catalogue to HEALPix map."""

# get mapper
mapper = self.mapper_or_error
mapper = self.mapper

# get the column definition of the catalogue
*col, wcol = self.columns_or_error
col = self.columns

self.CheckColumns(self, "longitude", "latitude", "real", "imag")

wcol = self.weight

# complex map with real and imaginary part
val = mapper.create(2, spin=self.spin)
Expand Down Expand Up @@ -443,7 +404,7 @@ async def __call__(
return val


class Visibility(Field, spin=0):
class Visibility(Field):
"""Copy visibility map from catalogue at given resolution."""

async def __call__(
Expand All @@ -455,7 +416,7 @@ async def __call__(
"""Create a visibility map from the given catalogue."""

# get mapper
mapper = self.mapper_or_error
mapper = self.mapper

# make sure that catalogue has a visibility
visibility = catalog.visibility
Expand All @@ -479,11 +440,9 @@ async def __call__(
return out


class Weights(Field, spin=0):
class Weights(Field):
"""Field of weight values from a catalogue."""

uses = "longitude", "latitude", "[weight]"

async def __call__(
self,
catalog: Catalog,
Expand All @@ -493,10 +452,12 @@ async def __call__(
"""Map catalogue weights."""

# get mapper
mapper = self.mapper_or_error
mapper = self.mapper

# get the columns for this field
*col, wcol = self.columns_or_error
col = self.columns
self.CheckColumns(self, "longitude", "latitude")
wcol = self.weight

# weight map
wht = mapper.create(spin=self.spin)
Expand Down Expand Up @@ -543,8 +504,18 @@ async def __call__(
return wht


class Spin2Field(ComplexField, spin=2):
class Spin2Field(ComplexField):
"""Spin-2 complex field."""
def __init__(
self,
mapper: Mapper | None,
*columns: str,
weight: str | None,
mask: str | None = None,
) -> None:
"""Initialise the field."""
super().__init__(mapper, *columns,weight=weight, mask=mask)
self.__spin=2


Shears = Spin2Field
Expand Down

0 comments on commit 0112bf8

Please sign in to comment.