diff --git a/minetorch/__init__.py b/minetorch/__init__.py index 23af3a2..068ee94 100644 --- a/minetorch/__init__.py +++ b/minetorch/__init__.py @@ -1,6 +1,6 @@ from .miner import Miner from .plugin import Plugin -__version__ = "0.6.15" +__version__ = "0.6.16" __all__ = ["Miner", "Plugin"] diff --git a/minetorch/metrics.py b/minetorch/metrics.py index c5d31bd..3cdcf2b 100644 --- a/minetorch/metrics.py +++ b/minetorch/metrics.py @@ -37,8 +37,8 @@ def before_init(self): def before_epoch_start(self, epoch): self.raw_output = [] - self.predicts = np.array([]).astype(np.float) - self.targets = np.array([]).astype(np.int) + self.predicts = [] + self.targets = [] def after_val_iteration_ended(self, predicts, data, **ignore): raw_output = predicts.detach().cpu().numpy() @@ -47,10 +47,13 @@ def after_val_iteration_ended(self, predicts, data, **ignore): targets = data[1].cpu().numpy().reshape([-1]) self.raw_output.append(raw_output) - self.predicts = np.concatenate((self.predicts, predicts)) - self.targets = np.concatenate((self.targets, targets)) + self.predicts.append(predicts) + self.targets.append(targets) def after_epoch_end(self, val_loss, **ignore): + self.predicts = np.concatenate(self.predicts) + self.targets = np.concatenate(self.targets) + self._save_results() self.accuracy and self._accuracy() self.confusion_matrix and self._confusion_matrix() @@ -119,5 +122,5 @@ def _kappa_score(self): def _save_results(self): file_name = self.plugin_file(f"result.{self.current_epoch}.npz") - raw_output = np.stack(self.raw_output) + raw_output = np.concatenate(self.raw_output) np.savez_compressed(file_name, predicts=self.predicts, targets=self.targets, raw_output=raw_output) diff --git a/setup.py b/setup.py index 7f95e0a..aa02d0f 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name='Minetorch', description='A tools collection for pytorch users', - version='0.6.15', + version='0.6.16', packages=setuptools.find_packages(), include_package_data=True, url="https://github.com/minetorch/minetorch",