-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
314 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (C) 2023 Adam Lugowski. | ||
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file. | ||
# SPDX-License-Identifier: BSD-2-Clause | ||
|
||
from typing import Any, Iterable | ||
|
||
from . import Driver, MatrixSpyAdapter | ||
|
||
|
||
class NumPyDriver(Driver): | ||
@staticmethod | ||
def get_supported_type_prefixes() -> Iterable[str]: | ||
return ["numpy."] | ||
|
||
@staticmethod | ||
def adapt_spy(mat: Any) -> MatrixSpyAdapter: | ||
from .numpy_impl import NumPySpy | ||
return NumPySpy(mat) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (C) 2023 Adam Lugowski. | ||
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file. | ||
# SPDX-License-Identifier: BSD-2-Clause | ||
|
||
import numpy as np | ||
from scipy.sparse import csr_matrix | ||
|
||
from . import describe, MatrixSpyAdapter | ||
from .scipy_impl import SciPySpy | ||
|
||
|
||
class NumPySpy(MatrixSpyAdapter): | ||
def __init__(self, arr): | ||
super().__init__() | ||
if len(arr.shape) != 2: | ||
raise ValueError("Only 2D arrays are supported") | ||
self.arr = arr | ||
|
||
def get_shape(self) -> tuple: | ||
return self.arr.shape | ||
|
||
def describe(self) -> str: | ||
format_name = "array" | ||
|
||
return describe(shape=self.arr.shape, nz_type=self.arr.dtype, | ||
notes=f"{format_name}") | ||
|
||
def get_spy(self, spy_shape: tuple) -> np.array: | ||
precision = self.get_option("precision", None) | ||
|
||
if not precision: | ||
mask = (self.arr != 0) | ||
else: | ||
mask = (self.arr > precision) | (self.arr < -precision) | ||
|
||
if self.arr.dtype == 'object': | ||
mask = mask & (self.arr != np.array([None])) | ||
|
||
return SciPySpy(csr_matrix(mask)).get_spy(spy_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (C) 2023 Adam Lugowski. | ||
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file. | ||
# SPDX-License-Identifier: BSD-2-Clause | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
|
||
from matspy import spy_to_mpl, to_sparkline, to_spy_heatmap | ||
|
||
np.random.seed(123) | ||
|
||
|
||
class NumPyTests(unittest.TestCase): | ||
def setUp(self): | ||
self.mats = [ | ||
np.array([[]]), | ||
np.random.random((10, 10)), | ||
] | ||
|
||
def test_no_crash(self): | ||
import matplotlib.pyplot as plt | ||
for mat in self.mats: | ||
fig, ax = spy_to_mpl(mat) | ||
plt.close(fig) | ||
|
||
res = to_sparkline(mat) | ||
self.assertGreater(len(res), 5) | ||
|
||
def test_shape(self): | ||
arr = np.array([]) | ||
with self.assertRaises(ValueError): | ||
spy_to_mpl(arr) | ||
|
||
def test_count(self): | ||
arrs = [ | ||
(1, np.array([[1]])), | ||
(1, np.array([[1, 0], [0, 0]])), | ||
(1, np.array([[1, None], [None, None]])), | ||
(1, np.array([[1, 0], [None, None]])), | ||
] | ||
|
||
for count, arr in arrs: | ||
area = np.prod(arr.shape) | ||
heatmap = to_spy_heatmap(arr, buckets=1, shading="absolute") | ||
self.assertEqual(len(heatmap), 1) | ||
self.assertAlmostEqual(heatmap[0][0], count / area, places=2) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |