Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC for zarr support - DO NOT MERGE #91

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
readme = readme_file.read()

requirements = [
'dask[array] == 2021.2.0',
'donfig >= 0.4.0',
'numpy >= 1.14.0, <= 1.19.5', # breakage in newer numpy + numerical errors
'numba >= 0.43.0',
'scipy >= 1.2.0',
'threadpoolctl >= 1.0.0',
'dask-ms == 0.2.6',
'zarr >= 2.3.1'
'dask-ms[xarray,zarr,s3]'
]

extras_require = {'testing': ['pytest',
Expand Down
24 changes: 15 additions & 9 deletions tricolour/apps/tricolour/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
ResourceProfiler,
CacheProfiler, visualize)
import numpy as np
from daskms import xds_from_ms, xds_from_table, xds_to_table
from daskms import (xds_from_storage_ms,
xds_from_storage_table,
xds_to_storage_table)
from threadpoolctl import threadpool_limits

from tricolour.apps.tricolour.strat_executor import StrategyExecutor
Expand Down Expand Up @@ -229,10 +231,10 @@ def support_tables(ms):
"""

# Get datasets for sub-tables partitioned by row when variably shaped
support = {t: xds_from_table("::".join((ms, t)), group_cols="__row__")
support = {t: xds_from_storage_table("::".join((ms, t)), group_cols="__row__")
for t in ["FIELD", "POLARIZATION", "SPECTRAL_WINDOW"]}
# These columns have fixed shapes
support.update({t: xds_from_table("::".join((ms, t)))[0]
support.update({t: xds_from_storage_table("::".join((ms, t)))[0]
for t in ["ANTENNA", "DATA_DESCRIPTION"]})

# Reify all values upfront
Expand Down Expand Up @@ -291,7 +293,7 @@ def _main(args):
if args.subtract_model_column is not None:
columns.append(args.subtract_model_column)

xds = list(xds_from_ms(args.ms,
xds = list(xds_from_storage_ms(args.ms,
columns=tuple(columns),
group_cols=group_cols,
index_cols=index_cols,
Expand Down Expand Up @@ -347,6 +349,7 @@ def _main(args):
field_dict = {i: fn for i, fn in enumerate(fieldnames)}

# List which hold our dask compute graphs for each dataset
writable_xds = []
write_computes = []
original_stats = []
final_stats = []
Expand Down Expand Up @@ -386,7 +389,7 @@ def _main(args):
# Generate unflagged defaults if we should ignore existing flags
# otherwise take flags from the dataset
if args.ignore_flags is True:
flags = da.full_like(vis, False, dtype=np.bool)
flags = da.full_like(vis, False, dtype=np.bool_)
log.critical("Completely ignoring measurement set "
"flags as per '-if' request. "
"Strategy WILL NOT or with original flags, even if "
Expand Down Expand Up @@ -471,10 +474,13 @@ def _main(args):
# Create new dataset containing new flags
new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags))

# Write back to original dataset
writes = xds_to_table(new_ds, args.ms, "FLAG")
# original should also have .compute called because we need stats
write_computes.append(writes)
# Append to list of datasets we intend to write to disk.
writable_xds.append(new_ds)

# Write back to original dataset
write_computes = xds_to_storage_table(
writable_xds, args.ms, columns=("FLAG",), rechunk=True
)

if len(write_computes) > 0:
# Combine stats from all datasets
Expand Down
2 changes: 1 addition & 1 deletion tricolour/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def load_mask(filename, dilate):
# Load mask
mask = np.load(filename)

if mask.dtype[0] != np.bool or mask.dtype[1] != np.float64:
if mask.dtype[0] != np.bool_ or mask.dtype[1] != np.float64:
raise ValueError("Mask %s is not a valid static mask "
"with labelled channel axis "
"[dtype == (bool, float64)]" % filename)
Expand Down
4 changes: 2 additions & 2 deletions tricolour/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _create_window_dask(name, ntime, nchan, nbl, ncorr, token,

graph = HighLevelGraph.from_collections(collection_name, layers, ())
chunks = ((0,),) # One chunk containing single zarr array object
return da.Array(graph, collection_name, chunks, dtype=np.object)
return da.Array(graph, collection_name, chunks, dtype=object)


def create_vis_windows(ntime, nchan, nbl, ncorr, token,
Expand Down Expand Up @@ -343,7 +343,7 @@ def pack_data(time_inv, ubl,
flags, ("row", "chan", "corr"),
vis_win_obj, ("windim",),
flag_win_obj, ("windim",),
dtype=np.bool)
dtype=np.bool_)

# Expose visibility data at it's full resolution
vis_windows = da.blockwise(_packed_windows, _WINDOW_SCHEMA,
Expand Down
8 changes: 4 additions & 4 deletions tricolour/tests/test_flagging_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines,
accumulation_mode="or")

# Check that first mask's flags are applied
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool)
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_)
chan_sel[[2, 10]] = True

assert np.all(new_flags[:, :, :, chan_sel] == 1)
Expand All @@ -144,7 +144,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines,
accumulation_mode="or")

# Check that both mask's flags have been applied
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool)
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_)
chan_sel[[2, 10, 4, 11, 5]] = True

assert np.all(new_flags[:, :, :, chan_sel] == 1)
Expand All @@ -157,7 +157,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines,
accumulation_mode="override")

# Check that only last mask's flags applied
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool)
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_)
chan_sel[[4, 11, 5]] = True

assert np.all(new_flags[:, :, :, chan_sel] == 1)
Expand All @@ -176,7 +176,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines,
uvrange=uvrange)

# Check that both mask's flags have been applied
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool)
chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_)
chan_sel[[2, 10, 4, 11, 5]] = True

# Select baselines based on the uvrange
Expand Down
8 changes: 4 additions & 4 deletions tricolour/window_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,21 @@ def window_stats(flag_window, ubls, chan_freqs,
field_name, None,
ddid, None,
nchanbins, None,
meta=np.empty((0,), dtype=np.object))
meta=np.empty((0,), dtype=object))

# Create an empty stats object if the user hasn't supplied one
if prev_stats is None:
def _window_stat_creator():
return WindowStatistics(nchanbins)

prev_stats = da.blockwise(_window_stat_creator, (),
meta=np.empty((), dtype=np.object))
meta=np.empty((), dtype=object))

# Combine per-baseline stats into a single stats object
return da.blockwise(_combine_baseline_window_stats, (),
stats, ("bl",),
prev_stats, (),
meta=np.empty((), dtype=np.object))
meta=np.empty((), dtype=object))


def _combine_window_stats(*args):
Expand Down Expand Up @@ -167,7 +167,7 @@ def combine_window_stats(window_stats):
args = (v for ws in window_stats for v in (ws, ()))

return da.blockwise(_combine_window_stats, (),
*args, dtype=np.object)
*args, dtype=object)


class WindowStatistics(object):
Expand Down