diff --git a/menelaus/injection/label_manipulation.py b/menelaus/injection/label_manipulation.py index 1de5e0fd..beb870d3 100644 --- a/menelaus/injection/label_manipulation.py +++ b/menelaus/injection/label_manipulation.py @@ -158,7 +158,8 @@ class probability for 1 or more desired classes # locate each class in window for cls in all_classes: - cls_idx = np.where(data[:, target_col] == cls)[0] + + cls_idx = np.where(ret[:, target_col] == cls)[0] cls_idx = cls_idx[(cls_idx < to_index) & (cls_idx >= from_index)] # each member has p_class / class_size chance, represented as bool to avoid div/0 @@ -178,7 +179,7 @@ class probability for 1 or more desired classes sample_idxs = np.random.choice( sample_idxs_grouped, to_index - from_index, True, self._p_distribution ) - ret[from_index:to_index] = data[sample_idxs] + ret[from_index:to_index] = ret[sample_idxs] # handle data type and return ret = self._postprocess(ret)