Skip to content

Commit

Permalink
[fix] Add first draft of the PR for issue#349
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Jan 28, 2022
1 parent 224aa44 commit 1431980
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 245 deletions.
34 changes: 20 additions & 14 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import warnings
from multiprocessing.queues import Queue
from typing import Any, Dict, List, Optional, Tuple, Union, no_type_check
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union, no_type_check

from ConfigSpace import Configuration

Expand Down Expand Up @@ -54,6 +54,16 @@
]


class EvaluationResults(NamedTuple):
opt_loss: Dict[str, float]
train_loss: Dict[str, float]
opt_pred: np.ndarray
status: StatusType
valid_pred: Optional[np.ndarray] = None
test_pred: Optional[np.ndarray] = None
additional_run_info: Optional[Dict] = None


class MyTraditionalTabularClassificationPipeline(BaseEstimator):
"""
A wrapper class that holds a pipeline for traditional classification.
Expand Down Expand Up @@ -662,11 +672,7 @@ def _loss(self, y_true: np.ndarray, y_hat: np.ndarray) -> Dict[str, float]:
return calculate_loss(
y_true, y_hat, self.task_type, metrics)

def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
opt_pred: np.ndarray, valid_pred: Optional[np.ndarray],
test_pred: Optional[np.ndarray], additional_run_info: Optional[Dict],
file_output: bool, status: StatusType
) -> Optional[Tuple[float, float, int, Dict]]:
def finish_up(self, results: EvaluationResults, file_output: bool) -> Optional[Tuple[float, float, int, Dict]]:
"""This function does everything necessary after the fitting is done:
* predicting
Expand Down Expand Up @@ -711,37 +717,37 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],

if file_output:
loss_, additional_run_info_ = self.file_output(
opt_pred, valid_pred, test_pred,
results.opt_pred, results.valid_pred, results.test_pred,
)
else:
loss_ = None
additional_run_info_ = {}

validation_loss, test_loss = self.calculate_auxiliary_losses(
valid_pred, test_pred
results.valid_pred, results.test_pred
)

if loss_ is not None:
return self.duration, loss_, self.seed, additional_run_info_

cost = loss[self.metric.name]
cost = results.opt_loss[self.metric.name]

additional_run_info = (
{} if additional_run_info is None else additional_run_info
{} if results.additional_run_info is None else results.additional_run_info
)
additional_run_info['opt_loss'] = loss
additional_run_info['opt_loss'] = results.opt_loss
additional_run_info['duration'] = self.duration
additional_run_info['num_run'] = self.num_run
if train_loss is not None:
additional_run_info['train_loss'] = train_loss
if results.train_loss is not None:
additional_run_info['train_loss'] = results.train_loss
if validation_loss is not None:
additional_run_info['validation_loss'] = validation_loss
if test_loss is not None:
additional_run_info['test_loss'] = test_loss

rval_dict = {'loss': cost,
'additional_run_info': additional_run_info,
'status': status}
'status': results.status}

self.queue.put(rval_dict)
return None
Expand Down
Loading

0 comments on commit 1431980

Please sign in to comment.