Skip to content

Commit

Permalink
TF data - enable storing to HDF5 file
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanLoh committed Feb 21, 2024
1 parent 1300117 commit ef991cd
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 61 deletions.
25 changes: 21 additions & 4 deletions docs/io/tf_reading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
UnDySPuTeD Time-Frequency Data
==============================

blabala see tests :class:`~nenupy.io.tf.Spectra` or :meth:`~nenupy.astro.astro_tools.altaz_to_radec`


Reading a spectra file
-----------------------
Expand Down Expand Up @@ -137,6 +135,10 @@ Adding custom steps
(4 - Rebin in frequency)
5 - Compute Stokes parameters
.. warning::

This is a warning.


Getting the data
----------------
Expand All @@ -155,10 +157,25 @@ Getting the data
.. note::

There is a hardcoded size limit to the data output (i.e. after rebinning and all other pipeline operations) fixed at 2 GB, to prevent memory issues.
Users willing to bypass this limit may explicitely ask for it using the `ignore_data_size` argument of :meth:`~nenupy.io.tf.Spectra.get`:
Users willing to bypass this limit may explicitely ask for it using the ``ignore_volume_warning`` properties of :meth:`~nenupy.io.tf.Spectra.pipeline`.
This property can easily be updated directly by the :meth:`~nenupy.io.tf.Spectra.get` method:

.. code-block:: python
>>> sp.get(tmin="2023-05-27T08:40:00", tmax="2023-05-27T18:00:00", ignore_data_size=True)
>>> sp.get(
tmin="2023-05-27T08:40:00", tmax="2023-05-27T18:00:00",
ignore_volume_warning=True
)
Saving the data
---------------

.. code-block:: python
:emphasize-lines: 2
>>> sp.get(
file_name="/my/path/filename.hdf5"
stokes="I",
tmin="2023-05-27T08:41:30"
)
2 changes: 1 addition & 1 deletion nenupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__copyright__ = "Copyright 2023, nenupy"
__credits__ = ["Alan Loh"]
__license__ = "MIT"
__version__ = "2.6.9"
__version__ = "2.6.10"
__maintainer__ = "Alan Loh"
__email__ = "[email protected]"

Expand Down
26 changes: 14 additions & 12 deletions nenupy/io/io_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def plot(self, fig_ax=None, **kwargs):
# Save or show the figure
figname = kwargs.get("figname", "")
if figname != "":
plt.savefig(
fig.savefig(
figname,
dpi=300,
bbox_inches="tight",
Expand Down Expand Up @@ -1075,18 +1075,20 @@ def _plot_dynamic_spectrum(self, data, ax, fig, **kwargs):
ax.set_xlim(xlim)
ax.set_ylim(ylim)

cax = inset_axes(
ax,
width='3%',
height='100%',
loc='lower left',
bbox_to_anchor=(1.03, 0., 1, 1),
bbox_transform=ax.transAxes,
borderpad=0,
)
cbar = plt.colorbar(im, cax=cax)
# cbar = plt.colorbar(im, pad=0.03)
cbar = plt.colorbar(im, pad=0.03)#format='%.1e')
cbar.set_label(kwargs.get("colorbar_label", "dB" if kwargs.get("decibel", True) else "Amp"))
# cax = inset_axes(
# ax,
# width='3%',
# height='100%',
# loc='lower left',
# bbox_to_anchor=(1.03, 0., 1, 1),
# bbox_transform=ax.transAxes,
# borderpad=0,
# )
# cbar = plt.colorbar(im, cax=cax)
# # cbar = plt.colorbar(im, pad=0.03)
# cbar.set_label(kwargs.get("colorbar_label", "dB" if kwargs.get("decibel", True) else "Amp"))

# X label
ax.xaxis.set_major_formatter(
Expand Down
152 changes: 108 additions & 44 deletions nenupy/io/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,11 +815,13 @@ def info(self) -> None:
)
print(message)

def get(self, **pipeline_kwargs) -> SData:
def get(self, file_name: str = None, **pipeline_kwargs) -> SData:
r"""Perform data selection and pipeline computation.
Parameters
----------
file_name : `str`, default: ``None``
If different than ``None`` (default value), name of the HDF5 file (extension '.hdf5') to create and store the result.
**pipeline_kwargs
Any :attr:`~nenupy.io.tf.Spectra.pipeline` parameter passed as keyword argument from the list below.
Changes applied here are not kept once the method has resolved.
Expand Down Expand Up @@ -888,62 +890,46 @@ def get(self, **pipeline_kwargs) -> SData:
beam=self.pipeline.parameters["beam"],
)

# Run the pipeline
# Run the pipeline - within Dask
log.info(self.pipeline.info(return_str=True))
time_unix, frequency_hz, data = self.pipeline.run(
frequency_hz=frequency_hz, time_unix=time_unix, data=data
)

# Abort the process if the projected data volume is larger than the threshold
projected_data_volume = data.nbytes * u.byte
if (projected_data_volume >= DATA_VOLUME_SECURITY_THRESHOLD) and (
not self.pipeline.parameters["ignore_volume_warning"]
):
log.warning(
f"Data processing will produce {projected_data_volume.to(u.Gibyte)}."
f"The pipeline is interrupted because the volume threshold is {DATA_VOLUME_SECURITY_THRESHOLD.to(u.Gibyte)}."
time = Time(time_unix, format="unix", precision=7)
frequency = frequency_hz * u.Hz

if (file_name is None) or (file_name == ""):
# Simply compute the data
# The data volume security is ON
data = self._data_to_numpy_array(data) # compute() the Dask array
result = self._to_sdata(
data=data,
time=time,
frequency=frequency
)
return

log.info(
f"Computing the data (estimated volume: {projected_data_volume.to(u.Mibyte):.2})..."
)
with ProgressBar():
data = data.compute()
log.info(
f"\tData of shape (time, frequency, (polarization)) = {data.shape} produced."
)
self.pipeline.parameters = parameters_copy # Reset the parameters
return result

else:
# Save the result of the pipeline in a file
# No security on the resulting data volume
utils.store_dask_tf_data(
file_name=file_name,
data=data,
time=time,
frequency=frequency,
polarization=["XX", "XY", "YX", "YY"] if not self.pipeline.contains("Compute Stokes parameters") else self.pipeline.parameters["stokes"]
)
self.pipeline.parameters = parameters_copy # Reset the parameters
return

except Exception:
# Restore the parameters to their original default values
self.pipeline.parameters = parameters_copy
raise

# If other data product than Stokes are made
if not self.pipeline.contains("Compute Stokes parameters"):
# Make sure there are 3 dimensions (time, frequency, polarization)
# If there are more, the dimensions > 2 are all merged together.
# If there are 2 dimensions, an empty third is added
data = data.reshape(*data.shape[:2], -1)
result = SData(
data=data,
time=Time(time_unix, format="unix", precision=7),
freq=frequency_hz * u.Hz,
polar=["XX", "XY", "YX", "YY"],
)
else:
# If the data are regular Stokes parameters
result = SData(
data=data,
time=Time(time_unix, format="unix", precision=7),
freq=frequency_hz * u.Hz,
polar=self.pipeline.parameters["stokes"],
)

self.pipeline.parameters = parameters_copy

return result

def select_raw_data(
self,
tmin_unix: float,
Expand Down Expand Up @@ -1215,3 +1201,81 @@ def _to_dask_tf(self, data: np.ndarray, mask: np.ndarray) -> da.Array:
)

return data

def _data_to_numpy_array(self, data: da.Array) -> np.ndarray:
"""_summary_
Parameters
----------
data : da.Array
_description_
Returns
-------
np.ndarray
_description_
"""

# Abort the process if the projected data volume is larger than the threshold
projected_data_volume = data.nbytes * u.byte
if (projected_data_volume >= DATA_VOLUME_SECURITY_THRESHOLD) and (
not self.pipeline.parameters["ignore_volume_warning"]
):
log.warning(
f"Data processing will produce {projected_data_volume.to(u.Gibyte)}."
f"The pipeline is interrupted because the volume threshold is {DATA_VOLUME_SECURITY_THRESHOLD.to(u.Gibyte)}."
)
return

log.info(
f"Computing the data (estimated volume: {projected_data_volume.to(u.Mibyte):.2})..."
)

with ProgressBar():
data = data.compute()

log.info(
f"\tData of shape (time, frequency, (polarization)) = {data.shape} produced."
)

return data

def _to_sdata(self, data: np.ndarray, time: Time, frequency: u.Quantity) -> SData:
"""_summary_
Parameters
----------
data : np.ndarray
_description_
time : Time
_description_
frequency : u.Quantity
_description_
Returns
-------
SData
_description_
"""

# If other data product than Stokes are made
if not self.pipeline.contains("Compute Stokes parameters"):
# Make sure there are 3 dimensions (time, frequency, polarization)
# If there are more, the dimensions > 2 are all merged together.
# If there are 2 dimensions, an empty third is added
data = data.reshape(*data.shape[:2], -1)
return SData(
data=data,
time=time,
freq=frequency,
polar=["XX", "XY", "YX", "YY"],
)
else:
# If the data are regular Stokes parameters
return SData(
data=data,
time=time,
freq=frequency,
polar=self.pipeline.parameters["stokes"],
)

79 changes: 79 additions & 0 deletions nenupy/io/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
import numpy as np
import os
import dask.array as da
from dask.diagnostics import ProgressBar
import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.time import Time, TimeDelta
from typing import Union, List, Tuple, Any
from functools import partial
from abc import ABC, abstractmethod
import copy
import h5py
import logging

log = logging.getLogger(__name__)
Expand All @@ -44,6 +46,7 @@
"reshape_to_subbands",
"sort_beam_edges",
"spectra_data_to_matrix",
"store_dask_tf_data",
"TFPipelineParameters"
]

Expand Down Expand Up @@ -727,6 +730,82 @@ def spectra_data_to_matrix(fft0: da.Array, fft1: da.Array) -> da.Array:
)
return da.stack([row1, row2], axis=-1)

# ============================================================= #
# -------------------- store_dask_tf_data --------------------- #
def _time_to_keywords(prefix: str, time: Time) -> dict:
"""Returns a dictionnary of keywords in the HDF5 format."""
return {
f"{prefix.upper()}_MJD": time.mjd,
f"{prefix.upper()}_TAI": time.tai.isot,
f"{prefix.upper()}_UTC": time.isot + "Z",
}

def store_dask_tf_data(file_name: str, data: da.Array, time: Time, frequency: u.Quantity, polarization: np.ndarray, stored_frequency_unit: str = "MHz", **metadata) -> None:

log.info(f"Storing the data in {file_name}...")

# Check that the file_name has the correct extension
if not file_name.lower().endswith(".hdf5"):
raise ValueError(f"HDF5 files must ends with '.hdf5', got {file_name} instead.")

stored_freq_quantity = u.Unit(stored_frequency_unit)
frequency_min = frequency.min()
frequency_max = frequency.max()

with h5py.File(file_name, "w") as wf:

# Update main attributes
wf.attrs.update(metadata)
wf.attrs["SOFTWARE_NAME"] = "nenupy"
# wf.attrs["SOFTWARE_VERSION"] = nenupy.__version__
wf.attrs["SOFTWARE_MAINTAINER"] = "[email protected]"
wf.attrs.update(_time_to_keywords("OBSERVATION_START", time[0]))
wf.attrs.update(_time_to_keywords("OBSERVATION_END", time[-1]))
wf.attrs["OBSERVATION_FREQUENCY_MIN"] = frequency_min.to_value(stored_freq_quantity)
wf.attrs["OBSERVATION_FREQUENCY_MAX"] = frequency_max.to_value(stored_freq_quantity)
wf.attrs["OBSERVATION_FREQUENCY_CENTER"] = (
((frequency_max + frequency_min) / 2).to_value(stored_freq_quantity)
)
wf.attrs["OBSERVATION_FREQUENCY_UNIT"] = stored_frequency_unit

# Ravel the last polarization dimensions (above dim=2 -> freq)
data = np.reshape(data, data.shape[:2] + (-1,))

data_group = wf.create_group(f"data")
coordinates_group = data_group.create_group("axes")

# Set time and frequency axes
data_group.attrs.update(_time_to_keywords("TIME_START", time[0]))
data_group.attrs.update(_time_to_keywords("TIME_END", time[-1]))
data_group.attrs["FREQUENCY_MIN"] = (frequency_min.to_value(stored_freq_quantity))
data_group.attrs["FREQUENCY_MAX"] = (frequency_max.to_value(stored_freq_quantity))
data_group.attrs["FREQUENCY_UNIT"] = stored_frequency_unit
coordinates_group["frequency"] = frequency.to_value(stored_freq_quantity)
coordinates_group["frequency"].make_scale(f"Frequency ({stored_frequency_unit})")
coordinates_group["time"] = time.jd
coordinates_group["time"].make_scale("Time (JD)")

log.info("\tTime and frequency axes written.")

for pi in range(data.shape[-1]):
current_polar = polarization[pi]
log.info(f"\tDealing with polarization '{current_polar}'...")
data_i = data[:, :, pi]

dataset = data_group.create_dataset(
name=f"{current_polar.lower()}",
shape=data_i.shape,
dtype=data_i.dtype
)
with ProgressBar():
da.store(data_i, dataset, compute=True, return_stored=False)

dataset.dims[0].label = "frequency"
dataset.dims[0].attach_scale(coordinates_group["frequency"])
dataset.dims[1].label = "time"
dataset.dims[1].attach_scale(coordinates_group["time"])

log.info(f"\t'{file_name}' written.")

# ============================================================= #
# ------------------------ _Parameter ------------------------- #
Expand Down

0 comments on commit ef991cd

Please sign in to comment.