diff --git a/.flake8 b/.flake8 index b538fe0c..a0f419a9 100644 --- a/.flake8 +++ b/.flake8 @@ -3,6 +3,6 @@ ignore = E501, W605, W503, E203, F401, E722 #,E402, E203,E722 #F841, E402, E722 #ignore = E203, E266, E501, W503, F403, F401 -max-line-length = 88 +max-line-length = 80 max-complexity = 18 select = B,C,E,F,W,T4,B9 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b12d023e..f0fe8a47 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,7 +18,8 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest"] - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.8, "3.11"] +# python-version: [3.8, 3.9, "3.10", "3.11"] # python-version: ["3.10", ] steps: @@ -33,17 +34,19 @@ jobs: - name: Install Env run: | python --version + echo $CONDA_PREFIX conda install -c conda-forge pytest pytest-cov certifi">=2017.4.17" pandoc pip install -r requirements-dev.txt - # pip install git+https://github.com/kujaku11/mt_metadata.git@main - # pip install git+https://github.com/kujaku11/mth5.git@master - # pip install git+https://github.com/MTgeophysics/mtpy-v2.git@main + pip install git+https://github.com/kujaku11/mt_metadata.git@features + pip install git+https://github.com/kujaku11/mth5.git@features pip install git+https://github.com/MTgeophysics/mtpy-v2.git@main pip uninstall aurora -y - name: Install Our Package run: | + echo $CONDA_PREFIX pip install -e . + echo "Install complete" conda list pip freeze @@ -63,7 +66,6 @@ jobs: jupyter nbconvert --to notebook --execute docs/tutorials/processing_configuration.ipynb jupyter nbconvert --to notebook --execute docs/tutorials/process_cas04_multiple_station.ipynb jupyter nbconvert --to notebook --execute docs/tutorials/synthetic_data_processing.ipynb - jupyter nbconvert --to notebook --execute tests/test_run_on_commit.ipynb # Replace "notebook.ipynb" with your notebook's filename # - name: Commit changes (if any) diff --git a/aurora/config/config_creator.py b/aurora/config/config_creator.py index 269e2974..6f340184 100644 --- a/aurora/config/config_creator.py +++ b/aurora/config/config_creator.py @@ -13,7 +13,7 @@ from aurora.config.metadata import Processing from aurora.sandbox.io_helpers.emtf_band_setup import EMTFBandSetupFile -from mt_metadata.transfer_functions.processing.aurora.window import Window +from mt_metadata.transfer_functions.processing.window import Window SUPPORTED_BAND_SPECIFICATION_STYLES = ["EMTF", "band_edges"] @@ -108,9 +108,7 @@ def determine_band_specification_style(self) -> None: logger.info(msg) self._emtf_band_file = BANDS_DEFAULT_FILE self._band_specification_style = "EMTF" - elif (self._emtf_band_file is not None) & ( - self._band_edges is not None - ): + elif (self._emtf_band_file is not None) & (self._band_edges is not None): msg = "Bands defined twice, and possibly inconsistently" logger.error(msg) raise ValueError(msg) @@ -152,6 +150,8 @@ def create_from_kernel_dataset( Theoretically, you could also use the number of decimations implied by bands_dict but this is sloppy, because it would assume the decimation factor. + 3. 2024-12-29 Added setting of decimation_obj.stft.per_window_detrend_type = "linear" + This makes tests pass following a refactoring of mt_metadata. Could use more testing. Parameters ---------- @@ -178,9 +178,7 @@ def create_from_kernel_dataset( Object storing the processing parameters. """ - processing_obj = Processing( - id=kernel_dataset.processing_id - ) # , **kwargs) + processing_obj = Processing(id=kernel_dataset.processing_id) # , **kwargs) # pack station and run info into processing object processing_obj.stations.from_dataset_dataframe(kernel_dataset.df) @@ -209,9 +207,7 @@ def create_from_kernel_dataset( decimation_factors[0] = 1 if num_samples_window is None: default_window = Window() - num_samples_window = num_decimations * [ - default_window.num_samples - ] + num_samples_window = num_decimations * [default_window.num_samples] elif isinstance(num_samples_window, int): num_samples_window = num_decimations * [num_samples_window] # now you can define the frequency bands @@ -241,11 +237,13 @@ def create_from_kernel_dataset( decimation_obj.output_channels = output_channels if num_samples_window is not None: - decimation_obj.window.num_samples = num_samples_window[key] + decimation_obj.stft.window.num_samples = num_samples_window[key] # set estimator if provided as kwarg if estimator: try: decimation_obj.estimator.engine = estimator["engine"] except KeyError: pass + decimation_obj.stft.per_window_detrend_type = "linear" + return processing_obj diff --git a/aurora/pipelines/fourier_coefficients.py b/aurora/pipelines/fourier_coefficients.py deleted file mode 100644 index 3b1c93c8..00000000 --- a/aurora/pipelines/fourier_coefficients.py +++ /dev/null @@ -1,323 +0,0 @@ -""" -Supporting codes for building the FC level of the mth5 - -Here are the parameters that are defined via the mt_metadata fourier coefficients structures: -"anti_alias_filter": "default", -"bands", -"decimation.factor": 4.0, -"decimation.level": 2, -"decimation.method": "default", -"decimation.sample_rate": 0.0625, -"extra_pre_fft_detrend_type": "linear", -"prewhitening_type": "first difference", -"window.clock_zero_type": "ignore", -"window.num_samples": 128, -"window.overlap": 32, -"window.type": "boxcar" - -Creating the decimations config requires a decision about decimation factors and the number of levels. -We have been getting this from the EMTF band setup file by default. It is desirable to continue supporting this, -however, note that the EMTF band setup is really about processing, and not about making STFTs. - -For the record, here is the legacy decimation config from EMTF, a.k.a. decset.cfg: -``` -4 0 # of decimation level, & decimation offset -128 32. 1 0 0 7 4 32 1 -1.0 -128 32. 4 0 0 7 4 32 4 -.2154 .1911 .1307 .0705 -128 32. 4 0 0 7 4 32 4 -.2154 .1911 .1307 .0705 -128 32. 4 0 0 7 4 32 4 -.2154 .1911 .1307 .0705 -``` - -This essentially corresponds to a "Decimations Group" which is a list of decimations. -Related to the generation of FCs is the ARMA prewhitening (Issue #60) which was controlled in -EMTF with pwset.cfg -4 5 # of decimation level, # of channels -3 3 3 3 3 -3 3 3 3 3 -3 3 3 3 3 -3 3 3 3 3 - -Note 1: Assumes application of cascading decimation, and that the -decimated data will be accessed from the previous decimation level. - -Note 2: We can encounter cases where some runs can be decimated and others can not. -We need a way to handle this. For example, a short run may not yield any data from a -later decimation level. An attempt to handle this has been made in TF Kernel by -adding a is_valid_dataset column, associated with each run-decimation level pair. - -Note 3: This point in the loop marks the interface between _generation_ of the FCs and - their _usage_. In future the code above this comment would be pushed into - create_fourier_coefficients() and the code below this would access those FCs and - execute compute_transfer_function() - - -""" - -# ============================================================================= -# Imports -# ============================================================================= - -import mt_metadata.timeseries.time_period -import mth5.mth5 -import pathlib - -from aurora.pipelines.time_series_helpers import calibrate_stft_obj -from aurora.pipelines.time_series_helpers import prototype_decimate -from aurora.pipelines.time_series_helpers import run_ts_to_stft_scipy -from loguru import logger -from mth5.mth5 import MTH5 -from mth5.utils.helpers import path_or_mth5_object -from mt_metadata.transfer_functions.processing.fourier_coefficients import ( - Decimation as FCDecimation, -) -from typing import List, Optional, Union - -# ============================================================================= -GROUPBY_COLUMNS = ["survey", "station", "sample_rate"] - - -def fc_decimations_creator( - initial_sample_rate: float, - decimation_factors: Optional[Union[list, None]] = None, - max_levels: Optional[int] = 6, - time_period: mt_metadata.timeseries.TimePeriod = None, -) -> list: - """ - - Creates mt_metadata FCDecimation objects that parameterize Fourier coefficient decimation levels. - - Note 1: This does not yet work through the assignment of which bands to keep. Refer to - mt_metadata.transfer_functions.processing.Processing.assign_bands() to see how this was done in the past - - Parameters - ---------- - initial_sample_rate: float - Sample rate of the "level0" data -- usually the sample rate during field acquisition. - decimation_factors: list (or other iterable) - The decimation factors that will be applied at each FC decimation level - max_levels: int - The maximum number of decimation levels to allow - time_period: - - Returns - ------- - fc_decimations: list - Each element of the list is an object of type - mt_metadata.transfer_functions.processing.fourier_coefficients.Decimation, - (a.k.a. FCDecimation). - - The order of the list corresponds the order of the cascading decimation - - No decimation levels are omitted. - - This could be changed in future by using a dict instead of a list, - - e.g. decimation_factors = dict(zip(np.arange(max_levels), decimation_factors)) - - """ - if not decimation_factors: - # msg = "No decimation factors given, set default values to EMTF default values [1, 4, 4, 4, ..., 4]") - # logger.info(msg) - default_decimation_factor = 4 - decimation_factors = max_levels * [default_decimation_factor] - decimation_factors[0] = 1 - - # See Note 1 - fc_decimations = [] - for i_dec_level, decimation_factor in enumerate(decimation_factors): - fc_dec = FCDecimation() - fc_dec.decimation_level = i_dec_level - fc_dec.id = f"{i_dec_level}" - fc_dec.decimation_factor = decimation_factor - if i_dec_level == 0: - current_sample_rate = 1.0 * initial_sample_rate - else: - current_sample_rate /= decimation_factor - fc_dec.sample_rate_decimation = current_sample_rate - - if time_period: - if isinstance(time_period, mt_metadata.timeseries.time_period.TimePeriod): - fc_dec.time_period = time_period - else: - msg = ( - f"Not sure how to assign time_period with type {type(time_period)}" - ) - logger.info(msg) - raise NotImplementedError(msg) - - fc_decimations.append(fc_dec) - - return fc_decimations - - -@path_or_mth5_object -def add_fcs_to_mth5( - m: MTH5, fc_decimations: Optional[Union[list, None]] = None -) -> None: - """ - Add Fourier Coefficient Levels ot an existing MTH5. - - **Notes:** - - - This module computes the FCs differently than the legacy aurora pipeline. It uses scipy.signal.spectrogram. There is a test in Aurora to confirm that there are equivalent if we are not using fancy pre-whitening. - - - Nomenclature: "usssr_grouper" is the output of a group-by on unique {survey, station, sample_rate} tuples. - - Parameters - ---------- - m: MTH5 object - The mth5 file, open in append mode. - fc_decimations: Union[str, None, List] - This specifies the scheme to use for decimating the time series when building the FC layer. - None: Just use default (something like four decimation levels, decimated by 4 each time say. - String: Controlled Vocabulary, values are a work in progress, that will allow custom definition of the fc_decimations for some common cases. For example, say you have stored already decimated time - series, then you want simply the zeroth decimation for each run, because the decimated time series live - under another run container, and that will get its own FCs. This is experimental. - List: (**UNTESTED**) -- This means that the user thought about the decimations that they want to create and is - passing them explicitly. -- probably will need to be a dictionary actually, since this - would get redefined at each sample rate. - - """ - # Group the channel summary by survey, station, sample_rate - channel_summary_df = m.channel_summary.to_dataframe() - usssr_grouper = channel_summary_df.groupby(GROUPBY_COLUMNS) - logger.debug(f"Detected {len(usssr_grouper)} unique station-sample_rate instances") - - # loop over groups - for (survey, station, sample_rate), usssr_group in usssr_grouper: - msg = f"\n\n\nsurvey: {survey}, station: {station}, sample_rate {sample_rate}" - logger.info(msg) - station_obj = m.get_station(station, survey) - run_summary = station_obj.run_summary - - # Get the FC decimation schemes if not provided - if not fc_decimations: - msg = "FC Decimations not supplied, creating defaults on the fly" - logger.info(f"{msg}") - fc_decimations = fc_decimations_creator( - initial_sample_rate=sample_rate, time_period=None - ) - elif isinstance(fc_decimations, str): - if fc_decimations == "degenerate": - fc_decimations = get_degenerate_fc_decimation(sample_rate) - - # TODO: Make this a function that can be done using df.apply() - for i_run_row, run_row in run_summary.iterrows(): - logger.info( - f"survey: {survey}, station: {station}, sample_rate {sample_rate}, i_run_row {i_run_row}" - ) - # Access Run - run_obj = m.from_reference(run_row.hdf5_reference) - - # Set the time period: - # TODO: Should this be over-writing time period if it is already there? - for fc_decimation in fc_decimations: - fc_decimation.time_period = run_obj.metadata.time_period - - # Access the data to Fourier transform - runts = run_obj.to_runts( - start=fc_decimation.time_period.start, - end=fc_decimation.time_period.end, - ) - run_xrds = runts.dataset - - # access container for FCs - fc_group = station_obj.fourier_coefficients_group.add_fc_group( - run_obj.metadata.id - ) - - # If timing corrections were needed they could go here, right before STFT - - for i_dec_level, fc_decimation in enumerate(fc_decimations): - if i_dec_level != 0: - # Apply decimation - run_xrds = prototype_decimate(fc_decimation, run_xrds) - - # check if this decimation level yields a valid spectrogram - if not fc_decimation.is_valid_for_time_series_length( - run_xrds.time.shape[0] - ): - logger.info( - f"Decimation Level {i_dec_level} invalid, TS of {run_xrds.time.shape[0]} samples too short" - ) - continue - - stft_obj = run_ts_to_stft_scipy(fc_decimation, run_xrds) - stft_obj = calibrate_stft_obj(stft_obj, run_obj) - - # Pack FCs into h5 and update metadata - decimation_level = fc_group.add_decimation_level( - f"{i_dec_level}", decimation_level_metadata=fc_decimation - ) - decimation_level.from_xarray( - stft_obj, decimation_level.metadata.sample_rate_decimation - ) - decimation_level.update_metadata() - fc_group.update_metadata() - return - - -def get_degenerate_fc_decimation(sample_rate: float) -> list: - """ - - Makes a default fc_decimation list. WIP - This "degnerate" config will only operate on the first decimation level. - This is useful for testing but could be used in future if an MTH5 stored time series in decimation - levels already as separate runs. - - Parameters - ---------- - sample_rate: float - The sample rate assocaiated with the time-series to convert to Spectrogram - - Returns - ------- - output: list - List has only one element which is of type mt_metadata.transfer_functions.processing.fourier_coefficients.Decimation. - """ - output = fc_decimations_creator( - sample_rate, - decimation_factors=[ - 1, - ], - max_levels=1, - ) - return output - - -@path_or_mth5_object -def read_back_fcs(m: Union[MTH5, pathlib.Path, str], mode="r"): - """ - This is mostly a helper function for tests. It was used as a sanity check while debugging the FC files, and - also is a good example for how to access the data at each level for each channel. - - The Time axis of the FC array will change from level to level, but the frequency axis will stay the same shape - (for now -- storing all fcs by default) - - Args: - m: pathlib.Path, str or an MTH5 object - The path to an h5 file that we will scan the fcs from - - - """ - channel_summary_df = m.channel_summary.to_dataframe() - logger.debug(channel_summary_df) - usssr_grouper = channel_summary_df.groupby(GROUPBY_COLUMNS) - for (survey, station, sample_rate), usssr_group in usssr_grouper: - logger.info(f"survey: {survey}, station: {station}, sample_rate {sample_rate}") - station_obj = m.get_station(station, survey) - fc_groups = station_obj.fourier_coefficients_group.groups_list - logger.info(f"FC Groups: {fc_groups}") - for run_id in fc_groups: - fc_group = station_obj.fourier_coefficients_group.get_fc_group(run_id) - dec_level_ids = fc_group.groups_list - for dec_level_id in dec_level_ids: - dec_level = fc_group.get_decimation_level(dec_level_id) - xrds = dec_level.to_xarray(["hx", "hy"]) - msg = f"dec_level {dec_level_id}" - msg = f"{msg} \n Time axis shape {xrds.time.data.shape}" - msg = f"{msg} \n Freq axis shape {xrds.frequency.data.shape}" - logger.debug(msg) - - return diff --git a/aurora/pipelines/process_mth5.py b/aurora/pipelines/process_mth5.py index 94250a3a..0dfc0ba0 100644 --- a/aurora/pipelines/process_mth5.py +++ b/aurora/pipelines/process_mth5.py @@ -6,7 +6,6 @@ can be repurposed for other TF estimation schemes. The "legacy" version corresponds to aurora default processing. - Notes on process_mth5_legacy: Note 1: process_mth5 assumes application of cascading decimation, and that the decimated data will be accessed from the previous decimation level. This should be @@ -22,7 +21,7 @@ Note 3: This point in the loop marks the interface between _generation_ of the FCs and their _usage_. In future the code above this comment would be pushed into - create_fourier_coefficients() and the code below this would access those FCs and + the creation of the spectrograms and the code below this would access those FCs and execute compute_transfer_function(). This would also be an appropriate place to place a feature extraction layer, and compute weights for the FCs. @@ -34,7 +33,6 @@ # ============================================================================= # Imports # ============================================================================= - from aurora.pipelines.time_series_helpers import calibrate_stft_obj from aurora.pipelines.time_series_helpers import run_ts_to_stft from aurora.pipelines.transfer_function_helpers import ( @@ -47,8 +45,12 @@ TransferFunctionCollection, ) from aurora.transfer_function.TTFZ import TTFZ +from mt_metadata.transfer_functions.processing.aurora.decimation_level import ( + DecimationLevel as AuroraDecimationLevel, +) from loguru import logger from mth5.helpers import close_open_files +from mth5.timeseries.spectre import Spectrogram from typing import Optional, Union import aurora.config.metadata.processing @@ -63,9 +65,7 @@ # ============================================================================= -def make_stft_objects( - processing_config, i_dec_level, run_obj, run_xrds, units="MT" -): +def make_stft_objects(processing_config, i_dec_level, run_obj, run_xrds, units="MT"): """ Operates on a "per-run" basis. Applies STFT to all time series in the input run. @@ -93,9 +93,52 @@ def make_stft_objects( ------- stft_obj: xarray.core.dataset.Dataset Time series of calibrated Fourier coefficients per each channel in the run + + Development Notes: + Here are the parameters that are defined via the mt_metadata fourier coefficients structures: + + "bands", + "decimation.anti_alias_filter": "default", + "decimation.factor": 4.0, + "decimation.level": 2, + "decimation.method": "default", + "decimation.sample_rate": 0.0625, + "stft.per_window_detrend_type": "linear", + "stft.prewhitening_type": "first difference", + "stft.window.clock_zero_type": "ignore", + "stft.window.num_samples": 128, + "stft.window.overlap": 32, + "stft.window.type": "boxcar" + + Creating the decimations config requires a decision about decimation factors and the number of levels. + We have been getting this from the EMTF band setup file by default. It is desirable to continue supporting this, + however, note that the EMTF band setup is really about a time series operation, and not about making STFTs. + + For the record, here is the legacy decimation config from EMTF, a.k.a. decset.cfg: + ``` + 4 0 # of decimation level, & decimation offset + 128 32. 1 0 0 7 4 32 1 + 1.0 + 128 32. 4 0 0 7 4 32 4 + .2154 .1911 .1307 .0705 + 128 32. 4 0 0 7 4 32 4 + .2154 .1911 .1307 .0705 + 128 32. 4 0 0 7 4 32 4 + .2154 .1911 .1307 .0705 + ``` + + This essentially corresponds to a "Decimations Group" which is a list of decimations. + Related to the generation of FCs is the ARMA prewhitening (Issue #60) which was controlled in + EMTF with pwset.cfg + 4 5 # of decimation level, # of channels + 3 3 3 3 3 + 3 3 3 3 3 + 3 3 3 3 3 + 3 3 3 3 3 + """ stft_config = processing_config.get_decimation_level(i_dec_level) - stft_obj = run_ts_to_stft(stft_config, run_xrds) + spectrogram = run_ts_to_stft(stft_config, run_xrds) run_id = run_obj.metadata.id if run_obj.station_metadata.id == processing_config.stations.local.id: scale_factors = processing_config.stations.local.run_dict[ @@ -103,17 +146,16 @@ def make_stft_objects( ].channel_scale_factors elif run_obj.station_metadata.id == processing_config.stations.remote[0].id: scale_factors = ( - processing_config.stations.remote[0] - .run_dict[run_id] - .channel_scale_factors + processing_config.stations.remote[0].run_dict[run_id].channel_scale_factors ) stft_obj = calibrate_stft_obj( - stft_obj, + spectrogram.dataset, run_obj, units=units, channel_scale_factors=scale_factors, ) + return stft_obj @@ -134,7 +176,7 @@ def process_tf_decimation_level( Parameters ---------- - config: mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel + config: aurora.config.metadata.processing.Processing, Config for a single decimation level i_dec_level: int decimation level_id @@ -152,9 +194,7 @@ def process_tf_decimation_level( The transfer function values packed into an object """ frequency_bands = config.decimations[i_dec_level].frequency_bands_obj() - transfer_function_obj = TTFZ( - i_dec_level, frequency_bands, processing_config=config - ) + transfer_function_obj = TTFZ(i_dec_level, frequency_bands, processing_config=config) dec_level_config = config.decimations[i_dec_level] # segment_weights = coherence_weights(dec_level_config, local_stft_obj, remote_stft_obj) transfer_function_obj = process_transfer_functions( @@ -183,9 +223,7 @@ def triage_issue_289(local_stfts: list, remote_stfts: list): for i_chunk in range(n_chunks): ok = local_stfts[i_chunk].time.shape == remote_stfts[i_chunk].time.shape if not ok: - logger.warning( - "Mismatch in FC array lengths detected -- Issue #289" - ) + logger.warning("Mismatch in FC array lengths detected -- Issue #289") glb = max( local_stfts[i_chunk].time.min(), remote_stfts[i_chunk].time.min(), @@ -196,18 +234,13 @@ def triage_issue_289(local_stfts: list, remote_stfts: list): ) cond1 = local_stfts[i_chunk].time >= glb cond2 = local_stfts[i_chunk].time <= lub - local_stfts[i_chunk] = local_stfts[i_chunk].where( - cond1 & cond2, drop=True - ) + local_stfts[i_chunk] = local_stfts[i_chunk].where(cond1 & cond2, drop=True) cond1 = remote_stfts[i_chunk].time >= glb cond2 = remote_stfts[i_chunk].time <= lub remote_stfts[i_chunk] = remote_stfts[i_chunk].where( cond1 & cond2, drop=True ) - assert ( - local_stfts[i_chunk].time.shape - == remote_stfts[i_chunk].time.shape - ) + assert local_stfts[i_chunk].time.shape == remote_stfts[i_chunk].time.shape return local_stfts, remote_stfts @@ -289,7 +322,7 @@ def load_stft_obj_from_mth5( """ Load stft_obj from mth5 (instead of compute) - Note #1: See note #1 in time_series.frequency_band_helpers.extract_band + Note #1: See note #1 in mth5.timeseries.spectre.spectrogram.py in extract_band function. Parameters ---------- @@ -306,9 +339,7 @@ def load_stft_obj_from_mth5( An STFT from mth5. """ station_obj = station_obj_from_row(row) - fc_group = station_obj.fourier_coefficients_group.get_fc_group( - run_obj.metadata.id - ) + fc_group = station_obj.fourier_coefficients_group.get_fc_group(run_obj.metadata.id) fc_decimation_level = fc_group.get_decimation_level(f"{i_dec_level}") stft_obj = fc_decimation_level.to_xarray(channels=channels) @@ -323,7 +354,9 @@ def load_stft_obj_from_mth5( return stft_chunk -def save_fourier_coefficients(dec_level_config, row, run_obj, stft_obj) -> None: +def save_fourier_coefficients( + dec_level_config: AuroraDecimationLevel, row: pd.Series, run_obj, stft_obj +) -> None: """ Optionally saves the stft object into the MTH5. Note that the dec_level_config must have its save_fcs attr set to True to actually save the data. @@ -369,10 +402,7 @@ def save_fourier_coefficients(dec_level_config, row, run_obj, stft_obj) -> None: raise NotImplementedError(msg) # Get FC group (create if needed) - if ( - run_obj.metadata.id - in station_obj.fourier_coefficients_group.groups_list - ): + if run_obj.metadata.id in station_obj.fourier_coefficients_group.groups_list: fc_group = station_obj.fourier_coefficients_group.get_fc_group( run_obj.metadata.id ) @@ -394,7 +424,7 @@ def save_fourier_coefficients(dec_level_config, row, run_obj, stft_obj) -> None: decimation_level_metadata=decimation_level_metadata, ) fc_decimation_level.from_xarray( - stft_obj, decimation_level_metadata.sample_rate + stft_obj, decimation_level_metadata.decimation.sample_rate ) fc_decimation_level.update_metadata() fc_group.update_metadata() @@ -405,10 +435,11 @@ def save_fourier_coefficients(dec_level_config, row, run_obj, stft_obj) -> None: return -def get_spectrogams(tfk, i_dec_level, units="MT"): +def get_spectrograms(tfk: TransferFunctionKernel, i_dec_level, units="MT"): """ Given a decimation level id, loads a dictianary of all spectragrams from information in tfk. TODO: Make this a method of TFK + TODO: Modify this to be able to yield Spectrogram objects. Parameters ---------- @@ -442,6 +473,7 @@ def get_spectrogams(tfk, i_dec_level, units="MT"): run_obj = row.mth5_obj.from_reference(row.run_hdf5_reference) if row.fc: stft_obj = load_stft_obj_from_mth5(i_dec_level, row, run_obj) + # TODO: Cast stft_obj to a Spectrogram here stfts = append_chunk_to_stfts(stfts, stft_obj, row.remote) continue @@ -457,10 +489,13 @@ def get_spectrogams(tfk, i_dec_level, units="MT"): run_xrds, units, ) + # TODO: Cast stft_obj to a Spectrogram here or in make_stft_objects # Pack FCs into h5 dec_level_config = tfk.config.decimations[i_dec_level] save_fourier_coefficients(dec_level_config, row, run_obj, stft_obj) + # TODO: 1st pass, cast stft_obj to a Spectrogram here + stfts = append_chunk_to_stfts(stfts, stft_obj, row.remote) return stfts @@ -522,7 +557,7 @@ def process_mth5_legacy( tfk.update_dataset_df(i_dec_level) tfk.apply_clock_zero(dec_level_config) - stfts = get_spectrogams(tfk, i_dec_level, units=units) + stfts = get_spectrograms(tfk, i_dec_level, units=units) local_merged_stft_obj, remote_merged_stft_obj = merge_stfts(stfts, tfk) @@ -535,9 +570,7 @@ def process_mth5_legacy( local_merged_stft_obj, remote_merged_stft_obj, ) - ttfz_obj.apparent_resistivity( - tfk.config.channel_nomenclature, units=units - ) + ttfz_obj.apparent_resistivity(tfk.config.channel_nomenclature, units=units) tf_dict[i_dec_level] = ttfz_obj if show_plot: @@ -549,10 +582,20 @@ def process_mth5_legacy( tf_dict=tf_dict, processing_config=tfk.config ) - tf_cls = tfk.export_tf_collection(tf_collection) - - if z_file_path: - tf_cls.write(z_file_path) + try: + tf_cls = tfk.export_tf_collection(tf_collection) + if z_file_path: + tf_cls.write(z_file_path) + except Exception as e: + msg = "TF collection could not export to mt_metadata TransferFunction\n" + msg += f"Failed with exception {e}\n" + msg += "Perhaps an unconventional mixture of input/output channels was used\n" + msg += f"Input channels were {tfk.config.decimations[0].input_channels}\n" + msg += f"Output channels were {tfk.config.decimations[0].output_channels}\n" + msg += "No z-file will be written in this case\n" + msg += "Will return a legacy TransferFunctionCollection object, not mt_metadata object." + logger.error(msg) + return_collection = True tfk.dataset.close_mth5s() if return_collection: @@ -602,9 +645,7 @@ def process_mth5( The transfer function object """ if processing_type not in SUPPORTED_PROCESSINGS: - raise NotImplementedError( - f"Processing type {processing_type} not supported" - ) + raise NotImplementedError(f"Processing type {processing_type} not supported") if processing_type == "legacy": try: diff --git a/aurora/pipelines/time_series_helpers.py b/aurora/pipelines/time_series_helpers.py index 5102d7c2..8d9ea54c 100644 --- a/aurora/pipelines/time_series_helpers.py +++ b/aurora/pipelines/time_series_helpers.py @@ -11,169 +11,24 @@ from aurora.time_series.windowed_time_series import WindowedTimeSeries from aurora.time_series.windowing_scheme import window_scheme_from_decimation - - -def validate_sample_rate(run_ts, expected_sample_rate, tol=1e-4): - """ - Check that the sample rate of a run_ts is the expected value, and warn if not. - - Parameters - ---------- - run_ts: mth5.timeseries.run_ts.RunTS - Time series object with data and metadata. - expected_sample_rate: float - The sample rate the time series is expected to have. Normally taken from - the processing config - - """ - if run_ts.sample_rate != expected_sample_rate: - msg = ( - f"sample rate in run time series {run_ts.sample_rate} and " - f"processing decimation_obj {expected_sample_rate} do not match" - ) - logger.warning(msg) - delta = run_ts.sample_rate - expected_sample_rate - if np.abs(delta) > tol: - msg = f"Delta sample rate {delta} > {tol} tolerance" - msg += "TOL should be a percentage" - raise Exception(msg) - - -def apply_prewhitening(decimation_obj, run_xrds_input): - """ - Applies pre-whitening to time series to avoid spectral leakage when FFT is applied. - - If "first difference", may want to consider clipping first and last sample from - the differentiated time series. - - Parameters - ---------- - decimation_obj : mt_metadata.transfer_functions.processing.aurora.DecimationLevel - Information about how the decimation level is to be processed - run_xrds_input : xarray.core.dataset.Dataset - Time series to be pre-whitened - - Returns - ------- - run_xrds : xarray.core.dataset.Dataset - pre-whitened time series - - """ - if not decimation_obj.prewhitening_type: - return run_xrds_input - - if decimation_obj.prewhitening_type == "first difference": - run_xrds = run_xrds_input.differentiate("time") - - else: - msg = f"{decimation_obj.prewhitening_type} pre-whitening not implemented" - logger.exception(msg) - raise NotImplementedError(msg) - return run_xrds - - -def apply_recoloring(decimation_obj, stft_obj): - """ - Inverts the pre-whitening operation in frequency domain. - - Parameters - ---------- - decimation_obj : mt_metadata.transfer_functions.processing.fourier_coefficients.decimation.Decimation - Information about how the decimation level is to be processed - stft_obj : xarray.core.dataset.Dataset - Time series of Fourier coefficients to be recoloured - - - Returns - ------- - stft_obj : xarray.core.dataset.Dataset - Recolored time series of Fourier coefficients - """ - # No recoloring needed if prewhitening not appiled, or recoloring set to False - if not decimation_obj.prewhitening_type: - return stft_obj - if not decimation_obj.recoloring: - return stft_obj - - if decimation_obj.prewhitening_type == "first difference": - freqs = decimation_obj.fft_frequencies - prewhitening_correction = 1.0j * 2 * np.pi * freqs # jw - - stft_obj /= prewhitening_correction - - # suppress nan and inf to mute later warnings - if prewhitening_correction[0] == 0.0: - cond = stft_obj.frequency != 0.0 - stft_obj = stft_obj.where(cond, complex(0.0)) - # elif decimation_obj.prewhitening_type == "ARMA": - # from statsmodels.tsa.arima.model import ARIMA - # AR = 3 # add this to processing config - # MA = 4 # add this to processing config - - else: - msg = f"{decimation_obj.prewhitening_type} recoloring not yet implemented" - logger.error(msg) - raise NotImplementedError(msg) - - return stft_obj - - -def run_ts_to_stft_scipy(decimation_obj, run_xrds_orig): - """ - Converts a runts object into a time series of Fourier coefficients. - This method uses scipy.signal.spectrogram. - - Parameters - ---------- - decimation_obj : mt_metadata.transfer_functions.processing.aurora.DecimationLevel - Information about how the decimation level is to be processed - run_xrds_orig : : xarray.core.dataset.Dataset - Time series to be processed - - Returns - ------- - stft_obj : xarray.core.dataset.Dataset - Time series of Fourier coefficients - """ - run_xrds = apply_prewhitening(decimation_obj, run_xrds_orig) - windowing_scheme = window_scheme_from_decimation(decimation_obj) - - stft_obj = xr.Dataset() - for channel_id in run_xrds.data_vars: - ff, tt, specgm = ssig.spectrogram( - run_xrds[channel_id].data, - fs=decimation_obj.sample_rate_decimation, - window=windowing_scheme.taper, - nperseg=decimation_obj.window.num_samples, - noverlap=decimation_obj.window.overlap, - detrend="linear", - scaling="density", - mode="complex", - ) - - # drop Nyquist> - ff = ff[:-1] - specgm = specgm[:-1, :] - specgm *= np.sqrt(2) - - # make time_axis - tt = tt - tt[0] - tt *= decimation_obj.sample_rate_decimation - time_axis = run_xrds.time.data[tt.astype(int)] - - xrd = xr.DataArray( - specgm.T, - dims=["time", "frequency"], - coords={"frequency": ff, "time": time_axis}, - ) - stft_obj.update({channel_id: xrd}) - - stft_obj = apply_recoloring(decimation_obj, stft_obj) - - return stft_obj - - -def truncate_to_clock_zero(decimation_obj, run_xrds): +from mt_metadata.transfer_functions.processing import TimeSeriesDecimation +from mt_metadata.transfer_functions.processing.aurora.decimation_level import ( + DecimationLevel as AuroraDecimationLevel, +) +from mt_metadata.transfer_functions.processing.fourier_coefficients import ( + Decimation as FCDecimation, +) +from mth5.groups import RunGroup +from mth5.timeseries.spectre.prewhitening import apply_prewhitening +from mth5.timeseries.spectre.prewhitening import apply_recoloring +import mth5.timeseries.spectre as spectre +from typing import Literal, Optional, Union + + +def truncate_to_clock_zero( + decimation_obj: Union[AuroraDecimationLevel, FCDecimation], + run_xrds: RunGroup, +): """ Compute the time interval between the first data sample and the clock zero Identify the first sample in the xarray time series that corresponds to a @@ -181,7 +36,7 @@ def truncate_to_clock_zero(decimation_obj, run_xrds): Parameters ---------- - decimation_obj: mt_metadata.transfer_functions.processing.aurora.DecimationLevel + decimation_obj: Union[AuroraDecimationLevel, FCDecimation] Information about how the decimation level is to be processed run_xrds : xarray.core.dataset.Dataset normally extracted from mth5.RunTS @@ -192,10 +47,10 @@ def truncate_to_clock_zero(decimation_obj, run_xrds): run_xrds : xarray.core.dataset.Dataset same as the input time series, but possibly slightly shortened """ - if decimation_obj.window.clock_zero_type == "ignore": + if decimation_obj.stft.window.clock_zero_type == "ignore": pass else: - clock_zero = pd.Timestamp(decimation_obj.window.clock_zero) + clock_zero = pd.Timestamp(decimation_obj.stft.window.clock_zero) clock_zero = clock_zero.to_datetime64() delta_t = clock_zero - run_xrds.time[0] assert delta_t.dtype == "<m8[ns]" # expected in nanoseconds @@ -212,7 +67,7 @@ def truncate_to_clock_zero(decimation_obj, run_xrds): cond1 = run_xrds.time >= t_clip msg = ( f"dropping {n_clip} samples to agree with " - f"{decimation_obj.window.clock_zero_type} clock zero {clock_zero}" + f"{decimation_obj.stft.window.clock_zero_type} clock zero {clock_zero}" ) logger.info(msg) run_xrds = run_xrds.where(cond1, drop=True) @@ -242,18 +97,21 @@ def nan_to_mean(xrds: xr.Dataset) -> xr.Dataset: return xrds -def run_ts_to_stft(decimation_obj, run_xrds_orig): +def run_ts_to_stft( + decimation_obj: AuroraDecimationLevel, run_xrds_orig: xr.Dataset +) -> spectre.Spectrogram: """ Converts a runts object into a time series of Fourier coefficients. Similar to run_ts_to_stft_scipy, but in this implementation operations on individual windows are possible (for example pre-whitening per time window via ARMA filtering). + TODO: Make the output of this function a Spectrogram object Parameters ---------- - decimation_obj : mt_metadata.transfer_functions.processing.aurora.DecimationLevel + decimation_obj : AuroraDecimationLevel Information about how the decimation level is to be processed - run_ts : xarray.core.dataset.Dataset + run_xrds_orig: xarray.core.dataset.Dataset normally extracted from mth5.RunTS Returns @@ -266,11 +124,11 @@ def run_ts_to_stft(decimation_obj, run_xrds_orig): # need to remove any nans before windowing, or else if there is a single # nan then the whole channel becomes nan. run_xrds = nan_to_mean(run_xrds_orig) - run_xrds = apply_prewhitening(decimation_obj, run_xrds) + run_xrds = apply_prewhitening(decimation_obj.stft.prewhitening_type, run_xrds) run_xrds = truncate_to_clock_zero(decimation_obj, run_xrds) windowing_scheme = window_scheme_from_decimation(decimation_obj) windowed_obj = windowing_scheme.apply_sliding_window( - run_xrds, dt=1.0 / decimation_obj.sample_rate_decimation + run_xrds, dt=1.0 / decimation_obj.decimation.sample_rate ) if not np.prod(windowed_obj.to_array().data.shape): raise ValueError @@ -283,15 +141,22 @@ def run_ts_to_stft(decimation_obj, run_xrds_orig): data=tapered_obj, sample_rate=windowing_scheme.sample_rate, spectral_density_correction=windowing_scheme.linear_spectral_density_calibration_factor, - detrend_type=decimation_obj.extra_pre_fft_detrend_type, + detrend_type=decimation_obj.stft.per_window_detrend_type, ) - stft_obj = apply_recoloring(decimation_obj, stft_obj) + if decimation_obj.stft.recoloring: + stft_obj = apply_recoloring(decimation_obj.stft.prewhitening_type, stft_obj) - return stft_obj + spectrogram = spectre.Spectrogram(dataset=stft_obj) + return spectrogram -def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None): +def calibrate_stft_obj( + stft_obj: xr.Dataset, + run_obj: RunGroup, + units: Literal["MT", "SI"] = "MT", + channel_scale_factors: Optional[dict] = None, +) -> xr.Dataset: """ Calibrates frequency domain data into MT units. @@ -326,7 +191,7 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None logger.warning(msg) if channel_id == "hy": msg = "Channel hy has no filters, try using filters from hx" - logger.warning("Channel HY has no filters, try using filters from HX") + logger.warning(msg) channel_response = run_obj.get_channel("hx").channel_response indices_to_flip = channel_response.get_indices_of_filters_to_remove( @@ -338,29 +203,33 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None filters_to_remove = [channel_response.filters_list[i] for i in indices_to_flip] if not filters_to_remove: logger.warning("No filters to remove") + calibration_response = channel_response.complex_response( stft_obj.frequency.data, filters_list=filters_to_remove ) + if channel_scale_factors: try: channel_scale_factor = channel_scale_factors[channel_id] except KeyError: channel_scale_factor = 1.0 calibration_response /= channel_scale_factor + if units == "SI": logger.warning("Warning: SI Units are not robustly supported issue #36") - # TODO: This often raises a runtime warning due to DC term in calibration response=0 + + # TODO: FIXME Sometimes raises a runtime warning due to DC term in calibration response = 0 stft_obj[channel_id].data /= calibration_response return stft_obj def prototype_decimate( - config: mt_metadata.transfer_functions.processing.aurora.decimation.Decimation, + ts_decimation: TimeSeriesDecimation, run_xrds: xr.Dataset, ) -> xr.Dataset: """ Basically a wrapper for scipy.signal.decimate. Takes input timeseries (as xarray - Dataset) and a Decimation config object and returns a decimated version of the + Dataset) and a TimeSeriesDecimation object and returns a decimated version of the input time series. TODO: Consider moving this function into time_series/decimate.py @@ -370,7 +239,7 @@ def prototype_decimate( Parameters ---------- - config : mt_metadata.transfer_functions.processing.aurora.Decimation + ts_decimation : AuroraDecimationLevel run_xrds: xr.Dataset Originally from mth5.timeseries.run_ts.RunTS.dataset, but possibly decimated multiple times @@ -381,7 +250,7 @@ def prototype_decimate( Decimated version of the input run_xrds """ # downsample the time axis - slicer = slice(None, None, int(config.factor)) # decimation.factor + slicer = slice(None, None, int(ts_decimation.factor)) # decimation.factor downsampled_time_axis = run_xrds.time.data[slicer] # decimate the time series @@ -390,7 +259,8 @@ def prototype_decimate( num_channels = len(channel_labels) new_data = np.full((num_observations, num_channels), np.nan) for i_ch, ch_label in enumerate(channel_labels): - new_data[:, i_ch] = ssig.decimate(run_xrds[ch_label], int(config.factor)) + # TODO: add check here for ts_decimation.anti_alias_filter + new_data[:, i_ch] = ssig.decimate(run_xrds[ch_label], int(ts_decimation.factor)) xr_da = xr.DataArray( new_data, @@ -398,7 +268,7 @@ def prototype_decimate( coords={"time": downsampled_time_axis, "channel": channel_labels}, ) attr_dict = run_xrds.attrs - attr_dict["sample_rate"] = config.sample_rate + attr_dict["sample_rate"] = ts_decimation.sample_rate xr_da.attrs = attr_dict xr_ds = xr_da.to_dataset("channel") return xr_ds @@ -412,7 +282,7 @@ def prototype_decimate( # Method is fast. Might be non-linear. Seems to give similar performance to # prototype_decimate for synthetic data. # -# N.B. config.factor must be integer valued +# N.B. config.decimation.factor must be integer valued # # Parameters # ---------- @@ -426,7 +296,7 @@ def prototype_decimate( # xr_ds: xr.Dataset # Decimated version of the input run_xrds # """ -# new_xr_ds = run_xrds.coarsen(time=int(config.factor), boundary="trim").mean() +# new_xr_ds = run_xrds.coarsen(time=int(config.decimation.factor), boundary="trim").mean() # attr_dict = run_xrds.attrs # attr_dict["sample_rate"] = config.sample_rate # new_xr_ds.attrs = attr_dict @@ -451,7 +321,7 @@ def prototype_decimate( # Decimated version of the input run_xrds # """ # dt = run_xrds.time.diff(dim="time").median().values -# dt_new = config.factor * dt +# dt_new = config.decimation.factor * dt # dt_new = dt_new.__str__().replace("nanoseconds", "ns") # new_xr_ds = run_xrds.resample(time=dt_new).mean(dim="time") # attr_dict = run_xrds.attrs diff --git a/aurora/pipelines/transfer_function_helpers.py b/aurora/pipelines/transfer_function_helpers.py index 489c8965..41588a97 100644 --- a/aurora/pipelines/transfer_function_helpers.py +++ b/aurora/pipelines/transfer_function_helpers.py @@ -2,9 +2,9 @@ This module contains helper methods that are used during transfer function processing. Development Notes: -Note #1: repeatedly applying edf_weights seems to have no effect at all. -tested 20240118 and found that test_compare in synthetic passed whether this was commented -or not. TODO confirm this is a one-and-done add doc about why this is so. + Note #1: repeatedly applying edf_weights seems to have no effect at all. + tested 20240118 and found that test_compare in synthetic passed whether this was commented + or not. TODO confirm this is a one-and-done add doc about why this is so. """ @@ -19,20 +19,25 @@ from aurora.transfer_function.weights.edf_weights import ( effective_degrees_of_freedom_weights, ) +from mt_metadata.transfer_functions.processing.aurora.decimation_level import ( + DecimationLevel as AuroraDecimationLevel, +) from loguru import logger -from typing import Union +from typing import Literal, Union import numpy as np import xarray as xr ESTIMATOR_LIBRARY = {"OLS": RegressionEstimator, "RME": RME, "RME_RR": RME_RR} -def get_estimator_class(estimation_engine: str) -> RegressionEstimator: +def get_estimator_class( + estimation_engine: Literal["OLS", "RME", "RME_RR"] +) -> RegressionEstimator: """ Parameters ---------- - estimation_engine: str + estimation_engine: Literal["OLS", "RME", "RME_RR"] One of the keys in the ESTIMATOR_LIBRARY, designates the method that will be used to estimate the transfer function @@ -52,7 +57,7 @@ def get_estimator_class(estimation_engine: str) -> RegressionEstimator: return estimator_class -def set_up_iter_control(config): +def set_up_iter_control(config: AuroraDecimationLevel): """ Initializes an IterControl object based on values in the processing config. @@ -62,7 +67,8 @@ def set_up_iter_control(config): Parameters ---------- - config: mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel + config: AuroraDecimationLevel + metadata about the decimation level processing. Returns ------- @@ -81,32 +87,8 @@ def set_up_iter_control(config): "OLS", ]: iter_control = None - return iter_control - -# def select_channel(xrda: xr.DataArray, channel_label): -# """ -# Returns the channel specified by input channel_label as xarray. -# -# - Extra helper function to make process_transfer_functions more readable without -# black (the uncompromising formatter) forcing multiple lines. -# -# Parameters -# ---------- -# xrda: -# channel_label -# -# Returns -# ------- -# ch: xr.Dataset -# The channel specified by input channel_label as an xarray. -# """ -# ch = xrda.sel( -# channel=[ -# channel_label, -# ] -# ) -# return ch + return iter_control def drop_nans(X: xr.Dataset, Y: xr.Dataset, RR: Union[xr.Dataset, None]) -> tuple: @@ -173,7 +155,14 @@ def stack_fcs(X, Y, RR): return X, Y, RR -def apply_weights(X, Y, RR, W, segment=False, dropna=False): +def apply_weights( + X: xr.Dataset, + Y: xr.Dataset, + RR: xr.Dataset, + W, + segment: bool = False, + dropna: bool = False, +) -> tuple: """ Applies data weights (W) to each of X, Y, RR. If weight is zero, we set to nan and optionally dropna. @@ -210,7 +199,7 @@ def apply_weights(X, Y, RR, W, segment=False, dropna=False): def process_transfer_functions( - dec_level_config, + dec_level_config: AuroraDecimationLevel, local_stft_obj, remote_stft_obj, transfer_function_obj, @@ -219,10 +208,10 @@ def process_transfer_functions( channel_weights=None, ): """ - This is the main tf_processing method. It is based on TTFestBand.m + This is the main tf_processing method. It is based on the Matlab legacy code TTFestBand.m. Note #1: Although it is advantageous to execute the regression channel-by-channel - vs. all-at-once, we need to keep the all-at-once to get residual covariances (see issue #87) + vs. all-at-once, we need to keep the all-at-once to get residual covariances (see aurora issue #87) Note #2: Consider placing the segment weight logic in its own module with the various functions in a dictionary. @@ -241,7 +230,8 @@ def process_transfer_functions( Parameters ---------- - dec_level_config: mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel + dec_level_config: AuroraDecimationLevel + Metadata about the decimation level processing. local_stft_obj: xarray.core.dataset.Dataset remote_stft_obj: xarray.core.dataset.Dataset or None transfer_function_obj: aurora.transfer_function.TTFZ.TTFZ @@ -259,7 +249,9 @@ def process_transfer_functions( ------- transfer_function_obj: aurora.transfer_function.TTFZ.TTFZ """ - estimator_class = get_estimator_class(dec_level_config.estimator.engine) + estimator_class: RegressionEstimator = get_estimator_class( + dec_level_config.estimator.engine + ) iter_control = set_up_iter_control(dec_level_config) for band in transfer_function_obj.frequency_bands.bands(): @@ -328,3 +320,28 @@ def process_transfer_functions( transfer_function_obj.set_tf(regression_estimator, band.center_period) return transfer_function_obj + + +# def select_channel(xrda: xr.DataArray, channel_label): +# """ +# Returns the channel specified by input channel_label as xarray. +# +# - Extra helper function to make process_transfer_functions more readable without +# black (the uncompromising formatter) forcing multiple lines. +# +# Parameters +# ---------- +# xrda: +# channel_label +# +# Returns +# ------- +# ch: xr.Dataset +# The channel specified by input channel_label as an xarray. +# """ +# ch = xrda.sel( +# channel=[ +# channel_label, +# ] +# ) +# return ch diff --git a/aurora/pipelines/transfer_function_kernel.py b/aurora/pipelines/transfer_function_kernel.py index 77d070d5..d3b1dccb 100644 --- a/aurora/pipelines/transfer_function_kernel.py +++ b/aurora/pipelines/transfer_function_kernel.py @@ -4,20 +4,32 @@ """ +from aurora.config.metadata.processing import Processing from aurora.pipelines.helpers import initialize_config from aurora.pipelines.time_series_helpers import prototype_decimate + +# from aurora.transfer_function.transfer_function_collection import TransferFunctionCollection from loguru import logger from mth5.utils.exceptions import MTH5Error from mth5.utils.helpers import path_or_mth5_object from mt_metadata.transfer_functions.core import TF +from mt_metadata.transfer_functions.processing.aurora.decimation_level import ( + DecimationLevel as AuroraDecimationLevel, +) + +# from mtpy.processing.kernel_dataset import KernelDataset # TODO FIXME: causes circular import. +from typing import Union import numpy as np import pandas as pd +import pathlib import psutil class TransferFunctionKernel(object): - def __init__(self, dataset=None, config=None): + def __init__( + self, dataset, config: Union[Processing, str, pathlib.Path] # : KernelDataset, + ): """ Constructor @@ -33,12 +45,12 @@ def __init__(self, dataset=None, config=None): self._memory_warning = False @property - def dataset(self): + def dataset(self): # -> KernelDataset: """returns the KernelDataset object""" return self._dataset @property - def kernel_dataset(self): + def kernel_dataset(self): # -> KernelDataset: """returns the KernelDataset object""" return self._dataset @@ -48,12 +60,12 @@ def dataset_df(self) -> pd.DataFrame: return self._dataset.df @property - def processing_config(self): + def processing_config(self) -> Processing: """Returns the processing config object""" return self._config @property - def config(self): + def config(self) -> Processing: """Returns the processing config object""" return self._config @@ -103,8 +115,6 @@ def update_dataset_df(self, i_dec_level): ---------- i_dec_level: int decimation level id, indexed from zero - config: mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel - decimation level config Returns ------- @@ -145,21 +155,21 @@ def update_dataset_df(self, i_dec_level): ) return - def apply_clock_zero(self, dec_level_config): + def apply_clock_zero(self, dec_level_config: AuroraDecimationLevel): """ get clock-zero from data if needed Parameters ---------- - dec_level_config: mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel - + dec_level_config: AuroraDecimationLevel + metadata about the decimation level processing. Returns ------- - dec_level_config: mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel + dec_level_config: AuroraDecimationLevel The modified DecimationLevel with clock-zero information set. """ - if dec_level_config.window.clock_zero_type == "data start": - dec_level_config.window.clock_zero = str(self.dataset_df.start.min()) + if dec_level_config.stft.window.clock_zero_type == "data start": + dec_level_config.stft.window.clock_zero = str(self.dataset_df.start.min()) return dec_level_config @property @@ -378,7 +388,7 @@ def validate_decimation_scheme_and_dataset_compatability( """ if min_num_stft_windows is None: min_stft_window_info = { - x.decimation.level: x.min_num_stft_windows + x.decimation.level: x.stft.min_num_stft_windows for x in self.processing_config.decimations } min_stft_window_list = [ @@ -546,10 +556,13 @@ def export_tf_collection(self, tf_collection): Transfer function container """ - def make_decimation_dict_for_tf(tf_collection, processing_config): + def make_decimation_dict_for_tf( + tf_collection, # : TransferFunctionCollection, + processing_config: Processing, + ) -> dict: """ - Decimation dict is used by mt_metadata's TF class when it is writing z-files. - If no z-files will be written this is not needed + Helper function to create a dictionary used by mt_metadata's TF class when + writing z-files. If no z-files will be written this is not needed sample element of decimation_dict: '1514.70134': {'level': 4, 'bands': (5, 6), 'npts': 386, 'df': 0.015625}} @@ -562,25 +575,30 @@ def make_decimation_dict_for_tf(tf_collection, processing_config): Parameters ---------- - tfc + tf_collection: TransferFunctionCollection + Collection of transfer funtion estimates from aurora. + processing_config: Processing + Instructions for processing with aurora Returns ------- - + decimation_dict: dict + Keyed by a string representing the period + Values are a custom dictionary. """ from mt_metadata.transfer_functions.io.zfiles.zmm import ( PERIOD_FORMAT, ) decimation_dict = {} - + # dec_level_cfg is an AuroraDecimationLevel for i_dec, dec_level_cfg in enumerate(processing_config.decimations): for i_band, band in enumerate(dec_level_cfg.bands): period_key = f"{band.center_period:{PERIOD_FORMAT}}" period_value = {} period_value["level"] = i_dec + 1 # +1 to match EMTF standard period_value["bands"] = tuple(band.harmonic_indices[np.r_[0, -1]]) - period_value["sample_rate"] = dec_level_cfg.sample_rate_decimation + period_value["sample_rate"] = dec_level_cfg.decimation.sample_rate try: period_value["npts"] = tf_collection.tf_dict[ i_dec diff --git a/aurora/sandbox/io_helpers/garys_matlab_zfiles/matlab_z_file_reader.py b/aurora/sandbox/io_helpers/garys_matlab_zfiles/matlab_z_file_reader.py index 4ee80a2c..0e4689cb 100644 --- a/aurora/sandbox/io_helpers/garys_matlab_zfiles/matlab_z_file_reader.py +++ b/aurora/sandbox/io_helpers/garys_matlab_zfiles/matlab_z_file_reader.py @@ -99,9 +99,7 @@ def test_matlab_zfile_reader(case_id="IAK34ss", make_plot=False): ] reference_channels = [] matlab_z_file = test_dir_path.joinpath("IAK34_struct_zss.mat") - archived_z_file_path = test_dir_path.joinpath( - "archived_from_matlab.zss" - ) + archived_z_file_path = test_dir_path.joinpath("archived_from_matlab.zss") z_file_path = test_dir_path.joinpath("from_matlab.zss") # 2. Create an aurora processing config @@ -128,7 +126,7 @@ def test_matlab_zfile_reader(case_id="IAK34ss", make_plot=False): sample_rate = field_data_sample_rate for i_dec in range(4): p.decimations[i_dec].decimation.sample_rate = sample_rate - p.decimations[i_dec].window.num_samples = num_samples_window + p.decimations[i_dec].stft.window.num_samples = num_samples_window p.decimations[i_dec].estimator.engine = estimator_engine p.decimations[i_dec].input_channels = input_channels p.decimations[i_dec].output_channels = output_channels diff --git a/aurora/test_utils/parkfield/make_parkfield_mth5.py b/aurora/test_utils/parkfield/make_parkfield_mth5.py index 843e1d98..4d4e95cd 100644 --- a/aurora/test_utils/parkfield/make_parkfield_mth5.py +++ b/aurora/test_utils/parkfield/make_parkfield_mth5.py @@ -6,21 +6,17 @@ import pathlib from aurora.test_utils.dataset_definitions import TEST_DATA_SET_CONFIGS -from mth5.utils.helpers import read_back_data -from mth5.helpers import close_open_files -from aurora.sandbox.io_helpers.fdsn_dataset import FDSNDataset from aurora.sandbox.io_helpers.make_mth5_helpers import create_from_server_multistation from aurora.test_utils.parkfield.path_helpers import PARKFIELD_PATHS from loguru import logger -from typing import Union - +from mth5.utils.helpers import read_back_data +from mth5.helpers import close_open_files +from typing import Optional, Union DATA_SOURCES = ["NCEDC", "https://service.ncedc.org/"] DATASET_ID = "pkd_sao_test_00" FDSN_DATASET = TEST_DATA_SET_CONFIGS[DATASET_ID] -# - def select_data_source() -> None: """ @@ -44,16 +40,17 @@ def select_data_source() -> None: except: logger.warning(f"Data source {data_source} not initializing") if not ok: - logger.error("No data sources for Parkfield / Hollister initializing") - logger.error("NCEDC probably down") - raise ValueError + msg = "No data sources for Parkfield / Hollister initializing\n" + msg += "NCEDC probably down" + logger.error(msg) + raise ValueError(msg) else: return data_source def make_pkdsao_mth5( fdsn_dataset: FDSN_DATASET, - target_folder: Union[str, pathlib.Path, None] = PARKFIELD_PATHS["data"], + target_folder: Optional[Union[str, pathlib.Path]] = PARKFIELD_PATHS["data"], ) -> pathlib.Path: """ Makes MTH5 file with data from Parkfield and Hollister stations to use for testing. @@ -83,12 +80,13 @@ def ensure_h5_exists( Parameters ---------- h5_path: Union[pathlib.Path, None] - + The target path to build to mth5. Returns ------- h5_path: pathlib.Path The path to the PKD SAO mth5 file to be used for testing. + """ h5_path = target_folder.joinpath(FDSN_DATASET.h5_filebase) if h5_path.exists(): diff --git a/aurora/test_utils/synthetic/make_processing_configs.py b/aurora/test_utils/synthetic/make_processing_configs.py index a2714663..951a0800 100644 --- a/aurora/test_utils/synthetic/make_processing_configs.py +++ b/aurora/test_utils/synthetic/make_processing_configs.py @@ -130,12 +130,12 @@ def create_test_run_config( for decimation in p.decimations: decimation.estimator.engine = estimation_engine - decimation.window.type = "hamming" - decimation.window.num_samples = num_samples_window - decimation.window.overlap = num_samples_overlap + decimation.stft.window.type = "hamming" + decimation.stft.window.num_samples = num_samples_window + decimation.stft.window.overlap = num_samples_overlap decimation.regression.max_redescending_iterations = 2 if test_case_id == "test2": - decimation.window.type = "boxcar" + decimation.stft.window.type = "boxcar" if save == "json": filename = CONFIG_PATH.joinpath(p.json_fn()) diff --git a/aurora/time_series/frequency_band_helpers.py b/aurora/time_series/frequency_band_helpers.py index 10dc7588..6c0f085b 100644 --- a/aurora/time_series/frequency_band_helpers.py +++ b/aurora/time_series/frequency_band_helpers.py @@ -2,20 +2,31 @@ This module contains functions that are associated with time series of Fourier coefficients """ -# import numpy as np from loguru import logger - - -def get_band_for_tf_estimate(band, dec_level_config, local_stft_obj, remote_stft_obj): +from mt_metadata.transfer_functions.processing.aurora import ( + DecimationLevel as AuroraDecimationLevel, +) +from mt_metadata.transfer_functions.processing.aurora import Band +from mth5.timeseries.spectre.spectrogram import extract_band +from typing import Optional, Tuple +import xarray as xr + + +def get_band_for_tf_estimate( + band: Band, + dec_level_config: AuroraDecimationLevel, + local_stft_obj: xr.Dataset, + remote_stft_obj: Optional[xr.Dataset], +) -> Tuple[xr.Dataset, xr.Dataset, Optional[xr.Dataset]]: """ Returns spectrograms X, Y, RR for harmonics within the given band Parameters ---------- - band : mt_metadata.transfer_functions.processing.aurora.FrequencyBands + band : mt_metadata.transfer_functions.processing.aurora.Band object with lower_bound and upper_bound to tell stft object which subarray to return - config : mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel + config : AuroraDecimationLevel information about the input and output channels needed for TF estimation problem setup local_stft_obj : xarray.core.dataset.Dataset or None @@ -49,49 +60,6 @@ def get_band_for_tf_estimate(band, dec_level_config, local_stft_obj, remote_stft return X, Y, RR -def extract_band(frequency_band, fft_obj, channels=[], epsilon=1e-7): - """ - Extracts a frequency band from xr.DataArray representing a spectrogram. - - Stand alone version of the method that is used by WIP Spectrogram class. - - Development Notes: - #1: 20230902 - TODO: Decide if base dataset object should be a xr.DataArray (not xr.Dataset) - - drop=True does not play nice with h5py and Dataset, results in a type error. - File "stringsource", line 2, in h5py.h5r.Reference.__reduce_cython__ - TypeError: no default __reduce__ due to non-trivial __cinit__ - However, it works OK with DataArray, so maybe use data array in general - - Parameters - ---------- - frequency_band: mt_metadata.transfer_functions.processing.aurora.band.Band - Specifies interval corresponding to a frequency band - fft_obj: xarray.core.dataset.Dataset - To be replaced with an fft_obj() class in future - epsilon: float - Use this when you are worried about missing a frequency due to - round off error. This is in general not needed if we use a df/2 pad - around true harmonics. - - Returns - ------- - band: xr.DataArray - The frequencies within the band passed into this function - """ - cond1 = fft_obj.frequency >= frequency_band.lower_bound - epsilon - cond2 = fft_obj.frequency <= frequency_band.upper_bound + epsilon - try: - band = fft_obj.where(cond1 & cond2, drop=True) - except TypeError: # see Note #1 - tmp = fft_obj.to_array() - band = tmp.where(cond1 & cond2, drop=True) - band = band.to_dataset("variable") - if channels: - band = band[channels] - return band - - def check_time_axes_synched(X, Y): """ Utility function for checking that time axes agree. @@ -202,7 +170,7 @@ def adjust_band_for_coherence_sorting(frequency_band, spectrogram, rule="min3"): # def get_band_for_coherence_sorting( # frequency_band, -# dec_level_config, +# dec_level_config: AuroraDecimationLevel, # local_stft_obj, # remote_stft_obj, # widening_rule="min3", @@ -217,7 +185,7 @@ def adjust_band_for_coherence_sorting(frequency_band, spectrogram, rule="min3"): # band : mt_metadata.transfer_functions.processing.aurora.FrequencyBands # object with lower_bound and upper_bound to tell stft object which # subarray to return -# config : mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel +# config : AuroraDecimationLevel # information about the input and output channels needed for TF # estimation problem setup # local_stft_obj : xarray.core.dataset.Dataset or None diff --git a/aurora/time_series/spectrogram.py b/aurora/time_series/spectrogram.py deleted file mode 100644 index ad3684d3..00000000 --- a/aurora/time_series/spectrogram.py +++ /dev/null @@ -1,160 +0,0 @@ -""" - WORK IN PROGRESS (WIP): This module contains a class that represents a spectrogram, - i.e. A 2D time series of Fourier coefficients with axes time and frequency. - -""" -from aurora.time_series.frequency_band_helpers import extract_band -from typing import Optional -import xarray - - -class Spectrogram(object): - """ - Class to contain methods for STFT objects. - TODO: Add support for cross powers - TODO: Add OLS Z-estimates - TODO: Add Sims/Vozoff Z-estimates - - """ - - def __init__(self, dataset=None): - """Constructor""" - self._dataset = dataset - self._frequency_increment = None - - def _lowest_frequency(self): - pass - - def _higest_frequency(self): - pass - - def __str__(self) -> str: - """Returns a Description of frequency coverage""" - intro = "Spectrogram:" - frequency_coverage = ( - f"{self.dataset.dims['frequency']} harmonics, {self.frequency_increment}Hz spaced \n" - f" from {self.dataset.frequency.data[0]} to {self.dataset.frequency.data[-1]} Hz." - ) - time_coverage = f"\n{self.dataset.dims['time']} Time observations" - time_coverage = f"{time_coverage} \nStart: {self.dataset.time.data[0]}" - time_coverage = f"{time_coverage} \nEnd: {self.dataset.time.data[-1]}" - - channel_coverage = list(self.dataset.data_vars.keys()) - channel_coverage = "\n".join(channel_coverage) - channel_coverage = f"\nChannels present: \n{channel_coverage}" - return ( - intro - + "\n" - + frequency_coverage - + "\n" - + time_coverage - + "\n" - + channel_coverage - ) - - def __repr__(self): - return self.__str__() - - @property - def dataset(self): - """returns the underlying xarray data""" - return self._dataset - - @property - def time_axis(self): - """returns the time axis of the underlying xarray""" - return self.dataset.time - - @property - def frequency_increment(self): - """ - returns the "delta f" of the frequency axis - - assumes uniformly sampled in frequency domain - """ - if self._frequency_increment is None: - frequency_axis = self.dataset.frequency - self._frequency_increment = frequency_axis.data[1] - frequency_axis.data[0] - return self._frequency_increment - - def num_harmonics_in_band(self, frequency_band, epsilon=1e-7): - """ - - Returns the number of harmonics within the frequency band in the underlying dataset - - Parameters - ---------- - band - stft_obj - - Returns - ------- - - """ - cond1 = self._dataset.frequency >= frequency_band.lower_bound - epsilon - cond2 = self._dataset.frequency <= frequency_band.upper_bound + epsilon - num_harmonics = (cond1 & cond2).data.sum() - return num_harmonics - - def extract_band(self, frequency_band, channels=[]): - """ - Returns another instance of Spectrogram, with the frequency axis reduced to the input band. - - TODO: Consider returning a copy of the data... - - Parameters - ---------- - frequency_band - channels - - Returns - ------- - spectrogram: aurora.time_series.spectrogram.Spectrogram - Returns a Spectrogram object with only the extracted band for a dataset - - """ - extracted_band_dataset = extract_band( - frequency_band, - self.dataset, - channels=channels, - epsilon=self.frequency_increment / 2.0, - ) - spectrogram = Spectrogram(dataset=extracted_band_dataset) - return spectrogram - - # TODO: Add cross power method - # def cross_powers(self, ch1, ch2, band=None): - # pass - - def flatten(self, chunk_by: Optional[str] = "time") -> xarray.Dataset: - """ - - Returns the flattened xarray (time-chunked by default). - - Parameters - ---------- - chunk_by: str - Controlled vocabulary ["time", "frequency"]. Reshaping the 2D spectrogram can be done two ways - (basically "row-major", or column-major). In xarray, but we either keep frequency constant and - iterate over time, or keep time constant and iterate over frequency (in the inner loop). - - - Returns - ------- - xarray.Dataset : The dataset from the band spectrogram, stacked. - - Development Notes: - The flattening used in tf calculation by default is opposite to here - dataset.stack(observation=("frequency", "time")) - However, for feature extraction, it may make sense to swap the order: - xrds = band_spectrogram.dataset.stack(observation=("time", "frequency")) - This is like chunking into time windows and allows individual features to be computed on each time window -- if desired. - Still need to split the time series though--Splitting to time would be a reshape by (last_freq_index-first_freq_index). - Using pure xarray this may not matter but if we drop down into numpy it could be useful. - - - """ - if chunk_by == "time": - observation = ("time", "frequency") - elif chunk_by == "frequency": - observation = ("frequency", "time") - return self.dataset.stack(observation=observation) diff --git a/aurora/time_series/windowed_time_series.py b/aurora/time_series/windowed_time_series.py index 383fdf94..72f7b82b 100644 --- a/aurora/time_series/windowed_time_series.py +++ b/aurora/time_series/windowed_time_series.py @@ -11,9 +11,8 @@ """ from aurora.time_series.decorators import can_use_xr_dataarray -from mt_metadata.transfer_functions.processing.aurora.decimation_level import ( - get_fft_harmonics, -) +from mt_metadata.transfer_functions.processing.window import get_fft_harmonics + from typing import Optional, Union from loguru import logger import numpy as np diff --git a/aurora/time_series/windowing_scheme.py b/aurora/time_series/windowing_scheme.py index 2923cd31..07764fa7 100644 --- a/aurora/time_series/windowing_scheme.py +++ b/aurora/time_series/windowing_scheme.py @@ -74,10 +74,11 @@ from aurora.time_series.windowed_time_series import WindowedTimeSeries from aurora.time_series.window_helpers import available_number_of_windows_in_array from aurora.time_series.window_helpers import SLIDING_WINDOW_FUNCTIONS - from mt_metadata.transfer_functions.processing.aurora.decimation_level import ( - get_fft_harmonics, + DecimationLevel as AuroraDecimationLevel, ) +from mt_metadata.transfer_functions.processing.window import get_fft_harmonics + from loguru import logger from typing import Optional, Union @@ -448,14 +449,14 @@ def linear_spectral_density_calibration_factor(self) -> float: return np.sqrt(2 / (self.sample_rate * self.S2)) -def window_scheme_from_decimation(decimation): +def window_scheme_from_decimation(decimation: AuroraDecimationLevel): """ Helper function to workaround mt_metadata to not import form aurora Parameters ---------- - decimation: mt_metadata.transfer_function.processing.aurora.decimation_level - .DecimationLevel + decimation: AuroraDecimationLevel + Decimation level metadata object Returns ------- @@ -464,10 +465,10 @@ def window_scheme_from_decimation(decimation): from aurora.time_series.windowing_scheme import WindowingScheme windowing_scheme = WindowingScheme( - taper_family=decimation.window.type, - num_samples_window=decimation.window.num_samples, - num_samples_overlap=decimation.window.overlap, - taper_additional_args=decimation.window.additional_args, - sample_rate=decimation.sample_rate_decimation, + taper_family=decimation.stft.window.type, + num_samples_window=decimation.stft.window.num_samples, + num_samples_overlap=decimation.stft.window.overlap, + taper_additional_args=decimation.stft.window.additional_args, + sample_rate=decimation.decimation.sample_rate, ) return windowing_scheme diff --git a/aurora/time_series/xarray_helpers.py b/aurora/time_series/xarray_helpers.py index db30493f..f497ab4c 100644 --- a/aurora/time_series/xarray_helpers.py +++ b/aurora/time_series/xarray_helpers.py @@ -2,13 +2,17 @@ Placeholder module for methods manipulating xarray time series """ -import numpy as np import xarray as xr from loguru import logger -from typing import Optional, Union +from typing import Optional -def handle_nan(X, Y, RR, drop_dim=""): +def handle_nan( + X: xr.Dataset, + Y: Optional[xr.Dataset], + RR: Optional[xr.Dataset], + drop_dim: Optional[str] = "", +) -> tuple: """ Drops Nan from multiple channel series'. @@ -87,118 +91,3 @@ def handle_nan(X, Y, RR, drop_dim=""): RR = RR.rename(data_var_rm_label_mapper) return X, Y, RR - - -def covariance_xr( - X: xr.DataArray, aweights: Optional[Union[np.ndarray, None]] = None -) -> xr.DataArray: - """ - Compute the covariance matrix with numpy.cov. - - Parameters - ---------- - X: xarray.core.dataarray.DataArray - Multivariate time series as an xarray - aweights: array_like, optional - Doc taken from numpy cov follows: - 1-D array of observation vector weights. These relative weights are - typically large for observations considered "important" and smaller for - observations considered less "important". If ``ddof=0`` the array of - weights can be used to assign probabilities to observation vectors. - - Returns - ------- - S: xarray.DataArray - The covariance matrix of the data in xarray form. - """ - - channels = list(X.coords["variable"].values) - - S = xr.DataArray( - np.cov(X, aweights=aweights), - dims=["channel_1", "channel_2"], - coords={"channel_1": channels, "channel_2": channels}, - ) - return S - - -def initialize_xrda_1d( - channels: list, - dtype=Optional[type], - value: Optional[Union[complex, float, bool]] = 0, -) -> xr.DataArray: - """ - - Returns a 1D xr.DataArray with variable "channel", having values channels named by the input list. - - Parameters - ---------- - channels: list - The channels in the multivariate array - dtype: type - The datatype to initialize the array. - Common cases are complex, float, and bool - value: Union[complex, float, bool] - The default value to assign the array - - Returns - ------- - xrda: xarray.core.dataarray.DataArray - An xarray container for the channels, initialized to zeros. - """ - k = len(channels) - logger.debug(f"Initializing xarray with values {value}") - xrda = xr.DataArray( - np.zeros(k, dtype=dtype), - dims=[ - "variable", - ], - coords={ - "variable": channels, - }, - ) - if value != 0: - data = value * np.ones(k, dtype=dtype) - xrda.data = data - return xrda - - -def initialize_xrda_2d( - channels, dtype=complex, value: Optional[Union[complex, float, bool]] = 0, dims=None -): - - """ - TODO: consider merging with initialize_xrda_1d - TODO: consider changing nomenclature from dims=["channel_1", "channel_2"], - to dims=["variable_1", "variable_2"], to be consistent with initialize_xrda_1d - - Parameters - ---------- - channels: list - The channels in the multivariate array - dtype: type - The datatype to initialize the array. - Common cases are complex, float, and bool - value: Union[complex, float, bool] - The default value to assign the array - - Returns - ------- - xrda: xarray.core.dataarray.DataArray - An xarray container for the channel variances etc., initialized to zeros. - """ - if dims is None: - dims = [channels, channels] - - K = len(channels) - logger.debug(f"Initializing 2D xarray to {value}") - xrda = xr.DataArray( - np.zeros((K, K), dtype=dtype), - dims=["channel_1", "channel_2"], - coords={"channel_1": dims[0], "channel_2": dims[1]}, - ) - if value != 0: - data = value * np.ones(xrda.shape, dtype=dtype) - xrda.data = data - - return xrda diff --git a/aurora/transfer_function/base.py b/aurora/transfer_function/base.py index 6083073f..f26ac2e7 100644 --- a/aurora/transfer_function/base.py +++ b/aurora/transfer_function/base.py @@ -12,7 +12,7 @@ import xarray as xr from aurora.config.metadata.processing import Processing from loguru import logger -from mt_metadata.transfer_functions.processing.aurora.band import FrequencyBands +from mt_metadata.transfer_functions.processing.aurora import FrequencyBands from typing import Optional, Union @@ -50,10 +50,6 @@ def __init__( """ Constructor. - Development Notes: - change 2021-07-23 to require a frequency_bands object. We may want - to just pass the band_edges. - Parameters ---------- _emtf_header : legacy header information used by Egbert's matlab class. Header contains @@ -61,8 +57,8 @@ def __init__( decimation_level_id: int Identifies the relevant decimation level. Used for accessing the appropriate info in self.processing config. - frequency_bands: aurora.time_series.frequency_band.FrequencyBands - frequency bands object + frequency_bands: FrequencyBands + frequency bands object defining the tf estimation bands. """ self._emtf_tf_header = None self.decimation_level_id = decimation_level_id diff --git a/aurora/transfer_function/plot/comparison_plots.py b/aurora/transfer_function/plot/comparison_plots.py index d2120a16..d5732524 100644 --- a/aurora/transfer_function/plot/comparison_plots.py +++ b/aurora/transfer_function/plot/comparison_plots.py @@ -83,7 +83,7 @@ def compare_two_z_files( zfile1 = read_z_file(z_path1, angle=angle1) zfile2 = read_z_file(z_path2, angle=angle2) - logger.info(f"Sacling TF scale_factor1: {scale_factor1}") + logger.info(f"Scaling TF scale_factor1: {scale_factor1}") fig, axs = plt.subplots(nrows=2, dpi=300, sharex=True) # figsize=(8, 6.), # Make LaTeX symbol strings diff --git a/aurora/transfer_function/regression/iter_control.py b/aurora/transfer_function/regression/iter_control.py index fff4fdcd..dea77021 100644 --- a/aurora/transfer_function/regression/iter_control.py +++ b/aurora/transfer_function/regression/iter_control.py @@ -21,8 +21,8 @@ class IterControl(object): the abstract base class solved Y = X*b + epsilon for b, complex-valued. Perhaps this was intended as an intrinsic tolerated noise level. The value of epsilon was set to 1000. - - TODO The return covariance boolean just initializes arrays of zeros. Needs to be - made functional or removed + - TODO The return covariance boolean just initializes arrays of zeros. Needs to be + made functional or removed """ def __init__( @@ -35,7 +35,7 @@ def __init__( verbosity: int = 0, ) -> None: """ - Constructor + Constructor. Parameters ---------- @@ -43,7 +43,7 @@ def __init__( Set to zero for OLS, otherwise, this is how many times the RME will refine the estimate. max_number_of_redescending_iterations : int - 1 or 2 is fine at most. If set to zero we ignore the redescend code block. + 1 or 2 is fine at most. If set to zero, the redescend code block is ignored. r0: float Effectively infinty for OLS, this controls the point at which residuals transition from being penalized by a squared vs a linear function. The @@ -126,17 +126,15 @@ def converged(self, b, b0): the most recent regression estimate b0 : complex-valued numpy array The previous regression estimate - verbose: bool - Set to True for debugging Returns ------- converged: bool True of the regression has terminated, False otherwise - Notes: + Developement Notes: The variable maximum_change finds the maximum amplitude component of the vector - 1-b/b0. Looking at the formula, one might want to cast this instead as + 1 - b/b0. Looking at the formula, one might want to cast this instead as 1 - abs(b/b0), however, that will be insensitive to phase changes in b, which is complex valued. The way it is coded np.max(np.abs(1 - b / b0)) is correct as it stands. @@ -165,7 +163,16 @@ def converged(self, b, b0): return converged @property - def continue_redescending(self): + def continue_redescending(self) -> bool: + """ + Checks if the max_number_of_redescending_iterations has been reached + + Returns + ------- + maxxed_out: bool + True if max_number_of_redescending_iterations has been reached, otherwise False. + + """ maxxed_out = ( self.number_of_redescending_iterations >= self.max_number_of_redescending_iterations @@ -176,7 +183,7 @@ def continue_redescending(self): return True @property - def correction_factor(self): + def correction_factor(self) -> float: """ Returns correction factor for residual variances. diff --git a/aurora/transfer_function/weights/edf_weights.py b/aurora/transfer_function/weights/edf_weights.py index 590cfa12..4c8e104c 100644 --- a/aurora/transfer_function/weights/edf_weights.py +++ b/aurora/transfer_function/weights/edf_weights.py @@ -198,18 +198,20 @@ def effective_degrees_of_freedom_weights( Weights for reducing leverage points. """ + # Initialize the weights + n_observations_initial = len(X.observation) + weights = np.ones(n_observations_initial) + # validate num channels num_channels = len(X.data_vars) if num_channels != 2: - logger.error("edfwts only works for 2 input channels") - raise Exception + logger.error(f"edfwts only works for 2 input channels, not {num_channels}") + return weights + X = X.to_array(dim="channel") if R is not None: R = R.to_array(dim="channel") - n_observations_initial = len(X.observation) - weights = np.ones(n_observations_initial) - # reduce the data to only valid (non-nan) observations if R is not None: keep_x_indices = ~np.isnan(X.data).any(axis=0) diff --git a/docs/tutorials/process_cas04_multiple_station.ipynb b/docs/tutorials/process_cas04_multiple_station.ipynb index 0f5fc9be..e7191b6e 100644 --- a/docs/tutorials/process_cas04_multiple_station.ipynb +++ b/docs/tutorials/process_cas04_multiple_station.ipynb @@ -2772,7 +2772,7 @@ "outputs": [], "source": [ "for dec_level in config.decimations:\n", - " dec_level.window.type = \"hamming\"" + " dec_level.stft.window.type = \"hamming\"" ] }, { @@ -3534,8 +3534,8 @@ "cc = ConfigCreator()\n", "config = cc.create_from_kernel_dataset(kernel_dataset,) \n", "for dec_level in config.decimations:\n", - " dec_level.window.type = \"hamming\"\n", - "# dec_level.window.overlap = int(dec_level.window.num_samples/4)\n", + " dec_level.stft.window.type = \"hamming\"\n", + "# dec_level.stft.window.overlap = int(dec_level.stft.window.num_samples/4)\n", " dec_level.save_fcs = True\n", " dec_level.save_fcs_type = \"h5\"" ] diff --git a/docs/tutorials/synthetic_data_processing.ipynb b/docs/tutorials/synthetic_data_processing.ipynb index b505a313..d62fbc47 100644 --- a/docs/tutorials/synthetic_data_processing.ipynb +++ b/docs/tutorials/synthetic_data_processing.ipynb @@ -1908,9 +1908,9 @@ "metadata": {}, "outputs": [], "source": [ - "from aurora.pipelines.fourier_coefficients import add_fcs_to_mth5\n", - "from aurora.pipelines.fourier_coefficients import fc_decimations_creator\n", - "from aurora.pipelines.fourier_coefficients import read_back_fcs" + "from mth5.timeseries.spectre.helpers import add_fcs_to_mth5\n", + "from mth5.timeseries.spectre.helpers import fc_decimations_creator\n", + "from mth5.timeseries.spectre.helpers import read_back_fcs" ] }, { diff --git a/setup.py b/setup.py index 66dcfb83..b5c4d5ac 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ requirements = [ # "mt_metadata", # "mth5", - "mtpy-v2", + # "mtpy-v2", "numba", "psutil", ] diff --git a/tests/cas04/02b_process_cas04_mth5.py b/tests/cas04/02b_process_cas04_mth5.py index 9d5d55d7..5b0dfa13 100644 --- a/tests/cas04/02b_process_cas04_mth5.py +++ b/tests/cas04/02b_process_cas04_mth5.py @@ -125,9 +125,7 @@ def z_file_name(self, target_dir): return out_file -def process_station_runs( - local_station_id, remote_station_id="", station_runs={} -): +def process_station_runs(local_station_id, remote_station_id="", station_runs={}): """ Parameters @@ -154,9 +152,7 @@ def process_station_runs( # Pass the run_summary to a Dataset class kernel_dataset = KernelDataset() - kernel_dataset.from_run_summary( - run_summary, local_station_id, remote_station_id - ) + kernel_dataset.from_run_summary(run_summary, local_station_id, remote_station_id) # reduce station_runs_dict to only relevant stations @@ -297,7 +293,7 @@ def process_with_remote( kernel_dataset, emtf_band_file=band_setup_file ) for decimation in config.decimations: - decimation.window.type = "hamming" + decimation.stft.window.type = "hamming" show_plot = False if remote: z_file_base = f"{local}_RR{remote}.zrr" diff --git a/tests/config/test_config_creator.py b/tests/config/test_config_creator.py index 61fe51f9..f7e2f4f4 100644 --- a/tests/config/test_config_creator.py +++ b/tests/config/test_config_creator.py @@ -1,4 +1,5 @@ # import logging +import pandas as pd import unittest from aurora.config.config_creator import ConfigCreator @@ -47,6 +48,35 @@ def test_exception_for_non_unique_band_specification(self): ) cc2.determine_band_specification_style() + def test_frequency_bands(self): + """ + Tests the frequency_bands method of AuroraDecimationLevel + TODO: Move this into mt_metadata. + - Requires a test in mt_metadata that creates a fully populated AuroraDecimationLevel + + """ + import numpy as np + + kernel_dataset = get_example_kernel_dataset() + cc = ConfigCreator() + cfg1 = cc.create_from_kernel_dataset( + kernel_dataset, estimator={"engine": "RME"} + ) + dec_level_0 = cfg1.decimations[0] + band_edges_a = dec_level_0.frequency_bands_obj().band_edges + + # compare with another way to get band edges + delta_f = dec_level_0.frequency_sample_interval + lower_edges = (dec_level_0.lower_bounds * delta_f) - delta_f / 2.0 + upper_edges = (dec_level_0.upper_bounds * delta_f) + delta_f / 2.0 + band_edges_b = pd.DataFrame( + data={ + "lower_bound": lower_edges, + "upper_bound": upper_edges, + } + ) + assert (band_edges_b - band_edges_a == 0).all().all() + def main(): # tmp = TestConfigCreator() diff --git a/tests/parkfield/test_process_parkfield_run.py b/tests/parkfield/test_process_parkfield_run.py index 54dfb5c5..cac020ea 100644 --- a/tests/parkfield/test_process_parkfield_run.py +++ b/tests/parkfield/test_process_parkfield_run.py @@ -43,9 +43,9 @@ def test_processing(z_file_path=None, test_clock_zero=False): if test_clock_zero: for dec_lvl_cfg in config.decimations: - dec_lvl_cfg.window.clock_zero_type = test_clock_zero + dec_lvl_cfg.stft.window.clock_zero_type = test_clock_zero if test_clock_zero == "user specified": - dec_lvl_cfg.window.clock_zero = "2004-09-28 00:00:10+00:00" + dec_lvl_cfg.stft.window.clock_zero = "2004-09-28 00:00:10+00:00" show_plot = False tf_cls = process_mth5( diff --git a/tests/synthetic/test_decimation_methods.py b/tests/synthetic/test_decimation_methods.py new file mode 100644 index 00000000..919e7e84 --- /dev/null +++ b/tests/synthetic/test_decimation_methods.py @@ -0,0 +1,80 @@ +""" + This is a test to confirm that mth5's decimation method returns the same default values as aurora's prototype decimate. + + TODO: add tests from aurora issue #363 in this module +""" + +from aurora.pipelines.time_series_helpers import prototype_decimate +from aurora.test_utils.synthetic.make_processing_configs import ( + create_test_run_config, +) +from loguru import logger +from mth5.data.make_mth5_from_asc import create_test1_h5 +from mth5.mth5 import MTH5 +from mth5.helpers import close_open_files +from mtpy.processing import RunSummary, KernelDataset # mtpy-v2 + +import numpy as np + + +def test_decimation_methods_agree(): + """ + Get some synthetic time series and check that the decimation results are + equal to calling the mth5 built-in run_xrts.sps_filters.decimate. + + TODO: More testing could be added for downsamplings that are not integer factors. + + """ + close_open_files() + mth5_path = create_test1_h5() + mth5_paths = [ + mth5_path, + ] + + run_summary = RunSummary() + run_summary.from_mth5s(mth5_paths) + tfk_dataset = KernelDataset() + station_id = "test1" + run_id = "001" + tfk_dataset.from_run_summary(run_summary, station_id) + + processing_config = create_test_run_config(station_id, tfk_dataset) + + mth5_obj = MTH5(file_version="0.1.0") + mth5_obj.open_mth5(mth5_path, mode="a") + decimated_ts = {} + + for dec_level_id, dec_config in enumerate(processing_config.decimations): + decimated_ts[dec_level_id] = {} + if dec_level_id == 0: + run_obj = mth5_obj.get_run(station_id, run_id, survey=None) + run_ts = run_obj.to_runts(start=None, end=None) + run_xrds = run_ts.dataset + decimated_ts[dec_level_id]["run_xrds"] = run_xrds + current_sample_rate = run_obj.metadata.sample_rate + + if dec_level_id > 0: + run_xrds = decimated_ts[dec_level_id - 1]["run_xrds"] + target_sample_rate = current_sample_rate / (dec_config.decimation.factor) + + decimated_1 = prototype_decimate(dec_config.decimation, run_xrds) + decimated_2 = run_xrds.sps_filters.decimate( + target_sample_rate=target_sample_rate + ) + + difference = decimated_2 - decimated_1 + logger.info(len(difference.time)) + assert np.isclose(difference.to_array(), 0).all() + + logger.info("prototype decimate aurora method agrees with mth5 decimate") + decimated_ts[dec_level_id]["run_xrds"] = decimated_1 + current_sample_rate = target_sample_rate + return + + +def main(): + test_decimation_methods_agree() + + +if __name__ == "__main__": + main() diff --git a/tests/synthetic/test_fourier_coefficients.py b/tests/synthetic/test_fourier_coefficients.py index 0a454b33..2981faaf 100644 --- a/tests/synthetic/test_fourier_coefficients.py +++ b/tests/synthetic/test_fourier_coefficients.py @@ -1,24 +1,21 @@ import unittest from aurora.config.config_creator import ConfigCreator -from aurora.pipelines.fourier_coefficients import add_fcs_to_mth5 -from aurora.pipelines.fourier_coefficients import fc_decimations_creator -from aurora.pipelines.fourier_coefficients import read_back_fcs from aurora.pipelines.process_mth5 import process_mth5 from aurora.test_utils.synthetic.make_processing_configs import ( create_test_run_config, ) from aurora.test_utils.synthetic.paths import SyntheticTestPaths +from loguru import logger from mth5.data.make_mth5_from_asc import create_test1_h5 from mth5.data.make_mth5_from_asc import create_test2_h5 from mth5.data.make_mth5_from_asc import create_test3_h5 from mth5.data.make_mth5_from_asc import create_test12rr_h5 - -# from mtpy-v2 -from mtpy.processing import RunSummary, KernelDataset - -from loguru import logger from mth5.helpers import close_open_files +from mth5.timeseries.spectre.helpers import add_fcs_to_mth5 +from mth5.timeseries.spectre.helpers import fc_decimations_creator +from mth5.timeseries.spectre.helpers import read_back_fcs +from mtpy.processing import RunSummary, KernelDataset # from mtpy-v2 synthetic_test_paths = SyntheticTestPaths() synthetic_test_paths.mkdirs() @@ -30,9 +27,8 @@ class TestAddFourierCoefficientsToSyntheticData(unittest.TestCase): Runs several synthetic processing tests from config creation to tf_cls. There are two ways to prepare the FC-schema - a) use the mt_metadata.FCDecimation class explictly - b) mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel has - a to_fc_decimation() method that returns mt_metadata.FCDecimation + a) use the mt_metadata.FCDecimation class + b) use AuroraDecimationLevel's to_fc_decimation() method that returns mt_metadata.FCDecimation Flow is to make some mth5 files from synthetic data, then loop over those files adding fcs. Finally, process the mth5s to make TFs. @@ -73,8 +69,7 @@ def test_123(self): - This could probably be shortened, it isn't clear that all the h5 files need to have fc added and be processed too. - uses the to_fc_decimation() method of - mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel. + uses the to_fc_decimation() method of AuroraDecimationLevel. Returns ------- @@ -115,6 +110,7 @@ def test_123(self): x.to_fc_decimation() for x in processing_config.decimations ] # For code coverage, have a case where fc_decimations is None + # This also (indirectly) tests a different FCDeecimation object. if mth5_path.stem == "test1": fc_decimations = None @@ -127,8 +123,13 @@ def test_123(self): return tfc def test_fc_decimations_creator(self): - """""" - cfgs = fc_decimations_creator(1.0) + """ + # TODO: Move this into mt_metadata + Returns + ------- + + """ + cfgs = fc_decimations_creator(initial_sample_rate=1.0) # test time period must of of type with self.assertRaises(NotImplementedError): @@ -136,6 +137,15 @@ def test_fc_decimations_creator(self): fc_decimations_creator(1.0, time_period=time_period) return cfgs + def test_spectrogram(self): + """ + Place holder method. TODO: Move this into MTH5 + + Development Notes: + Currently mth5 does not have any STFT methods. Once that + :return: + """ + def test_create_then_use_stored_fcs_for_processing(self): """""" from test_processing import process_synthetic_2 @@ -160,7 +170,7 @@ def test_create_then_use_stored_fcs_for_processing(self): ) # Intialize a TF kernel to check for FCs - original_window = processing_config.decimations[0].window.type + original_window = processing_config.decimations[0].stft.window.type tfk = TransferFunctionKernel(dataset=tfk_dataset, config=processing_config) tfk.make_processing_summary() @@ -171,7 +181,7 @@ def test_create_then_use_stored_fcs_for_processing(self): # now change the window type and show that FCs are not detected for decimation in processing_config.decimations: - decimation.window.type = "hamming" + decimation.stft.window.type = "hamming" tfk = TransferFunctionKernel(dataset=tfk_dataset, config=processing_config) tfk.make_processing_summary() tfk.check_if_fcs_already_exist() @@ -181,7 +191,7 @@ def test_create_then_use_stored_fcs_for_processing(self): # Now reprocess with the FCs for decimation in processing_config.decimations: - decimation.window.type = original_window + decimation.stft.window.type = original_window tfk = TransferFunctionKernel(dataset=tfk_dataset, config=processing_config) tfk.make_processing_summary() tfk.check_if_fcs_already_exist() diff --git a/tests/synthetic/test_stft_methods_agree.py b/tests/synthetic/test_stft_methods_agree.py index e920e03d..da2ea54a 100644 --- a/tests/synthetic/test_stft_methods_agree.py +++ b/tests/synthetic/test_stft_methods_agree.py @@ -7,18 +7,15 @@ from aurora.pipelines.time_series_helpers import prototype_decimate from aurora.pipelines.time_series_helpers import run_ts_to_stft -from aurora.pipelines.time_series_helpers import run_ts_to_stft_scipy from aurora.test_utils.synthetic.make_processing_configs import ( create_test_run_config, ) - -# from mtpy-v2 -from mtpy.processing import RunSummary, KernelDataset - from loguru import logger from mth5.data.make_mth5_from_asc import create_test1_h5 from mth5.mth5 import MTH5 from mth5.helpers import close_open_files +from mth5.timeseries.spectre.stft import run_ts_to_stft_scipy +from mtpy.processing import RunSummary, KernelDataset # from mtpy-v2 def test_stft_methods_agree(): @@ -68,14 +65,14 @@ def test_stft_methods_agree(): run_ts = run_obj.to_runts(start=None, end=None) local_run_xrts = run_ts.dataset else: - local_run_xrts = prototype_decimate( - dec_config.decimation, local_run_xrts - ) - - dec_config.extra_pre_fft_detrend_type = "constant" - local_stft_obj = run_ts_to_stft(dec_config, local_run_xrts) - local_stft_obj2 = run_ts_to_stft_scipy(dec_config, local_run_xrts) - stft_difference = local_stft_obj - local_stft_obj2 + local_run_xrts = prototype_decimate(dec_config.decimation, local_run_xrts) + + dec_config.stft.per_window_detrend_type = "constant" + local_spectrogram = run_ts_to_stft(dec_config, local_run_xrts) + local_spectrogram2 = run_ts_to_stft_scipy(dec_config, local_run_xrts) + stft_difference = ( + local_spectrogram.dataset - local_spectrogram2.dataset + ) # TODO: add a "-" method to spectrogram that subtracts the datasets stft_difference = stft_difference.to_array() # drop dc term diff --git a/tests/test_run_on_commit.ipynb b/tests/test_run_on_commit.ipynb deleted file mode 100644 index 6ac8e9bc..00000000 --- a/tests/test_run_on_commit.ipynb +++ /dev/null @@ -1,35 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "99fe1b05-7951-4666-8c24-7bb28426611d", - "metadata": {}, - "outputs": [], - "source": [ - "assert(True)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "aurora-test", - "language": "python", - "name": "aurora-test" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tests/time_series/test_spectrogram.py b/tests/time_series/test_spectrogram.py deleted file mode 100644 index 22658981..00000000 --- a/tests/time_series/test_spectrogram.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -""" -""" - -import unittest - -from aurora.time_series.spectrogram import Spectrogram - - -class TestSpectrogram(unittest.TestCase): - """ - Test Spectrogram class - """ - - @classmethod - def setUpClass(self): - pass - - def setUp(self): - pass - - def test_initialize(self): - spectrogram = Spectrogram() - assert isinstance(spectrogram, Spectrogram) - - def test_slice_band(self): - """ - Place holder - TODO: Once FCs are added to an mth5, load a spectrogram and extract a Band - """ - pass - - -if __name__ == "__main__": - # tmp = TestSpectrogram() - # tmp.test_initialize() - unittest.main() diff --git a/tests/time_series/test_xarray_helpers.py b/tests/time_series/test_xarray_helpers.py index 247e77da..9f57df15 100644 --- a/tests/time_series/test_xarray_helpers.py +++ b/tests/time_series/test_xarray_helpers.py @@ -4,74 +4,83 @@ """ import numpy as np -import unittest - import xarray as xr +import pytest + +from aurora.time_series.xarray_helpers import handle_nan + + +def test_handle_nan_basic(): + """Test basic functionality of handle_nan with NaN values.""" + # Create sample data with NaN values + times = np.array([0, 1, 2, 3, 4]) + data_x = np.array([1.0, np.nan, 3.0, 4.0, 5.0]) + data_y = np.array([1.0, 2.0, np.nan, 4.0, 5.0]) + + X = xr.Dataset({"hx": ("time", data_x)}, coords={"time": times}) + Y = xr.Dataset({"ex": ("time", data_y)}, coords={"time": times}) + + # Test with X and Y only + X_clean, Y_clean, _ = handle_nan(X, Y, None, drop_dim="time") + + # Check that NaN values were dropped + assert len(X_clean.time) == 3 + assert len(Y_clean.time) == 3 + assert not np.any(np.isnan(X_clean.hx.values)) + assert not np.any(np.isnan(Y_clean.ex.values)) + + +def test_handle_nan_with_remote_reference(): + """Test handle_nan with remote reference data.""" + # Create sample data + times = np.array([0, 1, 2, 3]) + data_x = np.array([1.0, np.nan, 3.0, 4.0]) + data_y = np.array([1.0, 2.0, 3.0, 4.0]) + data_rr = np.array([1.0, 2.0, np.nan, 4.0]) + + X = xr.Dataset({"hx": ("time", data_x)}, coords={"time": times}) + Y = xr.Dataset({"ex": ("time", data_y)}, coords={"time": times}) + RR = xr.Dataset({"hx": ("time", data_rr)}, coords={"time": times}) + + # Test with all datasets + X_clean, Y_clean, RR_clean = handle_nan(X, Y, RR, drop_dim="time") + + # Check that NaN values were dropped + assert len(X_clean.time) == 2 + assert len(Y_clean.time) == 2 + assert len(RR_clean.time) == 2 + assert not np.any(np.isnan(X_clean.hx.values)) + assert not np.any(np.isnan(Y_clean.ex.values)) + assert not np.any(np.isnan(RR_clean.hx.values)) + + # Check that the values are correct + expected_times = np.array([0, 3]) + assert np.allclose(X_clean.time.values, expected_times) + assert np.allclose(Y_clean.time.values, expected_times) + assert np.allclose(RR_clean.time.values, expected_times) + assert np.allclose(X_clean.hx.values, np.array([1.0, 4.0])) + assert np.allclose(Y_clean.ex.values, np.array([1.0, 4.0])) + assert np.allclose(RR_clean.hx.values, np.array([1.0, 4.0])) + + +def test_handle_nan_time_mismatch(): + """Test handle_nan with time coordinate mismatches.""" + # Create sample data with slightly different timestamps + times_x = np.array([0, 1, 2, 3]) + times_rr = times_x + 0.1 # Small offset + data_x = np.array([1.0, 2.0, 3.0, 4.0]) + data_rr = np.array([1.0, 2.0, 3.0, 4.0]) + + X = xr.Dataset({"hx": ("time", data_x)}, coords={"time": times_x}) + RR = xr.Dataset({"hx": ("time", data_rr)}, coords={"time": times_rr}) + + # Test handling of time mismatch + X_clean, _, RR_clean = handle_nan(X, None, RR, drop_dim="time") + + # Check that data was preserved despite time mismatch + assert len(X_clean.time) == 4 + assert "hx" in RR_clean.data_vars + assert np.allclose(RR_clean.hx.values, data_rr) -from aurora.time_series.xarray_helpers import covariance_xr -from aurora.time_series.xarray_helpers import initialize_xrda_1d -from aurora.time_series.xarray_helpers import initialize_xrda_2d - - -class TestXarrayHelpers(unittest.TestCase): - """ - Test methods in xarray helpers - - may get broken into separate tests if this module grows - """ - - @classmethod - def setUpClass(self): - self.standard_channel_names = ["ex", "ey", "hx", "hy", "hz"] - - def setUp(self): - pass - - def test_initialize_xrda_1d(self): - dtype = float - value = -1 - tmp = initialize_xrda_1d(self.standard_channel_names, dtype=dtype, value=value) - self.assertTrue((tmp.data == value).all()) - - def test_initialize_xrda_2d(self): - dtype = float - value = -1 - tmp = initialize_xrda_2d(self.standard_channel_names, dtype=dtype, value=value) - self.assertTrue((tmp.data == value).all()) - - def test_covariance_xr(self): - np.random.seed(0) - n_observations = 100 - xrds = xr.Dataset( - { - "hx": ( - [ - "time", - ], - np.abs(np.random.randn(n_observations)), - ), - "hy": ( - [ - "time", - ], - np.abs(np.random.randn(n_observations)), - ), - }, - coords={ - "time": np.arange(n_observations), - }, - ) - - X = xrds.to_array() - cov = covariance_xr(X) - self.assertTrue((cov.data == cov.data.transpose().conj()).all()) - - def test_sometehing_else(self): - """ - Place holder - - """ - pass - - -if __name__ == "__main__": - unittest.main() + # Check that the time values match X's time values + assert np.allclose(RR_clean.time.values, X_clean.time.values) diff --git a/tests/transfer_function/test_cross_power.py b/tests/transfer_function/test_cross_power.py index b312e5e7..6c708f6f 100644 --- a/tests/transfer_function/test_cross_power.py +++ b/tests/transfer_function/test_cross_power.py @@ -1,4 +1,4 @@ -from aurora.time_series.xarray_helpers import initialize_xrda_2d +from mth5.timeseries.xarray_helpers import initialize_xrda_2d_cov from aurora.transfer_function.cross_power import tf_from_cross_powers from aurora.transfer_function.cross_power import _channel_names from aurora.transfer_function.cross_power import ( @@ -32,7 +32,7 @@ def setUpClass(self): station_1_channels = [f"{self.station_ids[0]}_{x}" for x in components] station_2_channels = [f"{self.station_ids[1]}_{x}" for x in components] channels = station_1_channels + station_2_channels - sdm = initialize_xrda_2d( + sdm = initialize_xrda_2d_cov( channels=channels, dtype=complex, )