-
-
Notifications
You must be signed in to change notification settings - Fork 549
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4c42312
commit d98ea9b
Showing
6 changed files
with
123 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# | ||
# Utility functions and classes for solvers | ||
# | ||
|
||
|
||
class NoMemAllocVertcat: | ||
""" | ||
Acts like a vertcat, but does not allocate new memory. | ||
""" | ||
|
||
def __init__(self, a, b): | ||
arrays = [a, b] | ||
self.arrays = arrays | ||
|
||
for array in arrays: | ||
if not 1 <= len(array.shape) <= 2: | ||
raise ValueError("Only 1D or 2D arrays are supported") | ||
self._ndim = len(array.shape) | ||
|
||
self.len_a = a.shape[0] | ||
shape0 = a.shape[0] + b.shape[0] | ||
|
||
if self._ndim == 1: | ||
self._shape = (shape0,) | ||
self._size = shape0 | ||
else: | ||
if a.shape[1] != b.shape[1]: | ||
raise ValueError("All arrays must have the same number of columns") | ||
shape1 = a.shape[1] | ||
|
||
self._shape = (shape0, shape1) | ||
self._size = shape0 * shape1 | ||
|
||
@property | ||
def shape(self): | ||
return self._shape | ||
|
||
@property | ||
def size(self): | ||
return self._size | ||
|
||
@property | ||
def ndim(self): | ||
return self._ndim | ||
|
||
def __getitem__(self, key): | ||
if self._ndim == 1 or isinstance(key, int): | ||
if key < self.len_a: | ||
return self.arrays[0][key] | ||
else: | ||
return self.arrays[1][key - self.len_a] | ||
|
||
if key[0] == slice(None): | ||
return NoMemAllocVertcat(*[arr[:, key[1]] for arr in self.arrays]) | ||
elif isinstance(key[0], int): | ||
if key[0] < self.len_a: | ||
return self.arrays[0][key[0], key[1]] | ||
else: | ||
return self.arrays[1][key[0] - self.len_a, key[1]] | ||
else: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# | ||
# Tests for the solver utility functions and classes | ||
# | ||
import json | ||
import pybamm | ||
import unittest | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.io import loadmat | ||
from tests import get_discretisation_for_testing | ||
|
||
|
||
class TestSolverUtils(unittest.TestCase): | ||
def test_compare_numpy_vertcat(self): | ||
a0 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) | ||
a1 = np.array([[1, 2, 3]]) | ||
b0 = np.array([[13, 14, 15], [16, 17, 18]]) | ||
|
||
for a, b in zip([a0, b0], [a1, b0]): | ||
pybamm_vertcat = pybamm.NoMemAllocVertcat(a, b) | ||
np_vertcat = np.concatenate((a, b), axis=0) | ||
self.assertEqual(pybamm_vertcat.shape, np_vertcat.shape) | ||
self.assertEqual(pybamm_vertcat.size, np_vertcat.size) | ||
for i in range(pybamm_vertcat.shape[0]): | ||
for j in range(pybamm_vertcat.shape[1]): | ||
self.assertEqual(pybamm_vertcat[i, j], np_vertcat[i, j]) | ||
self.assertEqual(pybamm_vertcat[:, j][i], np_vertcat[:, j][i]) | ||
for i in range(pybamm_vertcat.shape[0]): | ||
np.testing.assert_array_equal(pybamm_vertcat[i, :], np_vertcat[i, :]) | ||
|
||
def test_errors(self): | ||
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | ||
b = np.ones((4, 5, 6)) | ||
with self.assertRaisesRegex(ValueError, "Only 1D or 2D arrays are supported"): | ||
pybamm.NoMemAllocVertcat(a, b) | ||
|
||
b = np.array([[10, 11], [13, 14]]) | ||
with self.assertRaisesRegex( | ||
ValueError, "All arrays must have the same number of columns" | ||
): | ||
pybamm.NoMemAllocVertcat(a, b) | ||
|
||
|
||
if __name__ == "__main__": | ||
print("Add -v for more debug output") | ||
import sys | ||
|
||
if "-v" in sys.argv: | ||
debug = True | ||
pybamm.settings.debug_mode = True | ||
unittest.main() |