Skip to content

Commit

Permalink
Fix LabelProbabilityInjector data type bug (#148)
Browse files Browse the repository at this point in the history
* fix injector to use preprocessed data

* remove extra cls_idx definition
  • Loading branch information
Anmol-Srivastava committed Jul 31, 2023
1 parent ea015f1 commit 170b493
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions menelaus/injection/label_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class probability for 1 or more desired classes

# locate each class in window
for cls in all_classes:
cls_idx = np.where(data[:, target_col] == cls)[0]

cls_idx = np.where(ret[:, target_col] == cls)[0]
cls_idx = cls_idx[(cls_idx < to_index) & (cls_idx >= from_index)]

# each member has p_class / class_size chance, represented as bool to avoid div/0
Expand All @@ -178,7 +179,7 @@ class probability for 1 or more desired classes
sample_idxs = np.random.choice(
sample_idxs_grouped, to_index - from_index, True, self._p_distribution
)
ret[from_index:to_index] = data[sample_idxs]
ret[from_index:to_index] = ret[sample_idxs]

# handle data type and return
ret = self._postprocess(ret)
Expand Down

0 comments on commit 170b493

Please sign in to comment.