Skip to content

Commit ae77625

Browse files
committed
Add field parity, and copy parity and mask in arithmetic operators.
1 parent da41a99 commit ae77625

File tree

1 file changed

+120
-54
lines changed

1 file changed

+120
-54
lines changed

parity_tensor/parity_tensor.py

Lines changed: 120 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,54 @@ class ParityTensor:
1919
Each dimension of the tensor is composed of an even and an odd part, represented as a pair of integers.
2020
"""
2121

22-
edges: tuple[tuple[int, int], ...]
23-
tensor: torch.Tensor
24-
mask: torch.Tensor | None = None
22+
_edges: tuple[tuple[int, int], ...]
23+
_tensor: torch.Tensor
24+
_parity: tuple[torch.Tensor, ...] | None = None
25+
_mask: torch.Tensor | None = None
26+
27+
@property
28+
def edges(self) -> tuple[tuple[int, int], ...]:
29+
"""
30+
The edges of the tensor, represented as a tuple of pairs (even, odd).
31+
"""
32+
return self._edges
33+
34+
@property
35+
def tensor(self) -> torch.Tensor:
36+
"""
37+
The underlying tensor data.
38+
"""
39+
return self._tensor
40+
41+
@property
42+
def parity(self) -> tuple[torch.Tensor, ...]:
43+
"""
44+
The parity of each edge, represented as a tuple of tensors.
45+
"""
46+
if self._parity is None:
47+
self._parity = tuple(self._edge_mask(even, odd) for (even, odd) in self._edges)
48+
return self._parity
49+
50+
@property
51+
def mask(self) -> torch.Tensor:
52+
"""
53+
The mask of the tensor, which has the same shape as the tensor and indicates which elements could be non-zero based on the parity.
54+
"""
55+
if self._mask is None:
56+
self._mask = self._tensor_mask()
57+
return self._mask
58+
59+
def update_mask(self) -> ParityTensor:
60+
"""
61+
Update the mask of the tensor based on its parity.
62+
"""
63+
self._tensor = torch.where(self.mask, self._tensor, 0)
64+
return self
2565

2666
def __post_init__(self) -> None:
27-
assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})."
28-
for dim, (even, odd) in zip(self.tensor.shape, self.edges):
67+
assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})."
68+
for dim, (even, odd) in zip(self._tensor.shape, self._edges):
2969
assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative."
30-
if self.mask is None:
31-
self.mask = self._tensor_mask()
3270

3371
@classmethod
3472
def _unqueeze(cls, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor:
@@ -41,176 +79,204 @@ def _edge_mask(cls, even: int, odd: int) -> torch.Tensor:
4179
def _tensor_mask(self) -> torch.Tensor:
4280
return functools.reduce(
4381
torch.logical_xor,
44-
(self._unqueeze(self._edge_mask(even, odd), index, self.tensor.dim()) for index, (even, odd) in enumerate(self.edges)),
45-
torch.ones_like(self.tensor, dtype=torch.bool),
82+
(self._unqueeze(parity, index, self._tensor.dim()) for index, parity in enumerate(self.parity)),
83+
torch.ones_like(self._tensor, dtype=torch.bool),
4684
)
4785

4886
def _validate_edge_compatibility(self, other: ParityTensor) -> None:
4987
"""
5088
Validate that the edges of two ParityTensor instances are compatible for arithmetic operations.
5189
"""
52-
assert self.edges == other.edges, f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}."
90+
assert self._edges == other.edges, f"Edges must match for arithmetic operations. Got {self._edges} and {other.edges}."
5391

5492
def __pos__(self) -> ParityTensor:
5593
return ParityTensor(
56-
edges=self.edges,
57-
tensor=+self.tensor,
94+
_edges=self._edges,
95+
_tensor=+self._tensor,
96+
_parity=self._parity,
97+
_mask=self._mask,
5898
)
5999

60100
def __neg__(self) -> ParityTensor:
61101
return ParityTensor(
62-
edges=self.edges,
63-
tensor=-self.tensor,
102+
_edges=self._edges,
103+
_tensor=-self._tensor,
104+
_parity=self._parity,
105+
_mask=self._mask,
64106
)
65107

66108
def __add__(self, other: typing.Any) -> ParityTensor:
67109
if isinstance(other, ParityTensor):
68110
self._validate_edge_compatibility(other)
69111
return ParityTensor(
70-
edges=self.edges,
71-
tensor=self.tensor + other.tensor,
112+
_edges=self._edges,
113+
_tensor=self._tensor + other._tensor,
114+
_parity=self._parity,
115+
_mask=self._mask,
72116
)
73117
try:
74-
result = self.tensor + other
118+
result = self._tensor + other
75119
except TypeError:
76120
return NotImplemented
77121
if isinstance(result, torch.Tensor):
78122
return ParityTensor(
79-
edges=self.edges,
80-
tensor=result,
123+
_edges=self._edges,
124+
_tensor=result,
125+
_parity=self._parity,
126+
_mask=self._mask,
81127
)
82128
return NotImplemented
83129

84130
def __radd__(self, other: typing.Any) -> ParityTensor:
85131
try:
86-
result = other + self.tensor
132+
result = other + self._tensor
87133
except TypeError:
88134
return NotImplemented
89135
if isinstance(result, torch.Tensor):
90136
return ParityTensor(
91-
edges=self.edges,
92-
tensor=result,
137+
_edges=self._edges,
138+
_tensor=result,
139+
_parity=self._parity,
140+
_mask=self._mask,
93141
)
94142
return NotImplemented
95143

96144
def __iadd__(self, other: typing.Any) -> ParityTensor:
97145
if isinstance(other, ParityTensor):
98146
self._validate_edge_compatibility(other)
99-
self.tensor += other.tensor
147+
self._tensor += other._tensor
100148
else:
101-
self.tensor += other
149+
self._tensor += other
102150
return self
103151

104152
def __sub__(self, other: typing.Any) -> ParityTensor:
105153
if isinstance(other, ParityTensor):
106154
self._validate_edge_compatibility(other)
107155
return ParityTensor(
108-
edges=self.edges,
109-
tensor=self.tensor - other.tensor,
156+
_edges=self._edges,
157+
_tensor=self._tensor - other._tensor,
158+
_parity=self._parity,
159+
_mask=self._mask,
110160
)
111161
try:
112-
result = self.tensor - other
162+
result = self._tensor - other
113163
except TypeError:
114164
return NotImplemented
115165
if isinstance(result, torch.Tensor):
116166
return ParityTensor(
117-
edges=self.edges,
118-
tensor=result,
167+
_edges=self._edges,
168+
_tensor=result,
169+
_parity=self._parity,
170+
_mask=self._mask,
119171
)
120172
return NotImplemented
121173

122174
def __rsub__(self, other: typing.Any) -> ParityTensor:
123175
try:
124-
result = other - self.tensor
176+
result = other - self._tensor
125177
except TypeError:
126178
return NotImplemented
127179
if isinstance(result, torch.Tensor):
128180
return ParityTensor(
129-
edges=self.edges,
130-
tensor=result,
181+
_edges=self._edges,
182+
_tensor=result,
183+
_parity=self._parity,
184+
_mask=self._mask,
131185
)
132186
return NotImplemented
133187

134188
def __isub__(self, other: typing.Any) -> ParityTensor:
135189
if isinstance(other, ParityTensor):
136190
self._validate_edge_compatibility(other)
137-
self.tensor -= other.tensor
191+
self._tensor -= other._tensor
138192
else:
139-
self.tensor -= other
193+
self._tensor -= other
140194
return self
141195

142196
def __mul__(self, other: typing.Any) -> ParityTensor:
143197
if isinstance(other, ParityTensor):
144198
self._validate_edge_compatibility(other)
145199
return ParityTensor(
146-
edges=self.edges,
147-
tensor=self.tensor * other.tensor,
200+
_edges=self._edges,
201+
_tensor=self._tensor * other._tensor,
202+
_parity=self._parity,
203+
_mask=self._mask,
148204
)
149205
try:
150-
result = self.tensor * other
206+
result = self._tensor * other
151207
except TypeError:
152208
return NotImplemented
153209
if isinstance(result, torch.Tensor):
154210
return ParityTensor(
155-
edges=self.edges,
156-
tensor=result,
211+
_edges=self._edges,
212+
_tensor=result,
213+
_parity=self._parity,
214+
_mask=self._mask,
157215
)
158216
return NotImplemented
159217

160218
def __rmul__(self, other: typing.Any) -> ParityTensor:
161219
try:
162-
result = other * self.tensor
220+
result = other * self._tensor
163221
except TypeError:
164222
return NotImplemented
165223
if isinstance(result, torch.Tensor):
166224
return ParityTensor(
167-
edges=self.edges,
168-
tensor=result,
225+
_edges=self._edges,
226+
_tensor=result,
227+
_parity=self._parity,
228+
_mask=self._mask,
169229
)
170230
return NotImplemented
171231

172232
def __imul__(self, other: typing.Any) -> ParityTensor:
173233
if isinstance(other, ParityTensor):
174234
self._validate_edge_compatibility(other)
175-
self.tensor *= other.tensor
235+
self._tensor *= other._tensor
176236
else:
177-
self.tensor *= other
237+
self._tensor *= other
178238
return self
179239

180240
def __truediv__(self, other: typing.Any) -> ParityTensor:
181241
if isinstance(other, ParityTensor):
182242
self._validate_edge_compatibility(other)
183243
return ParityTensor(
184-
edges=self.edges,
185-
tensor=self.tensor / other.tensor,
244+
_edges=self._edges,
245+
_tensor=self._tensor / other._tensor,
246+
_parity=self._parity,
247+
_mask=self._mask,
186248
)
187249
try:
188-
result = self.tensor / other
250+
result = self._tensor / other
189251
except TypeError:
190252
return NotImplemented
191253
if isinstance(result, torch.Tensor):
192254
return ParityTensor(
193-
edges=self.edges,
194-
tensor=result,
255+
_edges=self._edges,
256+
_tensor=result,
257+
_parity=self._parity,
258+
_mask=self._mask,
195259
)
196260
return NotImplemented
197261

198262
def __rtruediv__(self, other: typing.Any) -> ParityTensor:
199263
try:
200-
result = other / self.tensor
264+
result = other / self._tensor
201265
except TypeError:
202266
return NotImplemented
203267
if isinstance(result, torch.Tensor):
204268
return ParityTensor(
205-
edges=self.edges,
206-
tensor=result,
269+
_edges=self._edges,
270+
_tensor=result,
271+
_parity=self._parity,
272+
_mask=self._mask,
207273
)
208274
return NotImplemented
209275

210276
def __itruediv__(self, other: typing.Any) -> ParityTensor:
211277
if isinstance(other, ParityTensor):
212278
self._validate_edge_compatibility(other)
213-
self.tensor /= other.tensor
279+
self._tensor /= other._tensor
214280
else:
215-
self.tensor /= other
281+
self._tensor /= other
216282
return self

0 commit comments

Comments
 (0)