Skip to content

Commit d3f990b

Browse files
committed
fix metrics bug
1 parent e56421f commit d3f990b

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

minetorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .miner import Miner
22
from .plugin import Plugin
33

4-
__version__ = "0.6.15"
4+
__version__ = "0.6.16"
55

66
__all__ = ["Miner", "Plugin"]

minetorch/metrics.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,22 @@ def before_init(self):
3737

3838
def before_epoch_start(self, epoch):
3939
self.raw_output = []
40-
self.predicts = np.array([]).astype(np.float)
41-
self.targets = np.array([]).astype(np.int)
40+
self.predicts = []
41+
self.targets = []
4242

4343
def after_val_iteration_ended(self, predicts, data, **ignore):
4444
raw_output = predicts.detach().cpu().numpy()
4545
predicts = np.argmax(raw_output, axis=1)
46-
predicts = predicts.reshape([-1])
47-
targets = data[1].cpu().numpy().reshape([-1])
46+
targets = data[1].cpu().numpy()
4847

4948
self.raw_output.append(raw_output)
50-
self.predicts = np.concatenate((self.predicts, predicts))
51-
self.targets = np.concatenate((self.targets, targets))
49+
self.predicts.append(predicts)
50+
self.targets.append(targets)
5251

5352
def after_epoch_end(self, val_loss, **ignore):
53+
self.predicts = np.concatenate(self.predicts)
54+
self.targets = np.concatenate(self.targets)
55+
5456
self._save_results()
5557
self.accuracy and self._accuracy()
5658
self.confusion_matrix and self._confusion_matrix()
@@ -119,5 +121,5 @@ def _kappa_score(self):
119121

120122
def _save_results(self):
121123
file_name = self.plugin_file(f"result.{self.current_epoch}.npz")
122-
raw_output = np.stack(self.raw_output)
124+
raw_output = np.concatenate(self.raw_output)
123125
np.savez_compressed(file_name, predicts=self.predicts, targets=self.targets, raw_output=raw_output)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setuptools.setup(
77
name='Minetorch',
88
description='A tools collection for pytorch users',
9-
version='0.6.15',
9+
version='0.6.16',
1010
packages=setuptools.find_packages(),
1111
include_package_data=True,
1212
url="https://github.com/minetorch/minetorch",

0 commit comments

Comments
 (0)