Skip to content

Commit 64a16bf

Browse files
committed
add label filter
1 parent b1288b5 commit 64a16bf

File tree

4 files changed

+198
-2
lines changed

4 files changed

+198
-2
lines changed

mimikit/extract/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .samplify import *
33
from .segment import *
44
from .from_neighbors import *
5+
from .label_filter import *
56

67

78
__all__ = [_ for _ in dir() if not _.startswith("_")]

mimikit/extract/from_neighbors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
]
1111

1212

13-
def nearest_neighbor(X, Y):
13+
def nearest_neighbor(X, Y, metric=AngularDistance()):
1414
"""
1515
computes nearest neighbor by angular distance
1616
"""
17-
D_xy = AngularDistance()(X, Y)
17+
D_xy = metric(X, Y)
1818
dists, nn = torch.min(D_xy, dim=-1)
1919
return dists, nn
2020

mimikit/extract/label_filter.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import numpy as np
2+
import dataclasses as dtc
3+
from ..features.functionals import Functional
4+
5+
__all__ = [
6+
"label_filter",
7+
"LabelFilter"
8+
]
9+
10+
11+
def _get_counts(labels, R):
12+
K = R + 1
13+
if K % 2 == 0:
14+
K += 1
15+
l2d = np.lib.stride_tricks.sliding_window_view(np.pad(labels, (K//2, K//2), constant_values=-1), (K,))
16+
where_l, where_r = np.ones(l2d.shape[0], dtype=bool), np.ones(l2d.shape[0], dtype=bool)
17+
counts = np.ones(l2d.shape[0], dtype=int)
18+
center = K//2
19+
for r in range(1, K//2 + 1):
20+
where_l[where_l] = (l2d[where_l, center] == l2d[where_l, center - r])
21+
where_r[where_r] = (l2d[where_r, center] == l2d[where_r, center + r])
22+
counts[where_l] += 1
23+
counts[where_r] += 1
24+
flagged = counts < R
25+
c2d = np.lib.stride_tricks.sliding_window_view(np.pad(counts, (K//2, K//2), constant_values=0), (K,))
26+
return l2d, c2d, flagged
27+
28+
29+
def _filter_window(w, elem_counts, glob_counts, label_undecidable):
30+
e_i = w.shape[0]//2
31+
elem = w[e_i]
32+
elem_count = elem_counts[e_i]
33+
glob_elem_count = glob_counts[elem]
34+
c_max_i = elem_counts.argmax()
35+
c_max = elem_counts[c_max_i]
36+
if c_max == 1:
37+
# w = np.sort(w)
38+
w_hat = w[elem_counts == 1]
39+
gc = glob_counts[w_hat]
40+
gc_max_i = gc.argmax()
41+
gc_max = gc[gc_max_i]
42+
if gc_max == 1 or (gc == gc_max).all():
43+
# all labels are global singletons
44+
if label_undecidable:
45+
v = -1
46+
elif w[e_i-1] == w[e_i+1]:
47+
# elem is surrounded by the same element
48+
v = w[e_i-1]
49+
else:
50+
v = elem
51+
elif (gc == gc_max).sum() > 1:
52+
# tie between labels
53+
if gc_max == glob_elem_count:
54+
# we keep it
55+
v = elem
56+
else:
57+
# first max
58+
v = w_hat[gc_max_i]
59+
else:
60+
v = w_hat[gc_max_i]
61+
else:
62+
if elem_count == c_max:
63+
v = elem
64+
else:
65+
v = w[c_max_i]
66+
return v
67+
68+
69+
def label_filter(
70+
labels: np.ndarray,
71+
min_repetition: int,
72+
label_undecidable: bool = True,
73+
relabel_output: bool = True
74+
) -> np.ndarray:
75+
if min_repetition == 1:
76+
return labels
77+
glob_counts = np.r_[np.bincount(labels), 0] # for -1 labels
78+
l2d, c2d, flagged = _get_counts(labels, min_repetition)
79+
while np.any(flagged):
80+
out = np.zeros_like(labels)
81+
for i in flagged.nonzero()[0]:
82+
out[i] = _filter_window(l2d[i], c2d[i], glob_counts, label_undecidable)
83+
if np.all(labels[flagged] == out[flagged]):
84+
break
85+
out[~flagged] = labels[~flagged]
86+
labels = out
87+
l2d, c2d, flagged = _get_counts(labels, min_repetition)
88+
if relabel_output:
89+
_, labels = np.unique(labels, return_inverse=True)
90+
return labels
91+
92+
93+
@dtc.dataclass
94+
class LabelFilter(Functional):
95+
min_repetition: int = 1
96+
label_undecidable: bool = False
97+
98+
def np_func(self, inputs):
99+
return label_filter(inputs,
100+
self.min_repetition,
101+
self.label_undecidable,
102+
relabel_output=False)
103+
104+
def torch_func(self, inputs):
105+
pass
106+
107+
@property
108+
def inv(self) -> "Functional":
109+
return None

tests/test_label_filter.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from mimikit.extract.label_filter import label_filter
2+
import numpy as np
3+
from assertpy import assert_that
4+
import pytest
5+
6+
7+
def test_should_extend_repetition_on_the_edges():
8+
given_labels = np.r_[0, 0, 1, 2, 3, 4, 4]
9+
given_min_rep = 2
10+
expected_result = np.r_[0, 0, 0, 0, 4, 4, 4]
11+
12+
result = label_filter(given_labels, given_min_rep, relabel_output=False)
13+
14+
assert_that(np.all(result == expected_result)).is_true()
15+
16+
17+
def test_should_extend_edges_and_replace_single_labels_with_undecidable_label():
18+
given_labels = np.r_[0, 0, 1, 2, 3, 4, 5, 5]
19+
given_min_rep = 2
20+
# if min_rep was 3, the -1 in the middle would be absorbed by 0 and 5
21+
expected_result = np.r_[0, 0, 0, -1, -1, 5, 5, 5]
22+
23+
result = label_filter(given_labels, given_min_rep, relabel_output=False)
24+
25+
assert_that(np.all(result == expected_result)).is_true()
26+
27+
28+
def test_should_extend_edges_and_not_replace_single_labels_without_undecidable_label():
29+
given_labels = np.r_[0, 0, 1, 2, 3, 4, 5, 5]
30+
given_min_rep = 2
31+
# if min_rep was 3, the -1 in the middle would be absorbed by 0 and 5
32+
expected_result = np.r_[0, 0, 0, 0, 5, 5, 5, 5]
33+
34+
result = label_filter(given_labels, given_min_rep, label_undecidable=False, relabel_output=False)
35+
36+
assert_that(np.all(result == expected_result)).is_true()
37+
38+
39+
def test_should_replace_undecidable_with_minus_one():
40+
given_labels = np.r_[0, 1, 2, 1, 2, 0]
41+
given_min_rep = 2
42+
expected_result = np.r_[-1, -1, -1, -1, -1, -1]
43+
44+
result = label_filter(given_labels, given_min_rep, relabel_output=False)
45+
46+
assert_that(np.all(result == expected_result)).is_true()
47+
48+
49+
def test_should_replace_with_surrounding_elem_without_undecidable_label():
50+
given_labels = np.r_[0, 1, 2, 1, 2, 0]
51+
given_min_rep = 2
52+
expected_result = np.r_[1, 1, 1, 2, 2, 2]
53+
54+
result = label_filter(given_labels, given_min_rep, label_undecidable=False, relabel_output=False)
55+
56+
assert_that(np.all(result == expected_result)).is_true()
57+
58+
59+
def test_should_return_input_if_undecidable():
60+
given_labels = np.r_[0, 1, 2, 3, 1, 2, 3, 0]
61+
given_min_rep = 2
62+
expected_result = given_labels
63+
64+
result = label_filter(given_labels, given_min_rep, label_undecidable=False, relabel_output=False)
65+
66+
assert_that(np.all(result == expected_result)).is_true()
67+
68+
69+
def test_should_fallback_to_global_counts():
70+
given_labels = np.r_[0, 1, 2, 1, 2]
71+
given_min_rep = 2
72+
expected_result = np.r_[1, 1, -1, -1, -1]
73+
74+
result = label_filter(given_labels, given_min_rep, relabel_output=False)
75+
76+
assert_that(np.all(result == expected_result)).is_true()
77+
78+
79+
def test_should_handle_edges_correctly():
80+
given_labels = np.r_[0, 0, 1, 1, 1]
81+
given_min_rep = 3
82+
expected_result = np.r_[1, 1, 1, 1, 1]
83+
84+
result = label_filter(given_labels, given_min_rep, relabel_output=False)
85+
86+
assert_that(np.all(result == expected_result)).is_true()

0 commit comments

Comments
 (0)