Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

created todo comment for stratification fix #29

Merged
merged 3 commits into from
Aug 27, 2024
Merged

created todo comment for stratification fix #29

merged 3 commits into from
Aug 27, 2024

Conversation

lshpaner
Copy link
Collaborator

Description:

Currently, the train_val_test_split method allows for stratification either by y (stratify_y) or by specified columns (stratify_cols), but not both at the same time. There are use cases where stratification by both the target variable (y) and specific columns is necessary to ensure a balanced and representative split across different data segments.

Proposed Enhancement:

Modify the method to support simultaneous stratification by both y and stratify_cols. This can be achieved by combining the stratification keys or implementing logic that ensures both y and the specified columns are considered during the stratification process.

Current Method Implementation:

def train_val_test_split(
    self,
    X,
    y,
    stratify_y,
    train_size,
    validation_size,
    test_size,
    random_state,
    stratify_cols,
    calibrate,
):

    # if calibrate:
    #     X = X.join(self.dropped_strat_cols)
    # Determine the stratify parameter based on stratify and stratify_cols
    if stratify_cols:
        # Creating stratification columns out of stratify_cols list
        stratify_key = X[stratify_cols]
    elif stratify_y:
        stratify_key = y
    else:
        stratify_key = None

    if self.drop_strat_feat:
        self.dropped_strat_cols = X[self.drop_strat_feat]
        X = X.drop(columns=self.drop_strat_feat)

    X_train, X_valid_test, y_train, y_valid_test = train_test_split(
        X,
        y,
        test_size=1 - train_size,
        stratify=stratify_key,  # Use stratify_key here
        random_state=random_state,
    )

    # Determine the proportion of validation to test size in the remaining dataset
    proportion = test_size / (validation_size + test_size)

    if stratify_cols:
        strat_key_val_test = X_valid_test[stratify_cols]
    elif stratify_y:
        strat_key_val_test = y_valid_test
    else:
        strat_key_val_test = None

    # Further split (validation + test) set into validation and test sets
    X_valid, X_test, y_valid, y_test = train_test_split(
        X_valid_test,
        y_valid_test,
        test_size=proportion,
        stratify=strat_key_val_test,
        random_state=random_state,
    )

    return X_train, X_valid, X_test, y_train, y_valid, y_test

Copy link
Collaborator Author

@lshpaner lshpaner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elemets @panas89 I went ahead and debugged these changes using the AIDS research example code notebook by doing the following:

  • passing stratify_cols=["gender", "race"],
  • passing stratify_y=True,
  • placing two print statements at the end of the two respective if stratify_cols and stratify_y: blocks:
    • first: print(stratify_key)
    • second: print(strat_key_val_test)

No errors, exceptions, or warnings were thrown, thus yielding the following output of a successful run:

      gender  race  cid
0          0     0    0
1          0     0    1
2          1     0    0
3          1     0    0
4          1     0    0
...      ...   ...  ...
2134       1     0    0
2135       1     1    0
2136       1     1    0
2137       1     0    1
2138       1     0    0

[2139 rows x 3 columns]
      gender  race  cid
1943       0     1    0
1583       1     0    0
1891       1     0    0
1316       1     0    1
1117       1     0    1
...      ...   ...  ...
6          1     0    1
1135       1     1    0
125        1     0    0
359        1     0    1
492        1     0    0

[856 rows x 3 columns]
100%|██████████| 324/324 [01:20<00:00,  4.02it/s]Best score/param set found on validation set:
{'params': {'selectKBest__k': 6,
            'xgb__colsample_bytree': 1.0,
            'xgb__early_stopping_rounds': 10,
            'xgb__eval_metric': 'logloss',
            'xgb__learning_rate': 0.01,
            'xgb__max_depth': 7,
            'xgb__n_estimators': 99,
            'xgb__subsample': 1.0},
 'score': 0.9262226970560304}
Best roc_auc: 0.926 

The print statement confirms that the stratify_y col ("cid") is being concatenated to the stratify_cols DataFrame subset for x ("gender", "race"), thus showing that the stratification for both cases is now working.

@@ -941,8 +946,14 @@ def train_val_test_split(
# if calibrate:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unnecessary comments and TODO

@panas89
Copy link
Collaborator

panas89 commented Aug 27, 2024

Line 1207 function get_cross_validate(), variable stratify is redundant

@panas89
Copy link
Collaborator

panas89 commented Aug 27, 2024

We need to make a note in the documentation that stratify_cols cannot be used when using cross_validation

@panas89
Copy link
Collaborator

panas89 commented Aug 27, 2024

Checked with debugger code changes! works!

@lshpaner lshpaner merged commit f9f7cff into main Aug 27, 2024
@lshpaner lshpaner deleted the stratify_fix branch August 27, 2024 18:40
@lshpaner
Copy link
Collaborator Author

We need to make a note in the documentation that stratify_cols cannot be used when using cross_validation

done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants