Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a rolling mean filter #459

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions movement/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,72 @@ def median_filter(
return data_smoothed


@log_to_attrs
def mean_filter(
data: xr.DataArray,
window: int,
min_periods: int | None = None,
print_report: bool = True,
) -> xr.DataArray:
"""Smooth data by applying a mean filter over time.

Parameters
----------
data : xarray.DataArray
The input data to be smoothed.
window : int
The size of the smoothing window, representing the fixed number
of observations used for each window.
min_periods : int
Minimum number of observations in the window required to have
a value (otherwise result is NaN). The default, None, is
equivalent to setting ``min_periods`` equal to the size of the window.
This argument is directly passed to the ``min_periods`` parameter of
:meth:`xarray.DataArray.rolling`.
print_report : bool
Whether to print a report on the number of NaNs in the dataset
before and after smoothing. Default is ``True``.

Returns
-------
xarray.DataArray
The data smoothed using a mean filter with the provided parameters.

Notes
-----
By default, whenever one or more NaNs are present in the smoothing window,
a NaN is returned to the output array. As a result, any
stretch of NaNs present in the input data will be propagated
proportionally to the size of the window (specifically, by
``floor(window/2)``). To control this behaviour, the
``min_periods`` option can be used to specify the minimum number of
non-NaN values required in the window to compute a result. For example,
setting ``min_periods=1`` will result in the filter returning NaNs
only when all values in the window are NaN, since 1 non-NaN value
is sufficient to compute the mean.

"""
half_window = window // 2
data_smoothed = (
data.pad( # Pad the edges to avoid NaNs
time=half_window, mode="reflect"
)
.rolling( # Take rolling windows across time
time=window, center=True, min_periods=min_periods
)
.mean( # Compute the mean of each window
skipna=True
)
.isel( # Remove the padded edges
time=slice(half_window, -half_window)
)
)
if print_report:
print(report_nan_values(data, "input"))
print(report_nan_values(data_smoothed, "output"))
return data_smoothed


@log_to_attrs
def savgol_filter(
data: xr.DataArray,
Expand Down
14 changes: 10 additions & 4 deletions tests/test_unit/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from movement.filtering import (
filter_by_confidence,
interpolate_over_time,
mean_filter,
median_filter,
savgol_filter,
)
Expand All @@ -28,20 +29,23 @@
list_all_valid_datasets,
)
class TestFilteringValidDataset:
"""Test median and savgol filtering on valid datasets with/without NaNs."""
"""Test median, mean and savgol filtering on valid datasets
with/without NaNs.
"""

@pytest.mark.parametrize(
("filter_func, filter_kwargs"),
[
(median_filter, {"window": 3}),
(mean_filter, {"window": 3}),
(savgol_filter, {"window": 3, "polyorder": 2}),
],
)
def test_filter_with_nans_on_position(
self, filter_func, filter_kwargs, valid_dataset, helpers, request
):
"""Test NaN behaviour of the median and SG filters.
Both filters should set all values to NaN if one element of the
"""Test NaN behaviour of the median, mean and SG filters.
All filters should set all values to NaN if one element of the
sliding window is NaN.
"""
# Expected number of nans in the position array per individual
Expand Down Expand Up @@ -155,7 +159,9 @@ def test_interpolate_over_time_on_position(
"window",
[3, 5, 6, 10], # input data has 10 frames
)
@pytest.mark.parametrize("filter_func", [median_filter, savgol_filter])
@pytest.mark.parametrize(
"filter_func", [median_filter, mean_filter, savgol_filter]
)
def test_filter_with_nans_on_position_varying_window(
self, valid_dataset_with_nan, window, filter_func, helpers, request
):
Expand Down
Loading