Skip to content

Commit

Permalink
fixes neuroinformatics-unit#454: Add a rolling mean filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ArismitaM committed Mar 3, 2025
1 parent b0676b7 commit dccf34c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 3 deletions.
66 changes: 66 additions & 0 deletions movement/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,72 @@ def median_filter(
return data_smoothed


@log_to_attrs
def mean_filter(
data: xr.DataArray,
window: int,
min_periods: int | None = None,
print_report: bool = True,
) -> xr.DataArray:
"""Smooth data by applying a mean filter over time.
Parameters
----------
data : xarray.DataArray
The input data to be smoothed.
window : int
The size of the smoothing window, representing the fixed number
of observations used for each window.
min_periods : int
Minimum number of observations in the window required to have
a value (otherwise result is NaN). The default, None, is
equivalent to setting ``min_periods`` equal to the size of the window.
This argument is directly passed to the ``min_periods`` parameter of
:meth:`xarray.DataArray.rolling`.
print_report : bool
Whether to print a report on the number of NaNs in the dataset
before and after smoothing. Default is ``True``.
Returns
-------
xarray.DataArray
The data smoothed using a mean filter with the provided parameters.
Notes
-----
By default, whenever one or more NaNs are present in the smoothing window,
a NaN is returned to the output array. As a result, any
stretch of NaNs present in the input data will be propagated
proportionally to the size of the window (specifically, by
``floor(window/2)``). To control this behaviour, the
``min_periods`` option can be used to specify the minimum number of
non-NaN values required in the window to compute a result. For example,
setting ``min_periods=1`` will result in the filter returning NaNs
only when all values in the window are NaN, since 1 non-NaN value
is sufficient to compute the mean.
"""
half_window = window // 2
data_smoothed = (
data.pad( # Pad the edges to avoid NaNs
time=half_window, mode="reflect"
)
.rolling( # Take rolling windows across time
time=window, center=True, min_periods=min_periods
)
.mean( # Compute the mean of each window
skipna=True
)
.isel( # Remove the padded edges
time=slice(half_window, -half_window)
)
)
if print_report:
print(report_nan_values(data, "input"))
print(report_nan_values(data_smoothed, "output"))
return data_smoothed


@log_to_attrs
def savgol_filter(
data: xr.DataArray,
Expand Down
12 changes: 9 additions & 3 deletions tests/test_unit/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from movement.filtering import (
filter_by_confidence,
interpolate_over_time,
mean_filter,
median_filter,
savgol_filter,
)
Expand All @@ -28,19 +29,22 @@
list_all_valid_datasets,
)
class TestFilteringValidDataset:
"""Test median and savgol filtering on valid datasets with/without NaNs."""
"""Test median, mean and savgol filtering on valid datasets
with/without NaNs.
"""

@pytest.mark.parametrize(
("filter_func, filter_kwargs"),
[
(median_filter, {"window": 3}),
(mean_filter, {"window": 3}),
(savgol_filter, {"window": 3, "polyorder": 2}),
],
)
def test_filter_with_nans_on_position(
self, filter_func, filter_kwargs, valid_dataset, helpers, request
):
"""Test NaN behaviour of the median and SG filters.
"""Test NaN behaviour of the median, mean and SG filters.
Both filters should set all values to NaN if one element of the
sliding window is NaN.
"""
Expand Down Expand Up @@ -155,7 +159,9 @@ def test_interpolate_over_time_on_position(
"window",
[3, 5, 6, 10], # input data has 10 frames
)
@pytest.mark.parametrize("filter_func", [median_filter, savgol_filter])
@pytest.mark.parametrize(
"filter_func", [median_filter, mean_filter, savgol_filter]
)
def test_filter_with_nans_on_position_varying_window(
self, valid_dataset_with_nan, window, filter_func, helpers, request
):
Expand Down

0 comments on commit dccf34c

Please sign in to comment.