@@ -425,7 +425,8 @@ def predict_from_data_iterator(self,
425425 def predict_single_npy_array (self , input_image : np .ndarray , image_properties : dict ,
426426 segmentation_previous_stage : np .ndarray = None ,
427427 output_file_truncated : str = None ,
428- save_or_return_probabilities : bool = False ):
428+ save_or_return_probabilities : bool = False ,
429+ return_logits_per_fold : bool = False ):
429430 """
430431 WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON.
431432
@@ -449,7 +450,11 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
449450
450451 if self .verbose :
451452 print ('predicting' )
452- predicted_logits = self .predict_logits_from_preprocessed_data (dct ['data' ]).cpu ()
453+ # For getting logits per fold, cpu extraction has to be done for each list element
454+ if return_logits_per_fold :
455+ predicted_logits = [ elem .cpu () for elem in self .predict_logits_from_preprocessed_data (dct ['data' ], return_logits_per_fold = return_logits_per_fold )]
456+ else :
457+ predicted_logits = self .predict_logits_from_preprocessed_data (dct ['data' ], return_logits_per_fold = return_logits_per_fold ).cpu ()
453458
454459 if self .verbose :
455460 print ('resampling to original shape' )
@@ -458,19 +463,34 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
458463 self .plans_manager , self .dataset_json , output_file_truncated ,
459464 save_or_return_probabilities )
460465 else :
461- ret = convert_predicted_logits_to_segmentation_with_correct_shape (predicted_logits , self .plans_manager ,
462- self .configuration_manager ,
463- self .label_manager ,
464- dct ['data_properties' ],
465- return_probabilities =
466- save_or_return_probabilities )
466+ if return_logits_per_fold :
467+ ret = []
468+ for elem in predicted_logits :
469+ ret .append (convert_predicted_logits_to_segmentation_with_correct_shape (
470+ elem , self .plans_manager ,
471+ self .configuration_manager ,
472+ self .label_manager ,
473+ dct ['data_properties' ],
474+ return_probabilities = save_or_return_probabilities ))
475+
476+
477+ else :
478+ ret = convert_predicted_logits_to_segmentation_with_correct_shape (predicted_logits , self .plans_manager ,
479+ self .configuration_manager ,
480+ self .label_manager ,
481+ dct ['data_properties' ],
482+ return_probabilities =
483+ save_or_return_probabilities )
467484 if save_or_return_probabilities :
485+ if return_logits_per_fold :
486+ segs , probs = zip (* ret )
487+ ret = [list (segs ), list (probs )]
468488 return ret [0 ], ret [1 ]
469489 else :
470490 return ret
471491
472492 @torch .inference_mode ()
473- def predict_logits_from_preprocessed_data (self , data : torch .Tensor ) -> torch .Tensor :
493+ def predict_logits_from_preprocessed_data (self , data : torch .Tensor , return_logits_per_fold : bool = False ) -> torch .Tensor :
474494 """
475495 IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
476496 TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!
@@ -481,6 +501,8 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
481501 n_threads = torch .get_num_threads ()
482502 torch .set_num_threads (default_num_processes if default_num_processes < n_threads else n_threads )
483503 prediction = None
504+ if return_logits_per_fold :
505+ prediction = []
484506
485507 for params in self .list_of_parameters :
486508
@@ -495,10 +517,12 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
495517 # this actually saves computation time
496518 if prediction is None :
497519 prediction = self .predict_sliding_window_return_logits (data ).to ('cpu' )
520+ if return_logits_per_fold :
521+ prediction .append (self .predict_sliding_window_return_logits (data ).to ('cpu' ))
498522 else :
499523 prediction += self .predict_sliding_window_return_logits (data ).to ('cpu' )
500524
501- if len (self .list_of_parameters ) > 1 :
525+ if len (self .list_of_parameters ) > 1 and not return_logits_per_fold :
502526 prediction /= len (self .list_of_parameters )
503527
504528 if self .verbose : print ('Prediction done' )
0 commit comments