From cbacc25536a1013c7fe1071ddf5c11551853a327 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Wed, 18 Oct 2023 13:55:02 +0200 Subject: [PATCH] Initial hacky commit of zarr support. --- setup.py | 5 +---- tricolour/apps/tricolour/app.py | 24 +++++++++++++-------- tricolour/mask.py | 2 +- tricolour/packing.py | 4 ++-- tricolour/tests/test_flagging_additional.py | 8 +++---- tricolour/window_statistics.py | 8 +++---- 6 files changed, 27 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index 4d95568..d2935d6 100644 --- a/setup.py +++ b/setup.py @@ -9,14 +9,11 @@ readme = readme_file.read() requirements = [ - 'dask[array] == 2021.2.0', 'donfig >= 0.4.0', - 'numpy >= 1.14.0, <= 1.19.5', # breakage in newer numpy + numerical errors 'numba >= 0.43.0', 'scipy >= 1.2.0', 'threadpoolctl >= 1.0.0', - 'dask-ms == 0.2.6', - 'zarr >= 2.3.1' + 'dask-ms[xarray,zarr,s3]' ] extras_require = {'testing': ['pytest', diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index e420a7f..a54ef29 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -21,7 +21,9 @@ ResourceProfiler, CacheProfiler, visualize) import numpy as np -from daskms import xds_from_ms, xds_from_table, xds_to_table +from daskms import (xds_from_storage_ms, + xds_from_storage_table, + xds_to_storage_table) from threadpoolctl import threadpool_limits from tricolour.apps.tricolour.strat_executor import StrategyExecutor @@ -229,10 +231,10 @@ def support_tables(ms): """ # Get datasets for sub-tables partitioned by row when variably shaped - support = {t: xds_from_table("::".join((ms, t)), group_cols="__row__") + support = {t: xds_from_storage_table("::".join((ms, t)), group_cols="__row__") for t in ["FIELD", "POLARIZATION", "SPECTRAL_WINDOW"]} # These columns have fixed shapes - support.update({t: xds_from_table("::".join((ms, t)))[0] + support.update({t: xds_from_storage_table("::".join((ms, t)))[0] for t in ["ANTENNA", "DATA_DESCRIPTION"]}) # Reify all values upfront @@ -291,7 +293,7 @@ def _main(args): if args.subtract_model_column is not None: columns.append(args.subtract_model_column) - xds = list(xds_from_ms(args.ms, + xds = list(xds_from_storage_ms(args.ms, columns=tuple(columns), group_cols=group_cols, index_cols=index_cols, @@ -347,6 +349,7 @@ def _main(args): field_dict = {i: fn for i, fn in enumerate(fieldnames)} # List which hold our dask compute graphs for each dataset + writable_xds = [] write_computes = [] original_stats = [] final_stats = [] @@ -386,7 +389,7 @@ def _main(args): # Generate unflagged defaults if we should ignore existing flags # otherwise take flags from the dataset if args.ignore_flags is True: - flags = da.full_like(vis, False, dtype=np.bool) + flags = da.full_like(vis, False, dtype=np.bool_) log.critical("Completely ignoring measurement set " "flags as per '-if' request. " "Strategy WILL NOT or with original flags, even if " @@ -471,10 +474,13 @@ def _main(args): # Create new dataset containing new flags new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags)) - # Write back to original dataset - writes = xds_to_table(new_ds, args.ms, "FLAG") - # original should also have .compute called because we need stats - write_computes.append(writes) + # Append to list of datasets we intend to write to disk. + writable_xds.append(new_ds) + + # Write back to original dataset + write_computes = xds_to_storage_table( + writable_xds, args.ms, columns=("FLAG",), rechunk=True + ) if len(write_computes) > 0: # Combine stats from all datasets diff --git a/tricolour/mask.py b/tricolour/mask.py index da1ab02..17d9f98 100644 --- a/tricolour/mask.py +++ b/tricolour/mask.py @@ -60,7 +60,7 @@ def load_mask(filename, dilate): # Load mask mask = np.load(filename) - if mask.dtype[0] != np.bool or mask.dtype[1] != np.float64: + if mask.dtype[0] != np.bool_ or mask.dtype[1] != np.float64: raise ValueError("Mask %s is not a valid static mask " "with labelled channel axis " "[dtype == (bool, float64)]" % filename) diff --git a/tricolour/packing.py b/tricolour/packing.py index 6b009b6..76fed22 100644 --- a/tricolour/packing.py +++ b/tricolour/packing.py @@ -90,7 +90,7 @@ def _create_window_dask(name, ntime, nchan, nbl, ncorr, token, graph = HighLevelGraph.from_collections(collection_name, layers, ()) chunks = ((0,),) # One chunk containing single zarr array object - return da.Array(graph, collection_name, chunks, dtype=np.object) + return da.Array(graph, collection_name, chunks, dtype=object) def create_vis_windows(ntime, nchan, nbl, ncorr, token, @@ -343,7 +343,7 @@ def pack_data(time_inv, ubl, flags, ("row", "chan", "corr"), vis_win_obj, ("windim",), flag_win_obj, ("windim",), - dtype=np.bool) + dtype=np.bool_) # Expose visibility data at it's full resolution vis_windows = da.blockwise(_packed_windows, _WINDOW_SCHEMA, diff --git a/tricolour/tests/test_flagging_additional.py b/tricolour/tests/test_flagging_additional.py index 53303f9..0337ed9 100644 --- a/tricolour/tests/test_flagging_additional.py +++ b/tricolour/tests/test_flagging_additional.py @@ -131,7 +131,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines, accumulation_mode="or") # Check that first mask's flags are applied - chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool) + chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_) chan_sel[[2, 10]] = True assert np.all(new_flags[:, :, :, chan_sel] == 1) @@ -144,7 +144,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines, accumulation_mode="or") # Check that both mask's flags have been applied - chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool) + chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_) chan_sel[[2, 10, 4, 11, 5]] = True assert np.all(new_flags[:, :, :, chan_sel] == 1) @@ -157,7 +157,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines, accumulation_mode="override") # Check that only last mask's flags applied - chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool) + chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_) chan_sel[[4, 11, 5]] = True assert np.all(new_flags[:, :, :, chan_sel] == 1) @@ -176,7 +176,7 @@ def test_apply_static_mask(wsrt_ants, unique_baselines, uvrange=uvrange) # Check that both mask's flags have been applied - chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool) + chan_sel = np.zeros(chan_freqs.shape[0], dtype=np.bool_) chan_sel[[2, 10, 4, 11, 5]] = True # Select baselines based on the uvrange diff --git a/tricolour/window_statistics.py b/tricolour/window_statistics.py index 6057e70..20d1843 100644 --- a/tricolour/window_statistics.py +++ b/tricolour/window_statistics.py @@ -123,7 +123,7 @@ def window_stats(flag_window, ubls, chan_freqs, field_name, None, ddid, None, nchanbins, None, - meta=np.empty((0,), dtype=np.object)) + meta=np.empty((0,), dtype=object)) # Create an empty stats object if the user hasn't supplied one if prev_stats is None: @@ -131,13 +131,13 @@ def _window_stat_creator(): return WindowStatistics(nchanbins) prev_stats = da.blockwise(_window_stat_creator, (), - meta=np.empty((), dtype=np.object)) + meta=np.empty((), dtype=object)) # Combine per-baseline stats into a single stats object return da.blockwise(_combine_baseline_window_stats, (), stats, ("bl",), prev_stats, (), - meta=np.empty((), dtype=np.object)) + meta=np.empty((), dtype=object)) def _combine_window_stats(*args): @@ -167,7 +167,7 @@ def combine_window_stats(window_stats): args = (v for ws in window_stats for v in (ws, ())) return da.blockwise(_combine_window_stats, (), - *args, dtype=np.object) + *args, dtype=object) class WindowStatistics(object):