From 170b493e1f77d2d52765d34fbb728d868c261694 Mon Sep 17 00:00:00 2001 From: Anmol Srivastava Date: Mon, 31 Jul 2023 10:45:05 -0400 Subject: [PATCH] Fix `LabelProbabilityInjector` data type bug (#148) * fix injector to use preprocessed data * remove extra cls_idx definition --- menelaus/injection/label_manipulation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/menelaus/injection/label_manipulation.py b/menelaus/injection/label_manipulation.py index 1de5e0fd..beb870d3 100644 --- a/menelaus/injection/label_manipulation.py +++ b/menelaus/injection/label_manipulation.py @@ -158,7 +158,8 @@ class probability for 1 or more desired classes # locate each class in window for cls in all_classes: - cls_idx = np.where(data[:, target_col] == cls)[0] + + cls_idx = np.where(ret[:, target_col] == cls)[0] cls_idx = cls_idx[(cls_idx < to_index) & (cls_idx >= from_index)] # each member has p_class / class_size chance, represented as bool to avoid div/0 @@ -178,7 +179,7 @@ class probability for 1 or more desired classes sample_idxs = np.random.choice( sample_idxs_grouped, to_index - from_index, True, self._p_distribution ) - ret[from_index:to_index] = data[sample_idxs] + ret[from_index:to_index] = ret[sample_idxs] # handle data type and return ret = self._postprocess(ret)