Skip to content

Commit

Permalink
Added "nearest", "wrap", and "truncate" modes to `ndfilters.generic_f…
Browse files Browse the repository at this point in the history
…ilter()`.
  • Loading branch information
byrdie committed Aug 23, 2024
1 parent c6185e0 commit 8a0d8b7
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 25 deletions.
92 changes: 69 additions & 23 deletions ndfilters/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def generic_filter(
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror"] = "mirror",
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
args: tuple = (),
) -> np.ndarray:
"""
Expand Down Expand Up @@ -42,7 +42,8 @@ def generic_filter(
mode
The method used to extend the input array beyond its boundaries.
See :func:`scipy.ndimage.generic_filter` for the definitions.
Currently, only "reflect" mode is supported.
Currently, only "mirror", "nearest", "wrap", and "truncate" modes are
supported.
args
Extra arguments to pass to function.
Expand Down Expand Up @@ -98,9 +99,6 @@ def function(a: np.ndarray, args: tuple) -> float:
f"{size=} should have the same number of elements as {axis=}."
)

if mode != "mirror": # pragma: nocover
raise ValueError(f"Only mode='reflected' is supported, got {mode=}")

axis_numba = ~np.arange(len(axis))[::-1]

shape = array.shape
Expand Down Expand Up @@ -138,6 +136,30 @@ def function(a: np.ndarray, args: tuple) -> float:
return result


@numba.njit
def _rectify_index_lower(index: int, size: int, mode: str) -> int:
if mode == "mirror":
return -index
elif mode == "nearest":
return 0
elif mode == "wrap":
return index % size

Check warning on line 146 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L143-L146

Added lines #L143 - L146 were not covered by tests
else:
raise ValueError

Check warning on line 148 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L148

Added line #L148 was not covered by tests


@numba.njit
def _rectify_index_upper(index: int, size: int, mode: str) -> int:
if mode == "mirror":
return ~(index % size + 1)
elif mode == "nearest":
return size - 1
elif mode == "wrap":
return index % size

Check warning on line 158 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L155-L158

Added lines #L155 - L158 were not covered by tests
else:
raise ValueError

Check warning on line 160 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L160

Added line #L160 was not covered by tests


@numba.njit(parallel=True)
def _generic_filter_1d(
array: np.ndarray,
Expand All @@ -157,18 +179,22 @@ def _generic_filter_1d(

for ix in numba.prange(array_shape_x):

values = np.empty(shape=size)
mask = np.empty(shape=size, dtype=np.bool_)
values = np.zeros(shape=size)
mask = np.zeros(shape=size, dtype=np.bool_)

for kx in range(kernel_shape_x):

px = kx - kernel_shape_x // 2
jx = ix + px

if jx < 0:
jx = -jx
if mode == "truncate":
continue

Check warning on line 192 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L192

Added line #L192 was not covered by tests
jx = _rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
jx = ~(jx % array_shape_x + 1)
if mode == "truncate":
continue

Check warning on line 196 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L196

Added line #L196 was not covered by tests
jx = _rectify_index_upper(jx, array_shape_x, mode)

values[kx] = array[it, jx]
mask[kx] = where[it, jx]
Expand Down Expand Up @@ -198,28 +224,36 @@ def _generic_filter_2d(
for ix in numba.prange(array_shape_x):
for iy in numba.prange(array_shape_y):

values = np.empty(shape=size)
mask = np.empty(shape=size, dtype=np.bool_)
values = np.zeros(shape=size)
mask = np.zeros(shape=size, dtype=np.bool_)

for kx in range(kernel_shape_x):

px = kx - kernel_shape_x // 2
jx = ix + px

if jx < 0:
jx = -jx
if mode == "truncate":
continue

Check warning on line 237 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L237

Added line #L237 was not covered by tests
jx = _rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
jx = ~(jx % array_shape_x + 1)
if mode == "truncate":
continue

Check warning on line 241 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L241

Added line #L241 was not covered by tests
jx = _rectify_index_upper(jx, array_shape_x, mode)

for ky in range(kernel_shape_y):

py = ky - kernel_shape_y // 2
jy = iy + py

if jy < 0:
jy = -jy
if mode == "truncate":
continue

Check warning on line 251 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L251

Added line #L251 was not covered by tests
jy = _rectify_index_lower(jy, array_shape_y, mode)
elif jy >= array_shape_y:
jy = ~(jy % array_shape_y + 1)
if mode == "truncate":
continue

Check warning on line 255 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L255

Added line #L255 was not covered by tests
jy = _rectify_index_upper(jy, array_shape_y, mode)

values[kx, ky] = array[it, jx, jy]
mask[kx, ky] = where[it, jx, jy]
Expand Down Expand Up @@ -253,38 +287,50 @@ def _generic_filter_3d(
for iy in numba.prange(array_shape_y):
for iz in numba.prange(array_shape_z):

values = np.empty(shape=size)
mask = np.empty(shape=size, dtype=np.bool_)
values = np.zeros(shape=size)
mask = np.zeros(shape=size, dtype=np.bool_)

for kx in range(kernel_shape_x):

px = kx - kernel_shape_x // 2
jx = ix + px

if jx < 0:
jx = -jx
if mode == "truncate":
continue

Check warning on line 300 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L300

Added line #L300 was not covered by tests
jx = _rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
jx = ~(jx % array_shape_x + 1)
if mode == "truncate":
continue

Check warning on line 304 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L304

Added line #L304 was not covered by tests
jx = _rectify_index_upper(jx, array_shape_x, mode)

for ky in range(kernel_shape_y):

py = ky - kernel_shape_y // 2
jy = iy + py

if jy < 0:
jy = -jy
if mode == "truncate":
continue

Check warning on line 314 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L314

Added line #L314 was not covered by tests
jy = _rectify_index_lower(jy, array_shape_y, mode)
elif jy >= array_shape_y:
jy = ~(jy % array_shape_y + 1)
if mode == "truncate":
continue

Check warning on line 318 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L318

Added line #L318 was not covered by tests
jy = _rectify_index_upper(jy, array_shape_y, mode)

for kz in range(kernel_shape_z):

pz = kz - kernel_shape_z // 2
jz = iz + pz

if jz < 0:
jz = -jz
if mode == "truncate":
continue

Check warning on line 328 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L328

Added line #L328 was not covered by tests
jz = _rectify_index_lower(jz, array_shape_z, mode)
elif jz >= array_shape_z:
jz = ~(jz % array_shape_z + 1)
if mode == "truncate":
continue

Check warning on line 332 in ndfilters/_generic.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_generic.py#L332

Added line #L332 was not covered by tests
jz = _rectify_index_upper(jz, array_shape_z, mode)

values[kx, ky, kz] = array[it, jx, jy, jz]
mask[kx, ky, kz] = where[it, jx, jy, jz]
Expand Down
7 changes: 5 additions & 2 deletions ndfilters/_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def trimmed_mean_filter(
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror"] = "mirror",
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
proportion: float = 0.25,
) -> np.ndarray:
"""
Expand All @@ -36,7 +36,8 @@ def trimmed_mean_filter(
mode
The method used to extend the input array beyond its boundaries.
See :func:`scipy.ndimage.generic_filter` for the definitions.
Currently, only "reflect" mode is supported.
Currently, only "mirror", "nearest", "wrap", and "truncate" modes are
supported.
proportion
The proportion to cut from the top and bottom of the distribution.
Expand Down Expand Up @@ -83,6 +84,8 @@ def _trimmed_mean(
(proportion,) = args

nobs = array.size
if nobs == 0:
return np.nan

Check warning on line 88 in ndfilters/_trimmed_mean.py

View check run for this annotation

Codecov / codecov/patch

ndfilters/_trimmed_mean.py#L88

Added line #L88 was not covered by tests
lowercut = int(proportion * nobs)
uppercut = nobs - lowercut
if lowercut > uppercut: # pragma: nocover
Expand Down

0 comments on commit 8a0d8b7

Please sign in to comment.