Skip to content

Commit

Permalink
Add numpy support
Browse files Browse the repository at this point in the history
  • Loading branch information
alugowski committed Aug 31, 2023
1 parent b84a78a commit 894ed69
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 4 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# MatSpy

Sparse matrix spy plot and sparkline renderer. Supports:
* **SciPy** - sparse matrices and arrays like `csr_matrix` and `coo_array`
* **[Python-graphblas](https://github.com/python-graphblas/python-graphblas)** - `gb.Matrix`. [See demo.](demo-python-graphblas.ipynb)
* **SciPy** - sparse matrices and arrays like `csr_matrix` and `coo_array` [(demo)](demo.ipynb)
* **NumPy** - `ndarray` [(demo)](demo-numpy.ipynb)
* **[Python-graphblas](https://github.com/python-graphblas/python-graphblas)** - `gb.Matrix` [(demo)](demo-python-graphblas.ipynb)

Features:
* Simple `spy()` method, similar to MatLAB's spy.
Expand Down Expand Up @@ -53,6 +54,7 @@ All methods take the same arguments. Apart from the matrix itself:
* `shading`: `binary`, `relative`, `absolute`.
* `buckets`: spy plot pixels (longest side).
* `dpi`: determine `buckets` relative to figure size.
* `precision`: For numpy arrays, magnitude less than this is considered zero. Like [matplotlib.pyplot.spy()](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.spy.html)'s `precision`.

### Overriding defaults
`matspy.params` contains the default values for all arguments.
Expand Down
181 changes: 181 additions & 0 deletions demo-numpy.ipynb

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion matspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ class MatSpyParams:
buckets: int = None
"""Pixel count of longest side of spy image. If None then computed from size and DPI."""

precision: float = None
"""
Applies to dense matrices like numpy arrays. If None or 0, nonzero values are plotted. Else only values with
absolute value > `precision` are plotted.
Behaves like `matplotlib.pyplot.spy`'s `precision` argument, but for dense arrays only.
"""

spy_aa_tweaks_enabled: bool = None
"""
Whether to_sparkline() may tweak parameters like bucket count to prevent visible aliasing artifacts.
Expand Down Expand Up @@ -117,6 +125,9 @@ def _register_bundled():
from .adapters.scipy_driver import SciPyDriver
register_driver(SciPyDriver)

from .adapters.numpy_driver import NumPyDriver
register_driver(NumPyDriver)

from .adapters.graphblas_driver import GraphBLASDriver
register_driver(GraphBLASDriver)

Expand All @@ -125,7 +136,7 @@ def _register_bundled():


def _get_driver(mat):
type_str = ".".join((mat.__module__, mat.__class__.__name__))
type_str = ".".join((type(mat).__module__, type(mat).__name__))
for prefix, driver in _driver_prefixes.items():
if type_str.startswith(prefix):
return driver
Expand Down
18 changes: 18 additions & 0 deletions matspy/adapters/numpy_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (C) 2023 Adam Lugowski.
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

from typing import Any, Iterable

from . import Driver, MatrixSpyAdapter


class NumPyDriver(Driver):
@staticmethod
def get_supported_type_prefixes() -> Iterable[str]:
return ["numpy."]

@staticmethod
def adapt_spy(mat: Any) -> MatrixSpyAdapter:
from .numpy_impl import NumPySpy
return NumPySpy(mat)
39 changes: 39 additions & 0 deletions matspy/adapters/numpy_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (C) 2023 Adam Lugowski.
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

import numpy as np
from scipy.sparse import csr_matrix

from . import describe, MatrixSpyAdapter
from .scipy_impl import SciPySpy


class NumPySpy(MatrixSpyAdapter):
def __init__(self, arr):
super().__init__()
if len(arr.shape) != 2:
raise ValueError("Only 2D arrays are supported")
self.arr = arr

def get_shape(self) -> tuple:
return self.arr.shape

def describe(self) -> str:
format_name = "array"

return describe(shape=self.arr.shape, nz_type=self.arr.dtype,
notes=f"{format_name}")

def get_spy(self, spy_shape: tuple) -> np.array:
precision = self.get_option("precision", None)

if not precision:
mask = (self.arr != 0)
else:
mask = (self.arr > precision) | (self.arr < -precision)

if self.arr.dtype == 'object':
mask = mask & (self.arr != np.array([None]))

return SciPySpy(csr_matrix(mask)).get_spy(spy_shape)
10 changes: 9 additions & 1 deletion matspy/spy_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,16 @@ def _rescale(arr, from_range, to_range):

# noinspection PyUnusedLocal
def get_spy_heatmap(adapter: MatrixSpyAdapter, buckets, shading, shading_absolute_min,
shading_relative_min, shading_relative_max_percentile, **kwargs):
shading_relative_min, shading_relative_max_percentile, precision, **kwargs):
# find spy matrix shape
mat_shape = adapter.get_shape()
if mat_shape[0] == 0 or mat_shape[1] == 0:
return np.array([[]])

ratio = buckets / max(mat_shape)
spy_shape = tuple(max(1, int(ratio * x)) for x in mat_shape)

adapter.set_option("precision", precision)
dense = adapter.get_spy(spy_shape=spy_shape)

dense[dense < 0] = 0
Expand Down Expand Up @@ -236,6 +240,10 @@ def to_sparkline(mat, retscale=False, scale=None, html_border="1px solid black",
repeat = int(repeat) if repeat >= 2 else 1

heatmap = to_spy_heatmap(adapter, **options.to_kwargs())
if heatmap.size == 0:
# zero-size
return "&#9643;" # a single character that is an empty square

if repeat > 1:
heatmap = heatmap.repeat(repeat, axis=0)
heatmap = heatmap.repeat(repeat, axis=1)
Expand Down
51 changes: 51 additions & 0 deletions tests/test_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2023 Adam Lugowski.
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

import unittest

import numpy as np

from matspy import spy_to_mpl, to_sparkline, to_spy_heatmap

np.random.seed(123)


class NumPyTests(unittest.TestCase):
def setUp(self):
self.mats = [
np.array([[]]),
np.random.random((10, 10)),
]

def test_no_crash(self):
import matplotlib.pyplot as plt
for mat in self.mats:
fig, ax = spy_to_mpl(mat)
plt.close(fig)

res = to_sparkline(mat)
self.assertGreater(len(res), 5)

def test_shape(self):
arr = np.array([])
with self.assertRaises(ValueError):
spy_to_mpl(arr)

def test_count(self):
arrs = [
(1, np.array([[1]])),
(1, np.array([[1, 0], [0, 0]])),
(1, np.array([[1, None], [None, None]])),
(1, np.array([[1, 0], [None, None]])),
]

for count, arr in arrs:
area = np.prod(arr.shape)
heatmap = to_spy_heatmap(arr, buckets=1, shading="absolute")
self.assertEqual(len(heatmap), 1)
self.assertAlmostEqual(heatmap[0][0], count / area, places=2)


if __name__ == '__main__':
unittest.main()

0 comments on commit 894ed69

Please sign in to comment.