Skip to content

Commit 1987587

Browse files
authored
Merge pull request #205 from tataratat/ft-helper-tests
added missing test
2 parents 3a71fdd + b41c1e4 commit 1987587

File tree

2 files changed

+316
-0
lines changed

2 files changed

+316
-0
lines changed

gustaf/helpers/data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,20 @@ def __contains__(self, key):
251251
"""
252252
return key in self._saved
253253

254+
def __len__(self):
255+
"""
256+
Returns number of items.
257+
258+
Parameters
259+
----------
260+
None
261+
262+
Returns
263+
-------
264+
len: int
265+
"""
266+
return len(self._saved)
267+
254268
def pop(self, key, default=None):
255269
"""
256270
Applied pop() to saved data

tests/test_helpers/test_data.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
import sys
2+
3+
import numpy as np
4+
import pytest
5+
6+
import gustaf
7+
8+
9+
def new_tracked_array(dtype=float):
10+
"""
11+
create new tracked array and checks if default flags are set correctly.
12+
Then sets modified to False, to give an easy start for testing
13+
"""
14+
ta = gustaf.helpers.data.make_tracked_array(
15+
[
16+
[0, 1, 2],
17+
[3, 4, 5],
18+
[6, 7, 8],
19+
],
20+
dtype=dtype,
21+
)
22+
23+
assert ta.modified
24+
assert ta._super_arr
25+
26+
ta.modified = False
27+
28+
return ta
29+
30+
31+
def test_TrackedArray():
32+
"""test if modified flag is well set"""
33+
# 1. set item
34+
ta = new_tracked_array()
35+
ta[0] = 1
36+
assert ta.modified
37+
38+
ta = new_tracked_array()
39+
ta[1, 1] = 2
40+
assert ta.modified
41+
42+
# in place
43+
ta = new_tracked_array()
44+
ta += 5
45+
assert ta.modified
46+
47+
ta = new_tracked_array()
48+
ta -= 3
49+
assert ta.modified
50+
51+
ta = new_tracked_array()
52+
ta *= 1
53+
assert ta.modified
54+
55+
ta = new_tracked_array()
56+
ta /= 1.5
57+
assert ta.modified
58+
59+
# old distributions of numpy does not have this feature
60+
if sys.version_info > (3, 9):
61+
ta = new_tracked_array()
62+
ta @= ta
63+
assert ta.modified
64+
65+
ta = new_tracked_array()
66+
ta **= 2
67+
assert ta.modified
68+
69+
ta = new_tracked_array()
70+
ta %= 3
71+
assert ta.modified
72+
73+
ta = new_tracked_array()
74+
ta //= 2
75+
assert ta.modified
76+
77+
ta = new_tracked_array(int)
78+
ta <<= 3
79+
assert ta.modified
80+
81+
ta = new_tracked_array(int)
82+
ta >>= 1
83+
assert ta.modified
84+
85+
ta = new_tracked_array(int)
86+
ta |= 3
87+
assert ta.modified
88+
89+
ta = new_tracked_array(int)
90+
ta &= 3
91+
assert ta.modified
92+
93+
ta = new_tracked_array(int)
94+
ta ^= 3
95+
assert ta.modified
96+
97+
# child array modification
98+
ta = new_tracked_array()
99+
ta_child = ta[0]
100+
assert ta_child.base is ta
101+
ta_child += 5
102+
assert ta.modified
103+
assert ta_child.modified
104+
105+
# copy returns normal np.ndarray
106+
assert isinstance(new_tracked_array().copy(), np.ndarray)
107+
108+
109+
def test_DataHolder():
110+
"""Base class of dataholder types"""
111+
112+
class Helpee:
113+
pass
114+
115+
helpee = Helpee()
116+
117+
dataholder = gustaf.helpers.data.DataHolder(helpee)
118+
119+
# setitem is pure abstract
120+
with pytest.raises(NotImplementedError):
121+
dataholder["somedata"] = []
122+
123+
# test other functions by injecting some keys and values directly to the
124+
# member
125+
dataholder._saved.update(a=1, b=2, c=3)
126+
127+
# getitem
128+
assert dataholder["a"] == 1
129+
assert dataholder["b"] == 2
130+
assert dataholder["c"] == 3
131+
with pytest.raises(KeyError):
132+
dataholder["d"]
133+
134+
# contains
135+
assert "a" in dataholder
136+
assert "b" in dataholder
137+
assert "c" in dataholder
138+
assert "d" not in dataholder
139+
assert "e" not in dataholder
140+
141+
# len
142+
assert len(dataholder) == 3
143+
144+
# pop
145+
assert dataholder.pop("c") == 3
146+
assert "c" not in dataholder
147+
148+
# get
149+
# 1. key
150+
assert dataholder.get("a") == 1
151+
# 2. key and default
152+
assert dataholder.get("b", 2) == 2
153+
# 3. key and wrong default
154+
assert dataholder.get("b", 3) == 2
155+
# 4. empty key - always None
156+
assert dataholder.get("c") is None
157+
# 5. empty key and default
158+
assert dataholder.get("c", "123") == "123"
159+
160+
# keys
161+
assert len(set(dataholder.keys()).difference({"a", "b"})) == 0
162+
163+
# values
164+
assert len(set(dataholder.values()).difference({1, 2})) == 0
165+
166+
# items
167+
for k, v in dataholder.items():
168+
assert k in dataholder.keys() # noqa SIM118
169+
assert v in dataholder.values()
170+
171+
# update
172+
dataholder.update(b=22, c=33, d=44)
173+
assert dataholder["a"] == 1
174+
assert dataholder["b"] == 22
175+
assert dataholder["c"] == 33
176+
assert dataholder["d"] == 44
177+
178+
# clear
179+
dataholder.clear()
180+
assert "a" not in dataholder
181+
assert "b" not in dataholder
182+
assert "c" not in dataholder
183+
assert "d" not in dataholder
184+
assert len(dataholder) == 0
185+
186+
187+
@pytest.mark.parametrize(
188+
"grid", ("edges", "faces_tri", "faces_quad", "volumes_tet", "volumes_hexa")
189+
)
190+
def test_ComputedData(grid, request):
191+
grid = request.getfixturevalue(grid)
192+
193+
# vertex related data
194+
v_data = (
195+
"unique_vertices",
196+
"bounds",
197+
"bounds_diagonal",
198+
"bounds_diagonal_norm",
199+
)
200+
201+
# element related data
202+
e_data = (
203+
"sorted_edges",
204+
"unique_edges",
205+
"single_edges",
206+
"edges",
207+
"sorted_faces",
208+
"unique_faces",
209+
"single_faces",
210+
"faces",
211+
"sorted_volumes",
212+
"unique_volumes",
213+
)
214+
215+
# for both
216+
both_data = ("centers", "referenced_vertices")
217+
218+
# entities before modification
219+
data_dependency = {"vertex": v_data, "element": e_data, "both": both_data}
220+
before = {}
221+
for dependency, attributes in data_dependency.items():
222+
# init
223+
before[dependency] = {}
224+
for attr in attributes:
225+
func = getattr(grid, attr, None)
226+
if attr is not None and callable(func):
227+
before[dependency][attr] = func()
228+
229+
# ensure that func is called at least once
230+
assert len(before[dependency]) != 0
231+
232+
# loop to check if you get the saved data
233+
for attributes in before.values():
234+
for attr, value in attributes.items():
235+
func = getattr(grid, attr, None)
236+
assert value is func()
237+
238+
# change vertices - assign new vertices
239+
grid.vertices = grid.vertices.copy()
240+
for dependency, attributes in before.items():
241+
if dependency == "element":
242+
continue
243+
for attr, value in attributes.items():
244+
func = getattr(grid, attr, None)
245+
assert value is not func() # should be different object
246+
247+
# change elements - assign new elements
248+
grid.elements = grid.elements.copy()
249+
for dependency, attributes in before.items():
250+
if dependency == "vertex":
251+
continue
252+
for attr, value in attributes.items():
253+
func = getattr(grid, attr, None)
254+
assert value is not func()
255+
256+
257+
@pytest.mark.parametrize(
258+
"grid", ("edges", "faces_tri", "faces_quad", "volumes_tet", "volumes_hexa")
259+
)
260+
def test_VertexData(grid, request):
261+
grid = request.getfixturevalue(grid)
262+
263+
key = "vertices"
264+
265+
# set data
266+
grid.vertex_data[key] = grid.vertices
267+
268+
# get_data - data is viewed as TrackedArray, so check against base
269+
assert grid.vertices is grid.vertex_data[key].base
270+
271+
# scalar extraction should return a norm
272+
assert np.allclose(
273+
grid.vertex_data.as_scalar(key).ravel(),
274+
np.linalg.norm(grid.vertex_data.get(key), axis=1),
275+
)
276+
277+
# norms should be saved, as long as data array isn't changed
278+
assert grid.vertex_data.as_scalar(key) is grid.vertex_data.as_scalar(key)
279+
280+
before = grid.vertex_data.as_scalar(key)
281+
# trigger modified flag on data - either reset or inplace change
282+
# reset first - with copy, just so that we can try to make inplace changes
283+
# later
284+
grid.vertex_data[key] = grid.vertex_data[key].copy()
285+
assert before is not grid.vertex_data.as_scalar(key)
286+
assert grid.vertex_data.as_scalar(key) is grid.vertex_data.as_scalar(key)
287+
288+
grid.vertex_data[key][0] = grid.vertex_data[key][0]
289+
assert before is not grid.vertex_data.as_scalar(key)
290+
assert grid.vertex_data.as_scalar(key) is grid.vertex_data.as_scalar(key)
291+
292+
# check arrow data
293+
assert grid.vertex_data[key] is grid.vertex_data.as_arrow(key)
294+
295+
# check wrong length assignment
296+
with pytest.raises(ValueError):
297+
grid.vertex_data["bad"] = np.vstack((grid.vertices, grid.vertices))
298+
299+
# check wrong arrow data request
300+
with pytest.raises(ValueError):
301+
grid.vertex_data["norm"] = grid.vertex_data.as_scalar(key)
302+
grid.vertex_data.as_arrow("norm")

0 commit comments

Comments
 (0)