diff --git a/MaskHIT_Prep/03_get_patches_meta.py b/MaskHIT_Prep/03_get_patches_meta.py index 3699b94..085ad1b 100644 --- a/MaskHIT_Prep/03_get_patches_meta.py +++ b/MaskHIT_Prep/03_get_patches_meta.py @@ -108,6 +108,9 @@ def extract_tcga_info(path): # TODO: This should work but need to check with real data. df['id_patient'], df['type'] = zip(*df['file'].apply(extract_tcga_info)) else: + #TODO: This behavior seems dangerous. Shouldn't we have a separated script + #and describe the input files definition so we don't need to update code + #based on a dataset? df['id_patient'] = df['id_svs'] df['type'] = '01Z' diff --git a/MaskHIT_Prep/04_feature_extraction.py b/MaskHIT_Prep/04_feature_extraction.py index 00cbe86..da80965 100644 --- a/MaskHIT_Prep/04_feature_extraction.py +++ b/MaskHIT_Prep/04_feature_extraction.py @@ -100,6 +100,21 @@ def main(): model = create_model(num_layers, True, 1) model.fc = nn.Identity() + + # User-defined feature extractor + if config.patch.pretrained_model_path: + # Load the pretrained model state dict + pretrained_dict = torch.load(config.patch.pretrained_model_path) + + if isinstance(pretrained_dict, nn.Module): + # If it's a full model, convert to a state dict from it + pretrained_dict = pretrained_dict.state_dict() + + # Filter out `fc` layer from the pretrained_dict + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in ['fc.weight', 'fc.bias']} + # Load the filtered state dict into your model + model.load_state_dict(pretrained_dict, strict=False) + model.cuda() model.eval() diff --git a/MaskHIT_Prep/05_post_process.py b/MaskHIT_Prep/05_post_process.py index 98d6452..89dbd9a 100644 --- a/MaskHIT_Prep/05_post_process.py +++ b/MaskHIT_Prep/05_post_process.py @@ -153,6 +153,7 @@ def get_counts_chunck(data, delta): df_svs = df_svs.merge(df_sum, on='id_svs', how='inner') df_svs = df_svs.loc[df_svs.valid > 25] df_svs.to_pickle(svs_meta) + df_svs.head(2).to_csv(str(svs_meta).replace('.pickle', '_preview.txt'), index=False, sep='\t') print("") print("=" * 40) diff --git a/MaskHIT_Prep/configs/config_default.yaml b/MaskHIT_Prep/configs/config_default.yaml index 92e32c4..c6b880c 100644 --- a/MaskHIT_Prep/configs/config_default.yaml +++ b/MaskHIT_Prep/configs/config_default.yaml @@ -10,6 +10,12 @@ # (custom): more likely to need modification for each user #--------------------------------------------------- +#TODO: Clarify which values are minimally required. +#Suggestions: +# for required values. +# null for optional value. When use null, don't add typehint. +# For null-able values, we should not add typehint. +# Instead, note the type in the comment: (optional, string) study: ## Study Configuration @@ -22,23 +28,29 @@ study: # (1) # Indicates a dataset class for json processing pipeline # See datasets/ directory for available classes. - # Available: TCGA, GramStains, IBD - # Please implement a dataset file if necessary + # Currently available: TCGA, GramStains, IBD, RCCp + # Please implement a dataset file if necessary. dataset_type: !!str # (0,1) # Path to the folder containing svs data # Example: /pool2/data/WSI_TCGA/Colorectal # Set to "tiles" if using tile preprocessing + # This can be left unset if you have a JAON file with each slide path + # and properly parsed in your dataset class. svs_dir: !!str # (1) # Path to the json file with dataset info. + # Please read README to prepare a JSON file. json_path: !!str # (1) - # Dataset field to use with stratification. - stratify_by: !!str status + # Dataset field to use with stratification (Optional[String]). + # Set null for no stratification. + #TODO: Change the default value from 'status'. It should be + #and add examples?. + stratify_by: null # (1) # Number of folds for kfold cross validation @@ -100,6 +112,7 @@ patch: # (2, 3, 5) # Path to the pickle file containing metadata for the .svs files # Example: meta/dhmc_rcc_svs.pickle + # TODO: This is a saving path, right? We should clarify that this is not input. svs_meta: !!str # (2, 4) @@ -108,7 +121,17 @@ patch: # (4, 5) # Backbone model for feature extraction - backbone: resnet_18 + # Currently, only models from the ResNet family are available. + # Default: resnet_18, using ImageNet features. + # Replace with another ResNet model if necessary. + backbone: !!str resnet_18 + + # (4, 5) + # Path to a custom feature extractor snapshot (optional, string) + # Optional: Set this only if you have a specialized feature extractor, + # possibly pretrained on a relevant dataset. + # Default: Uses the standard ImageNet pretrained model. + pretrained_model_path: null # (4) # Batch size for processing slide patches @@ -122,6 +145,7 @@ patch: # (2) # Mask magnification for color filtering during patch extraction. # A lower magnification will run faster but will be less precise + #TODO: Ask why this value? mag_mask: !!float 0.3125 feature: