Skip to content

Commit 76e0b56

Browse files
committed
no more safe-copy
1 parent 373e426 commit 76e0b56

File tree

5 files changed

+114
-159
lines changed

5 files changed

+114
-159
lines changed

gustaf/edges.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def __init__(
8585
vertices=None,
8686
edges=None,
8787
elements=None,
88-
copy=True,
8988
):
9089
"""Edges. It has vertices and edges. Also known as lines.
9190
@@ -94,7 +93,7 @@ def __init__(
9493
vertices: (n, d) np.ndarray
9594
edges: (n, 2) np.ndarray
9695
"""
97-
super().__init__(vertices=vertices, copy=copy)
96+
super().__init__(vertices=vertices)
9897

9998
if edges is not None:
10099
self.edges = edges
@@ -132,7 +131,7 @@ def edges(self, es):
132131
self._logd("setting edges")
133132

134133
self._edges = helpers.data.make_tracked_array(
135-
es, settings.INT_DTYPE, self.setter_copies
134+
es, settings.INT_DTYPE, copy=False
136135
)
137136

138137
# shape check

gustaf/faces.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def __init__(
9797
vertices=None,
9898
faces=None,
9999
elements=None,
100-
copy=True,
101100
):
102101
"""Faces. It has vertices and faces. Faces could be triangles or
103102
quadrilaterals.
@@ -107,7 +106,7 @@ def __init__(
107106
vertices: (n, d) np.ndarray
108107
faces: (n, 3) or (n, 4) np.ndarray
109108
"""
110-
super().__init__(vertices=vertices, copy=copy)
109+
super().__init__(vertices=vertices)
111110
if faces is not None:
112111
self.faces = faces
113112

@@ -176,9 +175,7 @@ def whatareyou(cls, face_obj):
176175
)
177176

178177
@property
179-
def faces(
180-
self,
181-
):
178+
def faces(self):
182179
"""Returns faces.
183180
184181
Parameters
@@ -209,7 +206,7 @@ def faces(self, fs):
209206
self._faces = helpers.data.make_tracked_array(
210207
fs,
211208
settings.INT_DTYPE,
212-
self.setter_copies,
209+
copy=False,
213210
)
214211
# shape check
215212
if fs is not None:

gustaf/helpers/data.py

Lines changed: 105 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -13,52 +13,66 @@
1313

1414

1515
class TrackedArray(np.ndarray):
16-
"""Taken from nice implementations of `trimesh` (see LICENSE.txt).
17-
`https://github.com/mikedh/trimesh/blob/main/trimesh/caching.py`. Minor
18-
adaption, since we don't have hashing functionalities.
19-
20-
All the inplace functions will set modified flag and if some operations
21-
has potential to cause un-trackable behavior, writeable flags will be set
22-
to False.
23-
24-
Note, if you really really want, it is possible to change the tracked
25-
array without setting modified flag.
16+
"""numpy array object that keeps mirroring inplace changes to the source.
17+
Meant to help control_points.
2618
"""
2719

28-
__slots__ = ("_modified", "_source")
20+
__slots__ = (
21+
"_super_arr",
22+
"_modified",
23+
)
2924

3025
def __array_finalize__(self, obj):
3126
"""Sets default flags for any arrays that maybe generated based on
32-
tracked array."""
27+
physical space array. For more information,
28+
see https://numpy.org/doc/stable/user/basics.subclassing.html"""
29+
self._super_arr = None
3330
self._modified = True
34-
self._source = 0
3531

32+
# for arrays created based on this subclass
3633
if isinstance(obj, type(self)):
37-
if isinstance(obj._source, int):
38-
self._source = obj
39-
else:
40-
self._source = obj._source
34+
# this is copy. nothing to worry here
35+
if self.base is None:
36+
return None
37+
38+
# first child array
39+
if self.base is obj:
40+
# make sure this is not a recursively born child
41+
# for example, `arr[[1,2]][:,2]`
42+
# we should have set _super_arr to True
43+
# if we made this array using `make_tracked_array`
44+
if obj._super_arr is True:
45+
self._super_arr = obj
46+
return None
47+
48+
# multi generation child array
49+
if obj._super_arr is not None and self.base is obj.base:
50+
self._super_arr = obj._super_arr
51+
return None
52+
53+
return None
4154

4255
@property
43-
def mutable(self):
44-
return self.flags["WRITEABLE"]
56+
def modified(self):
57+
"""
58+
Modified flag getter
59+
"""
60+
# have super arr and self is not super_arr,
61+
if self._super_arr is not None and self._super_arr is not True:
62+
return self._super_arr._modified
4563

46-
@mutable.setter
47-
def mutable(self, value):
48-
self.flags.writeable = value
64+
return self._modified
4965

50-
def _set_modified(self):
51-
"""set modified flags to itself and to the source."""
52-
self._modified = True
53-
if isinstance(self._source, type(self)):
54-
self._source._modified = True
55-
56-
def copy(self, *_args, **_kwargs):
57-
"""copy gives np.ndarray.
66+
@modified.setter
67+
def modified(self, m):
68+
if self._super_arr is not None and self._super_arr is not True:
69+
self._super_arr._modified = m
70+
else:
71+
self._modified = m
5872

59-
no more tracking.
60-
"""
61-
return np.array(self, copy=True)
73+
def copy(self, *args, **kwargs):
74+
"""copy creates regular numpy array"""
75+
return np.array(self, *args, copy=True, **kwargs)
6276

6377
def view(self, *args, **kwargs):
6478
"""Set writeable flags to False for the view."""
@@ -67,89 +81,88 @@ def view(self, *args, **kwargs):
6781
return v
6882

6983
def __iadd__(self, *args, **kwargs):
70-
self._set_modified()
71-
return super(self.__class__, self).__iadd__(*args, **kwargs)
84+
sr = super(self.__class__, self).__iadd__(*args, **kwargs)
85+
self.modified = True
86+
return sr
7287

7388
def __isub__(self, *args, **kwargs):
74-
self._set_modified()
75-
return super(self.__class__, self).__isub__(*args, **kwargs)
89+
sr = super(self.__class__, self).__isub__(*args, **kwargs)
90+
self.modified = True
91+
return sr
7692

7793
def __imul__(self, *args, **kwargs):
78-
self._set_modified()
79-
return super(self.__class__, self).__imul__(*args, **kwargs)
94+
sr = super(self.__class__, self).__imul__(*args, **kwargs)
95+
self.modified = True
96+
return sr
8097

8198
def __idiv__(self, *args, **kwargs):
82-
self._set_modified()
83-
return super(self.__class__, self).__idiv__(*args, **kwargs)
99+
sr = super(self.__class__, self).__idiv__(*args, **kwargs)
100+
self.modified = True
101+
return sr
84102

85103
def __itruediv__(self, *args, **kwargs):
86-
self._set_modified()
87-
return super(self.__class__, self).__itruediv__(*args, **kwargs)
104+
sr = super(self.__class__, self).__itruediv__(*args, **kwargs)
105+
self.modified = True
106+
return sr
88107

89108
def __imatmul__(self, *args, **kwargs):
90-
self._set_modified()
91-
return super(self.__class__, self).__imatmul__(*args, **kwargs)
109+
sr = super(self.__class__, self).__imatmul__(*args, **kwargs)
110+
self.modified = True
111+
return sr
92112

93113
def __ipow__(self, *args, **kwargs):
94-
self._set_modified()
95-
return super(self.__class__, self).__ipow__(*args, **kwargs)
114+
sr = super(self.__class__, self).__ipow__(*args, **kwargs)
115+
self.modified = True
116+
return sr
96117

97118
def __imod__(self, *args, **kwargs):
98-
self._set_modified()
99-
return super(self.__class__, self).__imod__(*args, **kwargs)
119+
sr = super(self.__class__, self).__imod__(*args, **kwargs)
120+
self.modified = True
121+
return sr
100122

101123
def __ifloordiv__(self, *args, **kwargs):
102-
self._set_modified()
103-
return super(self.__class__, self).__ifloordiv__(*args, **kwargs)
124+
sr = super(self.__class__, self).__ifloordiv__(*args, **kwargs)
125+
self.modified = True
126+
return sr
104127

105128
def __ilshift__(self, *args, **kwargs):
106-
self._set_modified()
107-
return super(self.__class__, self).__ilshift__(*args, **kwargs)
129+
sr = super(self.__class__, self).__ilshift__(*args, **kwargs)
130+
self.modified = True
131+
return sr
108132

109133
def __irshift__(self, *args, **kwargs):
110-
self._set_modified()
111-
return super(self.__class__, self).__irshift__(*args, **kwargs)
134+
sr = super(self.__class__, self).__irshift__(*args, **kwargs)
135+
self.modified = True
136+
return sr
112137

113138
def __iand__(self, *args, **kwargs):
114-
self._set_modified()
115-
return super(self.__class__, self).__iand__(*args, **kwargs)
139+
sr = super(self.__class__, self).__iand__(*args, **kwargs)
140+
self.modified = True
141+
return sr
116142

117143
def __ixor__(self, *args, **kwargs):
118-
self._set_modified()
119-
return super(self.__class__, self).__ixor__(*args, **kwargs)
144+
sr = super(self.__class__, self).__ixor__(*args, **kwargs)
145+
self.modified = True
146+
return sr
120147

121148
def __ior__(self, *args, **kwargs):
122-
self._set_modified()
123-
return super(self.__class__, self).__ior__(*args, **kwargs)
124-
125-
def __setitem__(self, *args, **kwargs):
126-
self._set_modified()
127-
super(self.__class__, self).__setitem__(*args, **kwargs)
128-
129-
def __setslice__(self, *args, **kwargs):
130-
self._set_modified()
131-
super(self.__class__, self).__setslice__(*args, **kwargs)
149+
sr = super(self.__class__, self).__ior__(*args, **kwargs)
150+
self.modified = True
151+
return sr
132152

133-
def __getslice__(self, *args, **kwargs):
134-
self._set_modified()
135-
"""
136-
return slices I am pretty sure np.ndarray does not have __*slice__
137-
"""
138-
slices = super(self.__class__, self).__getitem__(*args, **kwargs)
139-
if isinstance(slices, np.ndarray):
140-
slices.flags.writeable = False
141-
return slices
153+
def __setitem__(self, key, value):
154+
# set first. invalid setting will cause error
155+
sr = super(self.__class__, self).__setitem__(key, value)
156+
self.modified = True
157+
return sr
142158

143159

144160
def make_tracked_array(array, dtype=None, copy=True):
145-
"""Taken from nice implementations of `trimesh` (see LICENSE.txt).
161+
"""Motivated by nice implementations of `trimesh` (see LICENSE.txt).
146162
`https://github.com/mikedh/trimesh/blob/main/trimesh/caching.py`.
147163
148-
``Properly subclass a numpy ndarray to track changes.
149-
Avoids some pitfalls of subclassing by forcing contiguous
150-
arrays and does a view into a TrackedArray.``
151-
152164
Factory-like wrapper function for TrackedArray.
165+
If you want to use TrackedArray, it is recommended to use this function.
153166
154167
Parameters
155168
------------
@@ -168,16 +181,16 @@ def make_tracked_array(array, dtype=None, copy=True):
168181
# if someone passed us None, just create an empty array
169182
if array is None:
170183
array = []
171-
# make sure it is contiguous then view it as our subclass
172-
tracked = np.ascontiguousarray(array, dtype=dtype)
173-
tracked = (
174-
tracked.copy().view(TrackedArray)
175-
if copy
176-
else tracked.view(TrackedArray)
177-
)
178184

179-
# should always be contiguous here
180-
assert tracked.flags["C_CONTIGUOUS"]
185+
if copy:
186+
array = np.array(array, dtype=dtype)
187+
else:
188+
array = np.asanyarray(array, dtype=dtype)
189+
190+
tracked = array.view(TrackedArray)
191+
192+
# this marks original array
193+
tracked._super_arr = True
181194

182195
return tracked
183196

0 commit comments

Comments
 (0)