Skip to content

Commit a001c92

Browse files
plbenvenistejoshuacwnewton
authored andcommitted
predict_from_raw_data.py: Add return_logits_per_fold arg
1 parent 511204e commit a001c92

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

nnunetv2/inference/predict_from_raw_data.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,8 @@ def predict_from_data_iterator(self,
425425
def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
426426
segmentation_previous_stage: np.ndarray = None,
427427
output_file_truncated: str = None,
428-
save_or_return_probabilities: bool = False):
428+
save_or_return_probabilities: bool = False,
429+
return_logits_per_fold: bool = False):
429430
"""
430431
WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON.
431432
@@ -449,7 +450,11 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
449450

450451
if self.verbose:
451452
print('predicting')
452-
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu()
453+
# For getting logits per fold, cpu extraction has to be done for each list element
454+
if return_logits_per_fold:
455+
predicted_logits = [ elem.cpu() for elem in self.predict_logits_from_preprocessed_data(dct['data'], return_logits_per_fold=return_logits_per_fold)]
456+
else:
457+
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data'], return_logits_per_fold=return_logits_per_fold).cpu()
453458

454459
if self.verbose:
455460
print('resampling to original shape')
@@ -458,19 +463,34 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
458463
self.plans_manager, self.dataset_json, output_file_truncated,
459464
save_or_return_probabilities)
460465
else:
461-
ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
462-
self.configuration_manager,
463-
self.label_manager,
464-
dct['data_properties'],
465-
return_probabilities=
466-
save_or_return_probabilities)
466+
if return_logits_per_fold:
467+
ret = []
468+
for elem in predicted_logits:
469+
ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(
470+
elem, self.plans_manager,
471+
self.configuration_manager,
472+
self.label_manager,
473+
dct['data_properties'],
474+
return_probabilities=save_or_return_probabilities))
475+
476+
477+
else:
478+
ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
479+
self.configuration_manager,
480+
self.label_manager,
481+
dct['data_properties'],
482+
return_probabilities=
483+
save_or_return_probabilities)
467484
if save_or_return_probabilities:
485+
if return_logits_per_fold:
486+
segs, probs = zip(*ret)
487+
ret = [list(segs), list(probs)]
468488
return ret[0], ret[1]
469489
else:
470490
return ret
471491

472492
@torch.inference_mode()
473-
def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
493+
def predict_logits_from_preprocessed_data(self, data: torch.Tensor, return_logits_per_fold: bool = False) -> torch.Tensor:
474494
"""
475495
IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
476496
TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!
@@ -481,6 +501,8 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
481501
n_threads = torch.get_num_threads()
482502
torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
483503
prediction = None
504+
if return_logits_per_fold:
505+
prediction = []
484506

485507
for params in self.list_of_parameters:
486508

@@ -495,10 +517,12 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
495517
# this actually saves computation time
496518
if prediction is None:
497519
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
520+
if return_logits_per_fold:
521+
prediction.append(self.predict_sliding_window_return_logits(data).to('cpu'))
498522
else:
499523
prediction += self.predict_sliding_window_return_logits(data).to('cpu')
500524

501-
if len(self.list_of_parameters) > 1:
525+
if len(self.list_of_parameters) > 1 and not return_logits_per_fold:
502526
prediction /= len(self.list_of_parameters)
503527

504528
if self.verbose: print('Prediction done')

0 commit comments

Comments
 (0)