Skip to content

Commit

Permalink
fix metrics bug
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-she committed Oct 4, 2021
1 parent e56421f commit 73932dd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion minetorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .miner import Miner
from .plugin import Plugin

__version__ = "0.6.15"
__version__ = "0.6.16"

__all__ = ["Miner", "Plugin"]
13 changes: 8 additions & 5 deletions minetorch/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def before_init(self):

def before_epoch_start(self, epoch):
self.raw_output = []
self.predicts = np.array([]).astype(np.float)
self.targets = np.array([]).astype(np.int)
self.predicts = []
self.targets = []

def after_val_iteration_ended(self, predicts, data, **ignore):
raw_output = predicts.detach().cpu().numpy()
Expand All @@ -47,10 +47,13 @@ def after_val_iteration_ended(self, predicts, data, **ignore):
targets = data[1].cpu().numpy().reshape([-1])

self.raw_output.append(raw_output)
self.predicts = np.concatenate((self.predicts, predicts))
self.targets = np.concatenate((self.targets, targets))
self.predicts.append(predicts)
self.targets.append(targets)

def after_epoch_end(self, val_loss, **ignore):
self.predicts = np.concatenate(self.predicts)
self.targets = np.concatenate(self.targets)

self._save_results()
self.accuracy and self._accuracy()
self.confusion_matrix and self._confusion_matrix()
Expand Down Expand Up @@ -119,5 +122,5 @@ def _kappa_score(self):

def _save_results(self):
file_name = self.plugin_file(f"result.{self.current_epoch}.npz")
raw_output = np.stack(self.raw_output)
raw_output = np.concatenate(self.raw_output)
np.savez_compressed(file_name, predicts=self.predicts, targets=self.targets, raw_output=raw_output)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setuptools.setup(
name='Minetorch',
description='A tools collection for pytorch users',
version='0.6.15',
version='0.6.16',
packages=setuptools.find_packages(),
include_package_data=True,
url="https://github.com/minetorch/minetorch",
Expand Down

0 comments on commit 73932dd

Please sign in to comment.