|
| 1 | +""" Helper functions for training image classification models with AutoGluon and using cross-validation. """ |
| 2 | + |
| 3 | +import sys |
| 4 | + |
| 5 | +sys.path.insert(0, "../") |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +import pickle |
| 10 | +import datetime |
| 11 | +import os |
| 12 | +from pathlib import Path |
| 13 | +from typing import Dict, Tuple |
| 14 | + |
| 15 | +import cleanlab |
| 16 | +from autogluon.vision import ImagePredictor, ImageDataset |
| 17 | +from sklearn.model_selection import StratifiedKFold |
| 18 | + |
| 19 | + |
| 20 | +def cross_val_predict_autogluon_image_dataset( |
| 21 | + dataset: ImageDataset, |
| 22 | + out_folder: str = "./cross_val_predict_run/", |
| 23 | + *, |
| 24 | + n_splits: int = 5, |
| 25 | + model_params: Dict = {"epochs": 1, "holdout_frac": 0.2}, |
| 26 | + ngpus_per_trial: int = 1, |
| 27 | + time_limit: int = 7200, |
| 28 | + random_state: int = 123, |
| 29 | + verbose: int = 0, |
| 30 | +) -> Tuple: |
| 31 | + """Run stratified K-folds cross-validation with AutoGluon image model. |
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | + dataset : gluoncv.auto.data.dataset.ImageClassificationDataset |
| 35 | + AutoGluon dataset for image classification. |
| 36 | + out_folder : str, default="./cross_val_predict_run/" |
| 37 | + Folder to save cross-validation results. Save results after each split (each K in K-fold). |
| 38 | + n_splits : int, default=3 |
| 39 | + Number of splits for stratified K-folds cross-validation. |
| 40 | + model_params : Dict, default={"epochs": 1, "holdout_frac": 0.2} |
| 41 | + Passed into AutoGluon's `ImagePredictor().fit()` method. |
| 42 | + ngpus_per_trial : int, default=1 |
| 43 | + Passed into AutoGluon's `ImagePredictor().fit()` method. |
| 44 | + time_limit : int, default=7200 |
| 45 | + Passed into AutoGluon's `ImagePredictor().fit()` method. |
| 46 | + random_state : int, default=123 |
| 47 | + Passed into AutoGluon's `ImagePredictor().fit()` method. |
| 48 | + Returns |
| 49 | + ------- |
| 50 | + None |
| 51 | + """ |
| 52 | + |
| 53 | + # stratified K-folds |
| 54 | + skf = StratifiedKFold(n_splits=n_splits, shuffle=False) |
| 55 | + skf_splits = [ |
| 56 | + [train_index, test_index] |
| 57 | + for train_index, test_index in skf.split(X=dataset, y=dataset.label) |
| 58 | + ] |
| 59 | + |
| 60 | + for split_num, split in enumerate(skf_splits): |
| 61 | + |
| 62 | + print("----") |
| 63 | + print(f"Running Cross-Validation on Split: {split_num}") |
| 64 | + |
| 65 | + # split from stratified K-folds |
| 66 | + train_index, test_index = split |
| 67 | + |
| 68 | + # init model |
| 69 | + predictor = ImagePredictor(verbosity=0) |
| 70 | + |
| 71 | + # train model on train indices in this split |
| 72 | + predictor.fit( |
| 73 | + train_data=dataset.iloc[train_index], |
| 74 | + ngpus_per_trial=ngpus_per_trial, |
| 75 | + hyperparameters=model_params, |
| 76 | + time_limit=time_limit, |
| 77 | + random_state=random_state, |
| 78 | + ) |
| 79 | + |
| 80 | + # predict on test indices in this split |
| 81 | + |
| 82 | + # predicted probabilities for test split |
| 83 | + pred_probs = predictor.predict_proba( |
| 84 | + data=dataset.iloc[test_index], as_pandas=False |
| 85 | + ) |
| 86 | + |
| 87 | + # predicted features (aka embeddings) for test split |
| 88 | + # why does autogluon predict_feature return array of array for the features? |
| 89 | + # need to use stack to convert to 2d array (https://stackoverflow.com/questions/50971123/converty-numpy-array-of-arrays-to-2d-array) |
| 90 | + pred_features = np.stack( |
| 91 | + predictor.predict_feature(data=dataset.iloc[test_index], as_pandas=False)[ |
| 92 | + :, 0 |
| 93 | + ] |
| 94 | + ) |
| 95 | + |
| 96 | + # save output of model + split in pickle file |
| 97 | + |
| 98 | + out_subfolder = f"{out_folder}split_{split_num}/" |
| 99 | + |
| 100 | + try: |
| 101 | + os.makedirs(out_subfolder, exist_ok=False) |
| 102 | + except OSError: |
| 103 | + print(f"Folder {out_subfolder} already exists!") |
| 104 | + finally: |
| 105 | + |
| 106 | + # save to pickle files |
| 107 | + |
| 108 | + get_pickle_file_name = ( |
| 109 | + lambda object_name: f"{out_subfolder}_{object_name}_split_{split_num}" |
| 110 | + ) |
| 111 | + |
| 112 | + _save_to_pickle(pred_probs, get_pickle_file_name("test_pred_probs")) |
| 113 | + _save_to_pickle(pred_features, get_pickle_file_name("test_pred_features")) |
| 114 | + _save_to_pickle( |
| 115 | + dataset.iloc[test_index].label.values, |
| 116 | + get_pickle_file_name("test_labels"), |
| 117 | + ) |
| 118 | + _save_to_pickle( |
| 119 | + dataset.iloc[test_index].image.values, |
| 120 | + get_pickle_file_name("test_image_files"), |
| 121 | + ) |
| 122 | + _save_to_pickle(test_index, get_pickle_file_name("test_indices")) |
| 123 | + |
| 124 | + # save model trained on this split |
| 125 | + predictor.save(f"{out_subfolder}predictor.ag") |
| 126 | + |
| 127 | + return predictor |
| 128 | + |
| 129 | + |
| 130 | +def _save_to_pickle(object, pickle_file_name): |
| 131 | + """Save object to pickle file""" |
| 132 | + |
| 133 | + print(f"Saving {pickle_file_name}") |
| 134 | + |
| 135 | + # save to pickle file |
| 136 | + with open(pickle_file_name, "wb") as handle: |
| 137 | + pickle.dump(object, handle, protocol=pickle.HIGHEST_PROTOCOL) |
| 138 | + |
| 139 | + |
| 140 | +def train_model( |
| 141 | + model_type, |
| 142 | + data, |
| 143 | + model_results_folder, |
| 144 | + *, |
| 145 | + num_cv_folds=5, |
| 146 | + verbose=0, |
| 147 | + epochs=1, |
| 148 | + holdout_frac=0.2, |
| 149 | + time_limit=60, |
| 150 | + random_state=123, |
| 151 | +): |
| 152 | + """Trains AutoGluon image model with stratified K-folds cross-validation and saves data in model_results_folder. |
| 153 | + Parameters |
| 154 | + ---------- |
| 155 | + model_type: str |
| 156 | + Type of backend architecture for Autogluon |
| 157 | +
|
| 158 | + data : gluoncv.auto.data.dataset.ImageClassificationDataset |
| 159 | + AutoGluon dataset for image classification. |
| 160 | + model_results_folder : str |
| 161 | + Folder to save cross-validation results. Save results after each split (each K in K-fold). |
| 162 | + num_cv_folds : int, default=5 |
| 163 | + Number of splits for stratified K-folds cross-validation. |
| 164 | + model_params : Dict, default={"epochs": 1, "holdout_frac": 0.2, "verbose": 1} |
| 165 | + Passed into AutoGluon's `ImagePredictor().fit()` method. |
| 166 | + time_limit : int, default=7200 |
| 167 | + Passed into AutoGluon's `ImagePredictor().fit()` method. |
| 168 | + random_state : int, default=123 |
| 169 | + Passed into AutoGluon's `ImagePredictor().fit()` method. |
| 170 | + Returns |
| 171 | + ------- |
| 172 | + None |
| 173 | + """ |
| 174 | + |
| 175 | + # run xvalidation |
| 176 | + print("----") |
| 177 | + print(f"Running cross-validation for model: {model_type}") |
| 178 | + |
| 179 | + MODEL_PARAMS = { |
| 180 | + "model": model_type, |
| 181 | + "epochs": epochs, |
| 182 | + "holdout_frac": holdout_frac, |
| 183 | + } |
| 184 | + |
| 185 | + # results of cross-validation will be saved to pickle files for each model/fold |
| 186 | + predictor = cross_val_predict_autogluon_image_dataset( |
| 187 | + dataset=data, |
| 188 | + out_folder=f"{model_results_folder}_{model_type}/", # save results of cross-validation in pickle files for each fold |
| 189 | + n_splits=num_cv_folds, |
| 190 | + model_params=MODEL_PARAMS, |
| 191 | + time_limit=time_limit, |
| 192 | + random_state=random_state, |
| 193 | + verbose=verbose, |
| 194 | + ) |
| 195 | + return predictor |
| 196 | + |
| 197 | + |
| 198 | +# load pickle file util |
| 199 | +def _load_pickle(pickle_file_name, verbose=1): |
| 200 | + """Load pickle file""" |
| 201 | + if verbose: |
| 202 | + print(f"Loading {pickle_file_name}") |
| 203 | + with open(pickle_file_name, "rb") as handle: |
| 204 | + out = pickle.load(handle) |
| 205 | + return out |
| 206 | + |
| 207 | + |
| 208 | +def sum_xval_folds(model, model_results_folder, num_cv_folds=5, verbose=1, **kwargs): |
| 209 | + # get original label name to idx mapping |
| 210 | + label_name_to_idx_map = { |
| 211 | + "airplane": 0, |
| 212 | + "automobile": 1, |
| 213 | + "bird": 2, |
| 214 | + "cat": 3, |
| 215 | + "deer": 4, |
| 216 | + "dog": 5, |
| 217 | + "frog": 6, |
| 218 | + "horse": 7, |
| 219 | + "ship": 8, |
| 220 | + "truck": 9, |
| 221 | + } |
| 222 | + results_list = [] |
| 223 | + |
| 224 | + # get shapes of arrays (this is dumb way to do it what is better?) |
| 225 | + pred_probs_shape = [] |
| 226 | + features_shape = [] |
| 227 | + labels_shape = [] |
| 228 | + for split_num in range(num_cv_folds): |
| 229 | + |
| 230 | + out_subfolder = f"{model_results_folder}_{model}/split_{split_num}/" |
| 231 | + |
| 232 | + # pickle file name to read |
| 233 | + get_pickle_file_name = ( |
| 234 | + lambda object_name: f"{out_subfolder}_{object_name}_split_{split_num}" |
| 235 | + ) |
| 236 | + |
| 237 | + # NOTE: the "test_" prefix in the pickle name correspond to the "test" split during cross-validation. |
| 238 | + pred_probs_split = _load_pickle( |
| 239 | + get_pickle_file_name("test_pred_probs"), verbose=verbose |
| 240 | + ) |
| 241 | + labels_split = _load_pickle( |
| 242 | + get_pickle_file_name("test_labels"), verbose=verbose |
| 243 | + ) |
| 244 | + test_pred_features_split = _load_pickle( |
| 245 | + get_pickle_file_name("test_pred_features"), verbose=verbose |
| 246 | + ) |
| 247 | + |
| 248 | + pred_probs_shape.append(pred_probs_split) |
| 249 | + features_shape.append(test_pred_features_split) |
| 250 | + labels_shape.append(labels_split) |
| 251 | + |
| 252 | + pred_probs_shape = np.vstack(pred_probs_shape) |
| 253 | + labels_shape = np.hstack(labels_shape) |
| 254 | + |
| 255 | + pred_probs = np.zeros_like(pred_probs_shape) |
| 256 | + labels = np.zeros_like(labels_shape) |
| 257 | + images = np.empty((labels_shape.shape[0],), dtype=object) |
| 258 | + |
| 259 | + for split_num in range(num_cv_folds): |
| 260 | + |
| 261 | + out_subfolder = f"{model_results_folder}_{model}/split_{split_num}/" |
| 262 | + |
| 263 | + # pickle file name to read |
| 264 | + get_pickle_file_name = ( |
| 265 | + lambda object_name: f"{out_subfolder}_{object_name}_split_{split_num}" |
| 266 | + ) |
| 267 | + |
| 268 | + # NOTE: the "test_" prefix in the pickle name correspond to the "test" split during cross-validation. |
| 269 | + pred_probs_split = _load_pickle( |
| 270 | + get_pickle_file_name("test_pred_probs"), verbose=verbose |
| 271 | + ) |
| 272 | + labels_split = _load_pickle( |
| 273 | + get_pickle_file_name("test_labels"), verbose=verbose |
| 274 | + ) |
| 275 | + images_split = _load_pickle( |
| 276 | + get_pickle_file_name("test_image_files"), verbose=verbose |
| 277 | + ) |
| 278 | + indices_split = _load_pickle( |
| 279 | + get_pickle_file_name("test_indices"), verbose=verbose |
| 280 | + ) |
| 281 | + indices_split = np.array(indices_split) |
| 282 | + |
| 283 | + pred_probs[indices_split] = pred_probs_split |
| 284 | + labels[indices_split] = labels_split |
| 285 | + images[indices_split] = np.array(images_split) |
| 286 | + |
| 287 | + return pred_probs, labels, images |
0 commit comments