Skip to content

[WIP] ENH: Resample additional arrays apart from X and y #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/over-sampling/plot_comparison_over_sampling.py
Original file line number Diff line number Diff line change
@@ -134,7 +134,7 @@ class FakeSampler(BaseSampler):

_sampling_type = 'bypass'

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
return X, y


42 changes: 35 additions & 7 deletions imblearn/base.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,8 @@
from sklearn.externals import six
from sklearn.preprocessing import label_binarize
from sklearn.utils import check_X_y
from sklearn.utils import check_consistent_length
from sklearn.utils import check_array

from .utils import check_sampling_strategy, check_target_type
from .utils.deprecation import deprecate_parameter
@@ -55,7 +57,7 @@ def fit(self, X, y):
self.sampling_strategy, y, self._sampling_type)
return self

def fit_resample(self, X, y):
def fit_resample(self, X, y, sample_weight=None):
"""Resample the dataset.
Parameters
@@ -66,24 +68,39 @@ def fit_resample(self, X, y):
y : array-like, shape (n_samples,)
Corresponding label for each sample in X.
sample_weight : array-like, shape (n_samples,) or None
Sample weights.
Returns
-------
X_resampled : {array-like, sparse matrix}, shape \
X_resampled : {ndarray, sparse matrix}, shape \
(n_samples_new, n_features)
The array containing the resampled data.
y_resampled : array-like, shape (n_samples_new,)
y_resampled : ndarray, shape (n_samples_new,)
The corresponding label of `X_resampled`.
sample_weight_resampled : ndarray, shape (n_samples_new,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather have a dict of non-X,y returned. (Optionally? In scikit-learn I would rather this be mandatory so we don't need to handle both cases.)

Resampled sample weights. This output is returned only if
``sample_weight`` was not ``None``.
idx_resampled : ndarray, shape (n_samples_new,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this should be returned from fit_resample, rather than stored as an attribute?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it was some original design (before it was in scikit-learn). But actually it would be better to keep it as an attribute with the single fit_resample.

Indices of the selected features. This output is optional and only
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the selected samples?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

available for some sampler if ``return_indices=True``.
"""
self._deprecate_ratio()

X, y, binarize_y = self._check_X_y(X, y)
if sample_weight is not None:
sample_weight = check_array(sample_weight, ensure_2d=False)
check_consistent_length(X, y, sample_weight)

self.sampling_strategy_ = check_sampling_strategy(
self.sampling_strategy, y, self._sampling_type)

output = self._fit_resample(X, y)
output = self._fit_resample(X, y, sample_weight)

if binarize_y:
y_sampled = label_binarize(output[1], np.unique(y))
@@ -96,7 +113,7 @@ def fit_resample(self, X, y):
fit_sample = fit_resample

@abstractmethod
def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
"""Base method defined in each sampler to defined the sampling
strategy.
@@ -108,14 +125,25 @@ def _fit_resample(self, X, y):
y : array-like, shape (n_samples,)
Corresponding label for each sample in X.
sample_weight : array-like, shape (n_samples,) or None
Sample weights.
Returns
-------
X_resampled : {ndarray, sparse matrix}, shape \
(n_samples_new, n_features)
The array containing the resampled data.
y_resampled : ndarray, shape (n_samples_new,)
The corresponding label of `X_resampled`
The corresponding label of `X_resampled`.
sample_weight_resampled : ndarray, shape (n_samples_new,)
Resampled sample weights. This output is returned only if
``sample_weight`` was not ``None``.
idx_resampled : ndarray, shape (n_samples_new,)
Indices of the selected features. This output is optional and only
available for some sampler if ``return_indices=True``.
"""
pass
@@ -243,7 +271,7 @@ def __init__(self, func=None, accept_sparse=True, kw_args=None):
self.kw_args = kw_args
self.logger = logging.getLogger(__name__)

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']
if self.accept_sparse else False)
func = _identity if self.func is None else self.func
6 changes: 3 additions & 3 deletions imblearn/combine/_smote_enn.py
Original file line number Diff line number Diff line change
@@ -125,11 +125,11 @@ def _validate_estimator(self):
else:
self.enn_ = EditedNearestNeighbours(sampling_strategy='all')

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()
y = check_target_type(y)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
self.sampling_strategy_ = self.sampling_strategy

X_res, y_res = self.smote_.fit_resample(X, y)
return self.enn_.fit_resample(X_res, y_res)
resampled_arrays = self.smote_.fit_resample(X, y, sample_weight)
return self.enn_.fit_resample(*resampled_arrays)
6 changes: 3 additions & 3 deletions imblearn/combine/_smote_tomek.py
Original file line number Diff line number Diff line change
@@ -134,11 +134,11 @@ def _validate_estimator(self):
else:
self.tomek_ = TomekLinks(sampling_strategy='all')

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()
y = check_target_type(y)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
self.sampling_strategy_ = self.sampling_strategy

X_res, y_res = self.smote_.fit_resample(X, y)
return self.tomek_.fit_resample(X_res, y_res)
resampled_arrays = self.smote_.fit_resample(X, y, sample_weight)
return self.tomek_.fit_resample(*resampled_arrays)
17 changes: 9 additions & 8 deletions imblearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
@@ -210,7 +210,7 @@ def __init__(self,
self.ratio = ratio
self.replacement = replacement

def _validate_estimator(self, default=DecisionTreeClassifier()):
def _validate_estimator(self):
"""Check the estimator and the n_estimator attribute, set the
`base_estimator_` attribute."""
if not isinstance(self.n_estimators, (numbers.Integral, np.integer)):
@@ -224,12 +224,14 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
if self.base_estimator is not None:
base_estimator = clone(self.base_estimator)
else:
base_estimator = clone(default)
base_estimator = clone(DecisionTreeClassifier())

self.base_estimator_ = Pipeline([('sampler', RandomUnderSampler(
sampling_strategy=self.sampling_strategy,
replacement=self.replacement,
ratio=self.ratio)), ('classifier', base_estimator)])
self.base_estimator_ = Pipeline([
('sampler', RandomUnderSampler(
sampling_strategy=self.sampling_strategy,
replacement=self.replacement,
ratio=self.ratio)),
('classifier', base_estimator)])

def fit(self, X, y):
"""Build a Bagging ensemble of estimators from the training
@@ -248,6 +250,5 @@ def fit(self, X, y):
self : object
Returns self.
"""
# RandomUnderSampler is not supporting sample_weight. We need to pass
# None.
# Pipeline does not support sample_weight
return self._fit(X, y, self.max_samples, sample_weight=None)
2 changes: 1 addition & 1 deletion imblearn/ensemble/_balance_cascade.py
Original file line number Diff line number Diff line change
@@ -128,7 +128,7 @@ def _validate_estimator(self):

self.logger.debug(self.estimator_)

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()

self.sampling_strategy_ = check_sampling_strategy(
2 changes: 1 addition & 1 deletion imblearn/ensemble/_easy_ensemble.py
Original file line number Diff line number Diff line change
@@ -114,7 +114,7 @@ def __init__(self,
self.replacement = replacement
self.n_subsets = n_subsets

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
random_state = check_random_state(self.random_state)

X_resampled = []
17 changes: 12 additions & 5 deletions imblearn/over_sampling/_adasyn.py
Original file line number Diff line number Diff line change
@@ -106,12 +106,14 @@ def _validate_estimator(self):
'n_neighbors', self.n_neighbors, additional_neighbor=1)
self.nn_.set_params(**{'n_jobs': self.n_jobs})

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()
random_state = check_random_state(self.random_state)

X_resampled = X.copy()
y_resampled = y.copy()
if sample_weight is not None:
sample_weight_resampled = sample_weight.copy()

for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
@@ -165,8 +167,6 @@ def _fit_resample(self, X, y):
X_new = (sparse.csr_matrix(
(samples, (row_indices, col_indices)),
[np.sum(n_samples_generate), X.shape[1]], dtype=X.dtype))
y_new = np.array([class_sample] * np.sum(n_samples_generate),
dtype=y.dtype)
else:
x_class_gen = []
for x_i, x_i_nn, num_sample_i in zip(X_class, nn_index,
@@ -182,13 +182,20 @@ def _fit_resample(self, X, y):
])

X_new = np.concatenate(x_class_gen).astype(X.dtype)
y_new = np.array([class_sample] * np.sum(n_samples_generate),
dtype=y.dtype)

y_new = np.array([class_sample] * np.sum(n_samples_generate),
dtype=y.dtype)
if sample_weight is not None:
sample_weight_resampled = np.hstack(
(sample_weight_resampled,
np.ones_like(y_new, dtype=sample_weight.dtype)))

if sparse.issparse(X_new):
X_resampled = sparse.vstack([X_resampled, X_new])
else:
X_resampled = np.vstack((X_resampled, X_new))
y_resampled = np.hstack((y_resampled, y_new))

if sample_weight is not None:
return X_resampled, y_resampled, sample_weight_resampled
return X_resampled, y_resampled
13 changes: 7 additions & 6 deletions imblearn/over_sampling/_random_over_sampler.py
Original file line number Diff line number Diff line change
@@ -88,7 +88,7 @@ def _check_X_y(X, y):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None)
return X, y, binarize_y

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
random_state = check_random_state(self.random_state)
target_stats = Counter(y)

@@ -102,9 +102,10 @@ def _fit_resample(self, X, y):
sample_indices = np.append(sample_indices,
target_class_indices[indices])

resampled_arrays = [safe_indexing(arr, sample_indices)
for arr in (X, y, sample_weight)
if arr is not None]

if self.return_indices:
return (safe_indexing(X, sample_indices), safe_indexing(
y, sample_indices), sample_indices)
else:
return (safe_indexing(X, sample_indices), safe_indexing(
y, sample_indices))
return tuple(resampled_arrays + [sample_indices])
return tuple(resampled_arrays)
61 changes: 49 additions & 12 deletions imblearn/over_sampling/_smote.py
Original file line number Diff line number Diff line change
@@ -123,8 +123,7 @@ def _make_samples(self,
[len(samples_indices), X.shape[1]],
dtype=X.dtype),
y_new)
else:
return X_new, y_new
return X_new, y_new

def _in_danger_noise(self, nn_estimator, samples, target_class, y,
kind='danger'):
@@ -280,14 +279,16 @@ def _validate_estimator(self):
'Got {} instead.'.format(self.kind))

# FIXME: rename _sample -> _fit_resample in 0.6
def _fit_resample(self, X, y):
return self._sample(X, y)
def _fit_resample(self, X, y, sample_weight=None):
return self._sample(X, y, sample_weight)

def _sample(self, X, y):
def _sample(self, X, y, sample_weight=None):
self._validate_estimator()

X_resampled = X.copy()
y_resampled = y.copy()
if sample_weight is not None:
sample_weight_resampled = sample_weight.copy()

for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
@@ -317,6 +318,10 @@ def _sample(self, X, y):
else:
X_resampled = np.vstack((X_resampled, X_new))
y_resampled = np.hstack((y_resampled, y_new))
if sample_weight is not None:
sample_weight_resampled = np.hstack(
(sample_weight_resampled,
np.ones_like(y_new, dtype=sample_weight.dtype)))

elif self.kind == 'borderline-2':
random_state = check_random_state(self.random_state)
@@ -350,7 +355,14 @@ def _sample(self, X, y):
else:
X_resampled = np.vstack((X_resampled, X_new_1, X_new_2))
y_resampled = np.hstack((y_resampled, y_new_1, y_new_2))

if sample_weight is not None:
sample_weight_resampled = np.hstack(
(sample_weight_resampled,
np.ones_like(y_new_1, dtype=sample_weight.dtype),
np.ones_like(y_new_2, dtype=sample_weight.dtype)))

if sample_weight is not None:
return X_resampled, y_resampled, sample_weight_resampled
return X_resampled, y_resampled


@@ -466,14 +478,16 @@ def _validate_estimator(self):
self.svm_estimator)

# FIXME: rename _sample -> _fit_resample in 0.6
def _fit_resample(self, X, y):
return self._sample(X, y)
def _fit_resample(self, X, y, sample_weight=None):
return self._sample(X, y, sample_weight)

def _sample(self, X, y):
def _sample(self, X, y, sample_weight=None):
self._validate_estimator()
random_state = check_random_state(self.random_state)
X_resampled = X.copy()
y_resampled = y.copy()
if sample_weight is not None:
sample_weight_resampled = sample_weight.copy()

for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
@@ -535,19 +549,34 @@ def _sample(self, X, y):
X_resampled = np.vstack((X_resampled, X_new_1, X_new_2))
y_resampled = np.concatenate(
(y_resampled, y_new_1, y_new_2), axis=0)
if sample_weight is not None:
sample_weight_resampled = np.hstack(
(sample_weight_resampled,
np.ones_like(y_new_1, dtype=sample_weight.dtype),
np.ones_like(y_new_2, dtype=sample_weight.dtype)))
elif np.count_nonzero(danger_bool) == 0:
if sparse.issparse(X_resampled):
X_resampled = sparse.vstack([X_resampled, X_new_2])
else:
X_resampled = np.vstack((X_resampled, X_new_2))
y_resampled = np.concatenate((y_resampled, y_new_2), axis=0)
if sample_weight is not None:
sample_weight_resampled = np.hstack(
(sample_weight_resampled,
np.ones_like(y_new_2, dtype=sample_weight.dtype)))
elif np.count_nonzero(safety_bool) == 0:
if sparse.issparse(X_resampled):
X_resampled = sparse.vstack([X_resampled, X_new_1])
else:
X_resampled = np.vstack((X_resampled, X_new_1))
y_resampled = np.concatenate((y_resampled, y_new_1), axis=0)
if sample_weight is not None:
sample_weight_resampled = np.hstack(
(sample_weight_resampled,
np.ones_like(y_new_1, dtype=sample_weight.dtype)))

if sample_weight is not None:
return X_resampled, y_resampled, sample_weight_resampled
return X_resampled, y_resampled


@@ -735,16 +764,18 @@ def _validate_estimator(self):
self.nn_m_.set_params(**{'n_jobs': self.n_jobs})

# FIXME: to be removed in 0.6
def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()
return self._sample(X, y)
return self._sample(X, y, sample_weight)

def _sample(self, X, y):
def _sample(self, X, y, sample_weight=None):
# FIXME: uncomment in version 0.6
# self._validate_estimator()

X_resampled = X.copy()
y_resampled = y.copy()
if sample_weight is not None:
sample_weight_resampled = sample_weight.copy()

for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
@@ -762,5 +793,11 @@ def _sample(self, X, y):
else:
X_resampled = np.vstack((X_resampled, X_new))
y_resampled = np.hstack((y_resampled, y_new))
if sample_weight is not None:
sample_weight_resampled = np.hstack(
(sample_weight_resampled,
np.ones_like(y_new, dtype=sample_weight.dtype)))

if sample_weight is not None:
return X_resampled, y_resampled, sample_weight_resampled
return X_resampled, y_resampled
1 change: 0 additions & 1 deletion imblearn/pipeline.py
Original file line number Diff line number Diff line change
@@ -562,7 +562,6 @@ def _fit_transform_one(transformer, weight, X, y, **fit_params):

def _fit_resample_one(sampler, X, y, **fit_params):
X_res, y_res = sampler.fit_resample(X, y, **fit_params)

return X_res, y_res, sampler


52 changes: 34 additions & 18 deletions imblearn/under_sampling/_prototype_generation/_cluster_centroids.py
Original file line number Diff line number Diff line change
@@ -119,30 +119,35 @@ def _validate_estimator(self):
raise ValueError('`estimator` has to be a KMeans clustering.'
' Got {} instead.'.format(type(self.estimator)))

def _generate_sample(self, X, y, centroids, target_class):
def _generate_sample(self, X, y, sample_weight, centroids, target_class):
if self.voting_ == 'hard':
nearest_neighbors = NearestNeighbors(n_neighbors=1)
nearest_neighbors.fit(X, y)
indices = nearest_neighbors.kneighbors(
centroids, return_distance=False)
X_new = safe_indexing(X, np.squeeze(indices))
if sample_weight is not None:
sample_weight_new = safe_indexing(sample_weight,
np.squeeze(indices))
else:
if sparse.issparse(X):
X_new = sparse.csr_matrix(centroids, dtype=X.dtype)
else:
X_new = centroids
if sample_weight is not None:
sample_weight_new = np.ones(centroids.shape[0],
dtype=sample_weight.dtype)
y_new = np.array([target_class] * centroids.shape[0], dtype=y.dtype)

if sample_weight is not None:
return X_new, y_new, sample_weight_new
return X_new, y_new

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()

if self.voting == 'auto':
if sparse.issparse(X):
self.voting_ = 'hard'
else:
self.voting_ = 'soft'
self.voting_ = 'hard' if sparse.issparse(X) else 'soft'
else:
if self.voting in VOTING_KIND:
self.voting_ = self.voting
@@ -151,24 +156,35 @@ def _fit_resample(self, X, y):
" instead.".format(VOTING_KIND, self.voting))

X_resampled, y_resampled = [], []
if sample_weight is not None:
sample_weight_resampled = []
for target_class in np.unique(y):
if target_class in self.sampling_strategy_.keys():
n_samples = self.sampling_strategy_[target_class]
self.estimator_.set_params(**{'n_clusters': n_samples})
self.estimator_.fit(X[y == target_class])
X_new, y_new = self._generate_sample(
X, y, self.estimator_.cluster_centers_, target_class)
X_resampled.append(X_new)
y_resampled.append(y_new)
new_arrays = self._generate_sample(
X, y, sample_weight, self.estimator_.cluster_centers_,
target_class)
X_resampled.append(new_arrays[0])
y_resampled.append(new_arrays[1])
if sample_weight is not None:
sample_weight_resampled.append(new_arrays[2])
else:
target_class_indices = np.flatnonzero(y == target_class)
X_resampled.append(safe_indexing(X, target_class_indices))
y_resampled.append(safe_indexing(y, target_class_indices))

if sparse.issparse(X):
X_resampled = sparse.vstack(X_resampled)
else:
X_resampled = np.vstack(X_resampled)
y_resampled = np.hstack(y_resampled)

return X_resampled, np.array(y_resampled, dtype=y.dtype)
if sample_weight is not None:
sample_weight_resampled.append(
safe_indexing(sample_weight, target_class_indices))

X_resampled = (sparse.vstack(X_resampled)
if sparse.issparse(X) else np.vstack(X_resampled))
y_resampled = np.array(np.hstack(y_resampled), dtype=y.dtype)
if sample_weight is not None:
sample_weight_resampled = np.array(
np.hstack(sample_weight_resampled), dtype=sample_weight.dtype)

if sample_weight is not None:
return X_resampled, y_resampled, sample_weight_resampled
return X_resampled, y_resampled
Original file line number Diff line number Diff line change
@@ -128,7 +128,7 @@ def _validate_estimator(self):
' inhereited from KNeighborsClassifier.'
' Got {} instead.'.format(type(self.n_neighbors)))

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()

random_state = check_random_state(self.random_state)
@@ -201,8 +201,10 @@ def _fit_resample(self, X, y):
idx_under = np.concatenate(
(idx_under, np.flatnonzero(y == target_class)), axis=0)

resampled_arrays = [safe_indexing(arr, idx_under)
for arr in (X, y, sample_weight)
if arr is not None]

if self.return_indices:
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
idx_under)
else:
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
return tuple(resampled_arrays + [idx_under])
return tuple(resampled_arrays)
Original file line number Diff line number Diff line change
@@ -138,7 +138,7 @@ def _validate_estimator(self):
if self.kind_sel not in SEL_KIND:
raise NotImplementedError

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()

idx_under = np.empty((0, ), dtype=int)
@@ -168,11 +168,13 @@ def _fit_resample(self, X, y):
np.flatnonzero(y == target_class)[index_target_class]),
axis=0)

resampled_arrays = [safe_indexing(arr, idx_under)
for arr in (X, y, sample_weight)
if arr is not None]

if self.return_indices:
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
idx_under)
else:
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
return tuple(resampled_arrays + [idx_under])
return tuple(resampled_arrays)


@Substitution(
@@ -303,22 +305,35 @@ def _validate_estimator(self):
n_jobs=self.n_jobs,
ratio=self.ratio)

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()

X_, y_ = X, y
X_, y_, sample_weight_ = X, y, sample_weight
if self.return_indices:
idx_under = np.arange(X.shape[0], dtype=int)
target_stats = Counter(y)
class_minority = min(target_stats, key=target_stats.get)

for n_iter in range(self.max_iter):
for _ in range(self.max_iter):

prev_len = y_.shape[0]
if self.return_indices:
X_enn, y_enn, idx_enn = self.enn_.fit_resample(X_, y_)
resampled_data = self.enn_.fit_resample(X_, y_, sample_weight_)
else:
resampled_data = self.enn_.fit_resample(X_, y_, sample_weight_)

# unpacking data
if len(resampled_data) == 2:
X_enn, y_enn = resampled_data
sample_weight_enn = None
elif len(resampled_data) == 3:
if sample_weight_ is not None:
X_enn, y_enn, sample_weight_enn = resampled_data
else:
X_enn, y_enn, idx_enn = resampled_data
sample_weight_enn = None
else:
X_enn, y_enn = self.enn_.fit_resample(X_, y_)
X_enn, y_enn, sample_weight_enn, idx_enn = resampled_data

# Check the stopping criterion
# 1. If there is no changes for the vector y
@@ -341,25 +356,24 @@ def _fit_resample(self, X, y):
# Case 3
b_remove_maj_class = (len(stats_enn) < len(target_stats))

X_, y_, = X_enn, y_enn
X_, y_, sample_weight_ = X_enn, y_enn, sample_weight_enn

if self.return_indices:
idx_under = idx_under[idx_enn]

if b_conv or b_min_bec_maj or b_remove_maj_class:
if b_conv:
X_, y_, sample_weight_ = X_enn, y_enn, sample_weight_enn
if self.return_indices:
X_, y_, = X_enn, y_enn
idx_under = idx_under[idx_enn]
else:
X_, y_, = X_enn, y_enn
break

X_resampled, y_resampled = X_, y_
resampled_arrays = [arr for arr in (X_, y_, sample_weight_)
if arr is not None]

if self.return_indices:
return X_resampled, y_resampled, idx_under
else:
return X_resampled, y_resampled
return tuple(resampled_arrays + [idx_under])
return tuple(resampled_arrays)


@Substitution(
@@ -489,10 +503,10 @@ def _validate_estimator(self):
n_jobs=self.n_jobs,
ratio=self.ratio)

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()

X_, y_ = X, y
X_, y_, sample_weight_ = X, y, sample_weight
target_stats = Counter(y)
class_minority = min(target_stats, key=target_stats.get)

@@ -503,9 +517,22 @@ def _fit_resample(self, X, y):
self.enn_.n_neighbors = curr_size_ngh

if self.return_indices:
X_enn, y_enn, idx_enn = self.enn_.fit_resample(X_, y_)
resampled_data = self.enn_.fit_resample(X_, y_, sample_weight_)
else:
resampled_data = self.enn_.fit_resample(X_, y_, sample_weight_)

# unpacking data
if len(resampled_data) == 2:
X_enn, y_enn = resampled_data
sample_weight_enn = None
elif len(resampled_data) == 3:
if sample_weight_ is not None:
X_enn, y_enn, sample_weight_enn = resampled_data
else:
X_enn, y_enn, idx_enn = resampled_data
sample_weight_enn = None
else:
X_enn, y_enn = self.enn_.fit_resample(X_, y_)
X_enn, y_enn, sample_weight_enn, idx_enn = resampled_data

# Check the stopping criterion
# 1. If the number of samples in the other class become inferior to
@@ -526,16 +553,16 @@ def _fit_resample(self, X, y):
# Case 2
b_remove_maj_class = (len(stats_enn) < len(target_stats))

X_, y_, = X_enn, y_enn
X_, y_, sample_weight_ = X_enn, y_enn, sample_weight_enn
if self.return_indices:
idx_under = idx_under[idx_enn]

if b_min_bec_maj or b_remove_maj_class:
break

X_resampled, y_resampled = X_, y_
resampled_arrays = [arr for arr in (X_, y_, sample_weight_)
if arr is not None]

if self.return_indices:
return X_resampled, y_resampled, idx_under
else:
return X_resampled, y_resampled
return tuple(resampled_arrays + [idx_under])
return tuple(resampled_arrays)
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.utils import safe_indexing
from sklearn.utils.fixes import signature

from ..base import BaseCleaningSampler
from ...utils import Substitution
@@ -125,7 +126,7 @@ def _validate_estimator(self):
raise ValueError('Invalid parameter `estimator`. Got {}.'.format(
type(self.estimator)))

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()

target_stats = Counter(y)
@@ -134,13 +135,23 @@ def _fit_resample(self, X, y):
random_state=self.random_state).split(X, y)
probabilities = np.zeros(y.shape[0], dtype=float)

support_sample_weight = "sample_weight" in signature(
self.estimator_.fit).parameters

for train_index, test_index in skf:
X_train = safe_indexing(X, train_index)
X_test = safe_indexing(X, test_index)
y_train = safe_indexing(y, train_index)
y_test = safe_indexing(y, test_index)
if sample_weight is not None:
sample_weight_train = safe_indexing(sample_weight, train_index)
else:
sample_weight_train = None

self.estimator_.fit(X_train, y_train)
if support_sample_weight:
self.estimator_.fit(X_train, y_train, sample_weight_train)
else:
self.estimator_.fit(X_train, y_train)

probs = self.estimator_.predict_proba(X_test)
classes = self.estimator_.classes_
@@ -167,8 +178,10 @@ def _fit_resample(self, X, y):
np.flatnonzero(y == target_class)[index_target_class]),
axis=0)

resampled_arrays = [safe_indexing(arr, idx_under)
for arr in (X, y, sample_weight)
if arr is not None]

if self.return_indices:
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
idx_under)
else:
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
return tuple(resampled_arrays + [idx_under])
return tuple(resampled_arrays)
12 changes: 7 additions & 5 deletions imblearn/under_sampling/_prototype_selection/_nearmiss.py
Original file line number Diff line number Diff line change
@@ -211,7 +211,7 @@ def _validate_estimator(self):
raise ValueError('Parameter `version` must be 1, 2 or 3, got'
' {}'.format(self.version))

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()

idx_under = np.empty((0, ), dtype=int)
@@ -277,8 +277,10 @@ def _fit_resample(self, X, y):
np.flatnonzero(y == target_class)[index_target_class]),
axis=0)

resampled_arrays = [safe_indexing(arr, idx_under)
for arr in (X, y, sample_weight)
if arr is not None]

if self.return_indices:
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
idx_under)
else:
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
return resampled_arrays + [idx_under]
return resampled_arrays
Original file line number Diff line number Diff line change
@@ -139,7 +139,7 @@ def _validate_estimator(self):
"'threshold_cleaning' is a value between 0 and 1."
" Got {} instead.".format(self.threshold_cleaning))

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()
enn = EditedNearestNeighbours(
sampling_strategy=self.sampling_strategy,
@@ -186,9 +186,10 @@ def _fit_resample(self, X, y):
selected_samples[union_a1_a2] = False
index_target_class = np.flatnonzero(selected_samples)

resampled_arrays = [safe_indexing(arr, index_target_class)
for arr in (X, y, sample_weight)
if arr is not None]

if self.return_indices:
return (safe_indexing(X, index_target_class), safe_indexing(
y, index_target_class), index_target_class)
else:
return (safe_indexing(X, index_target_class), safe_indexing(
y, index_target_class))
return resampled_arrays + [index_target_class]
return resampled_arrays
Original file line number Diff line number Diff line change
@@ -122,7 +122,7 @@ def _validate_estimator(self):
' inhereited from KNeighborsClassifier.'
' Got {} instead.'.format(type(self.n_neighbors)))

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
self._validate_estimator()

random_state = check_random_state(self.random_state)
@@ -164,17 +164,27 @@ def _fit_resample(self, X, y):
idx_under = np.concatenate(
(idx_under, np.flatnonzero(y == target_class)), axis=0)

X_resampled = safe_indexing(X, idx_under)
y_resampled = safe_indexing(y, idx_under)
X_res = safe_indexing(X, idx_under)
y_res = safe_indexing(y, idx_under)
sample_weight_res = (safe_indexing(sample_weight, idx_under)
if sample_weight is not None else None)

# apply Tomek cleaning
tl = TomekLinks(
sampling_strategy=self.sampling_strategy_, return_indices=True)
X_cleaned, y_cleaned, idx_cleaned = tl.fit_resample(
X_resampled, y_resampled)
resampled_arrays = tl.fit_resample(X_res, y_res, sample_weight_res)
if sample_weight_res is not None:
X_res, y_res, sample_weight_res, idx_cleaned = resampled_arrays
else:
X_res, y_res, idx_cleaned = resampled_arrays

idx_under = safe_indexing(idx_under, idx_cleaned)
""" sample_weight_res = (safe_indexing(sample_weight_res, idx_cleaned)
if sample_weight_res is not None else None) """

resampled_arrays = [arr for arr in (X_res, y_res, sample_weight_res)
if arr is not None]

if self.return_indices:
return (X_cleaned, y_cleaned, idx_under)
else:
return X_cleaned, y_cleaned
return tuple(resampled_arrays + [idx_under])
return tuple(resampled_arrays)
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@

from __future__ import division

from itertools import chain

import numpy as np

from sklearn.utils import check_X_y, check_random_state, safe_indexing
@@ -92,7 +94,7 @@ def _check_X_y(X, y):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None)
return X, y, binarize_y

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
random_state = check_random_state(self.random_state)

idx_under = np.empty((0, ), dtype=int)
@@ -112,8 +114,10 @@ def _fit_resample(self, X, y):
np.flatnonzero(y == target_class)[index_target_class]),
axis=0)

resampled_arrays = [safe_indexing(arr, idx_under)
for arr in (X, y, sample_weight)
if arr is not None]

if self.return_indices:
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
idx_under)
else:
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
return tuple(resampled_arrays + [idx_under])
return tuple(resampled_arrays)
12 changes: 7 additions & 5 deletions imblearn/under_sampling/_prototype_selection/_tomek_links.py
Original file line number Diff line number Diff line change
@@ -134,7 +134,7 @@ def is_tomek(y, nn_index, class_type):

return links

def _fit_resample(self, X, y):
def _fit_resample(self, X, y, sample_weight=None):
# check for deprecated random_state
if self.random_state is not None:
deprecate_parameter(self, '0.4', 'random_state')
@@ -147,8 +147,10 @@ def _fit_resample(self, X, y):
links = self.is_tomek(y, nns, self.sampling_strategy_)
idx_under = np.flatnonzero(np.logical_not(links))

resampled_arrays = [safe_indexing(arr, idx_under)
for arr in (X, y, sample_weight)
if arr is not None]

if self.return_indices:
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
idx_under)
else:
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under))
return tuple(resampled_arrays + [idx_under])
return tuple(resampled_arrays)
24 changes: 22 additions & 2 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
from sklearn.datasets import make_classification
from sklearn.cluster import KMeans
from sklearn.preprocessing import label_binarize
from sklearn.utils import check_consistent_length
from sklearn.utils.estimator_checks import check_estimator \
as sklearn_check_estimator, check_parameters_default_constructible
from sklearn.utils.testing import assert_allclose
@@ -36,6 +37,8 @@
from imblearn.utils.testing import warns

DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE']
DONT_SUPPORT_SAMPLE_WEIGHT = ['EasyEnsemble', 'BalanceCascade',
'FunctionSampler']
SUPPORT_STRING = ['RandomUnderSampler', 'RandomOverSampler']


@@ -73,6 +76,7 @@ def _yield_sampler_checks(name, Estimator):
yield check_samplers_pandas
yield check_samplers_multiclass_ova
yield check_samplers_preserve_dtype
yield check_samplers_resample_sample_weight


def _yield_all_checks(name, estimator):
@@ -298,7 +302,7 @@ def check_samplers_pandas(name, Sampler):
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0)
X_pd, y_pd = pd.DataFrame(X), pd.Series(y)
X_pd = pd.DataFrame(X)
sampler = Sampler()
if isinstance(Sampler(), SMOTE):
samplers = [
@@ -314,7 +318,7 @@ def check_samplers_pandas(name, Sampler):

for sampler in samplers:
set_random_state(sampler)
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y_pd)
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y)
X_res, y_res = sampler.fit_resample(X, y)
assert_allclose(X_res_pd, X_res)
assert_allclose(y_res_pd, y_res)
@@ -358,3 +362,19 @@ def check_samplers_preserve_dtype(name, Sampler):
X_res, y_res = sampler.fit_resample(X, y)
assert X.dtype == X_res.dtype
assert y.dtype == y_res.dtype


def check_samplers_resample_sample_weight(name, Sampler):
# check that X, y, and an additional sample_weight array can be resampled.
if name not in DONT_SUPPORT_SAMPLE_WEIGHT:
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0)
sample_weight = np.ones_like(y)
sampler = Sampler()
set_random_state(sampler)
X_res, y_res, sw_res = sampler.fit_resample(X, y, sample_weight)
assert check_consistent_length(X_res, y_res, sw_res) is None