@@ -19,16 +19,54 @@ class ParityTensor:
1919 Each dimension of the tensor is composed of an even and an odd part, represented as a pair of integers.
2020 """
2121
22- edges : tuple [tuple [int , int ], ...]
23- tensor : torch .Tensor
24- mask : torch .Tensor | None = None
22+ _edges : tuple [tuple [int , int ], ...]
23+ _tensor : torch .Tensor
24+ _parity : tuple [torch .Tensor , ...] | None = None
25+ _mask : torch .Tensor | None = None
26+
27+ @property
28+ def edges (self ) -> tuple [tuple [int , int ], ...]:
29+ """
30+ The edges of the tensor, represented as a tuple of pairs (even, odd).
31+ """
32+ return self ._edges
33+
34+ @property
35+ def tensor (self ) -> torch .Tensor :
36+ """
37+ The underlying tensor data.
38+ """
39+ return self ._tensor
40+
41+ @property
42+ def parity (self ) -> tuple [torch .Tensor , ...]:
43+ """
44+ The parity of each edge, represented as a tuple of tensors.
45+ """
46+ if self ._parity is None :
47+ self ._parity = tuple (self ._edge_mask (even , odd ) for (even , odd ) in self ._edges )
48+ return self ._parity
49+
50+ @property
51+ def mask (self ) -> torch .Tensor :
52+ """
53+ The mask of the tensor, which has the same shape as the tensor and indicates which elements could be non-zero based on the parity.
54+ """
55+ if self ._mask is None :
56+ self ._mask = self ._tensor_mask ()
57+ return self ._mask
58+
59+ def update_mask (self ) -> ParityTensor :
60+ """
61+ Update the mask of the tensor based on its parity.
62+ """
63+ self ._tensor = torch .where (self .mask , self ._tensor , 0 )
64+ return self
2565
2666 def __post_init__ (self ) -> None :
27- assert len (self .edges ) == self .tensor .dim (), f"Edges length ({ len (self .edges )} ) must match tensor dimensions ({ self .tensor .dim ()} )."
28- for dim , (even , odd ) in zip (self .tensor .shape , self .edges ):
67+ assert len (self ._edges ) == self ._tensor .dim (), f"Edges length ({ len (self ._edges )} ) must match tensor dimensions ({ self ._tensor .dim ()} )."
68+ for dim , (even , odd ) in zip (self ._tensor .shape , self ._edges ):
2969 assert even >= 0 and odd >= 0 and dim == even + odd , f"Dimension { dim } must equal sum of even ({ even } ) and odd ({ odd } ) parts, and both must be non-negative."
30- if self .mask is None :
31- self .mask = self ._tensor_mask ()
3270
3371 @classmethod
3472 def _unqueeze (cls , tensor : torch .Tensor , index : int , dim : int ) -> torch .Tensor :
@@ -41,176 +79,204 @@ def _edge_mask(cls, even: int, odd: int) -> torch.Tensor:
4179 def _tensor_mask (self ) -> torch .Tensor :
4280 return functools .reduce (
4381 torch .logical_xor ,
44- (self ._unqueeze (self . _edge_mask ( even , odd ), index , self .tensor .dim ()) for index , ( even , odd ) in enumerate (self .edges )),
45- torch .ones_like (self .tensor , dtype = torch .bool ),
82+ (self ._unqueeze (parity , index , self ._tensor .dim ()) for index , parity in enumerate (self .parity )),
83+ torch .ones_like (self ._tensor , dtype = torch .bool ),
4684 )
4785
4886 def _validate_edge_compatibility (self , other : ParityTensor ) -> None :
4987 """
5088 Validate that the edges of two ParityTensor instances are compatible for arithmetic operations.
5189 """
52- assert self .edges == other .edges , f"Edges must match for arithmetic operations. Got { self .edges } and { other .edges } ."
90+ assert self ._edges == other .edges , f"Edges must match for arithmetic operations. Got { self ._edges } and { other .edges } ."
5391
5492 def __pos__ (self ) -> ParityTensor :
5593 return ParityTensor (
56- edges = self .edges ,
57- tensor = + self .tensor ,
94+ _edges = self ._edges ,
95+ _tensor = + self ._tensor ,
96+ _parity = self ._parity ,
97+ _mask = self ._mask ,
5898 )
5999
60100 def __neg__ (self ) -> ParityTensor :
61101 return ParityTensor (
62- edges = self .edges ,
63- tensor = - self .tensor ,
102+ _edges = self ._edges ,
103+ _tensor = - self ._tensor ,
104+ _parity = self ._parity ,
105+ _mask = self ._mask ,
64106 )
65107
66108 def __add__ (self , other : typing .Any ) -> ParityTensor :
67109 if isinstance (other , ParityTensor ):
68110 self ._validate_edge_compatibility (other )
69111 return ParityTensor (
70- edges = self .edges ,
71- tensor = self .tensor + other .tensor ,
112+ _edges = self ._edges ,
113+ _tensor = self ._tensor + other ._tensor ,
114+ _parity = self ._parity ,
115+ _mask = self ._mask ,
72116 )
73117 try :
74- result = self .tensor + other
118+ result = self ._tensor + other
75119 except TypeError :
76120 return NotImplemented
77121 if isinstance (result , torch .Tensor ):
78122 return ParityTensor (
79- edges = self .edges ,
80- tensor = result ,
123+ _edges = self ._edges ,
124+ _tensor = result ,
125+ _parity = self ._parity ,
126+ _mask = self ._mask ,
81127 )
82128 return NotImplemented
83129
84130 def __radd__ (self , other : typing .Any ) -> ParityTensor :
85131 try :
86- result = other + self .tensor
132+ result = other + self ._tensor
87133 except TypeError :
88134 return NotImplemented
89135 if isinstance (result , torch .Tensor ):
90136 return ParityTensor (
91- edges = self .edges ,
92- tensor = result ,
137+ _edges = self ._edges ,
138+ _tensor = result ,
139+ _parity = self ._parity ,
140+ _mask = self ._mask ,
93141 )
94142 return NotImplemented
95143
96144 def __iadd__ (self , other : typing .Any ) -> ParityTensor :
97145 if isinstance (other , ParityTensor ):
98146 self ._validate_edge_compatibility (other )
99- self .tensor += other .tensor
147+ self ._tensor += other ._tensor
100148 else :
101- self .tensor += other
149+ self ._tensor += other
102150 return self
103151
104152 def __sub__ (self , other : typing .Any ) -> ParityTensor :
105153 if isinstance (other , ParityTensor ):
106154 self ._validate_edge_compatibility (other )
107155 return ParityTensor (
108- edges = self .edges ,
109- tensor = self .tensor - other .tensor ,
156+ _edges = self ._edges ,
157+ _tensor = self ._tensor - other ._tensor ,
158+ _parity = self ._parity ,
159+ _mask = self ._mask ,
110160 )
111161 try :
112- result = self .tensor - other
162+ result = self ._tensor - other
113163 except TypeError :
114164 return NotImplemented
115165 if isinstance (result , torch .Tensor ):
116166 return ParityTensor (
117- edges = self .edges ,
118- tensor = result ,
167+ _edges = self ._edges ,
168+ _tensor = result ,
169+ _parity = self ._parity ,
170+ _mask = self ._mask ,
119171 )
120172 return NotImplemented
121173
122174 def __rsub__ (self , other : typing .Any ) -> ParityTensor :
123175 try :
124- result = other - self .tensor
176+ result = other - self ._tensor
125177 except TypeError :
126178 return NotImplemented
127179 if isinstance (result , torch .Tensor ):
128180 return ParityTensor (
129- edges = self .edges ,
130- tensor = result ,
181+ _edges = self ._edges ,
182+ _tensor = result ,
183+ _parity = self ._parity ,
184+ _mask = self ._mask ,
131185 )
132186 return NotImplemented
133187
134188 def __isub__ (self , other : typing .Any ) -> ParityTensor :
135189 if isinstance (other , ParityTensor ):
136190 self ._validate_edge_compatibility (other )
137- self .tensor -= other .tensor
191+ self ._tensor -= other ._tensor
138192 else :
139- self .tensor -= other
193+ self ._tensor -= other
140194 return self
141195
142196 def __mul__ (self , other : typing .Any ) -> ParityTensor :
143197 if isinstance (other , ParityTensor ):
144198 self ._validate_edge_compatibility (other )
145199 return ParityTensor (
146- edges = self .edges ,
147- tensor = self .tensor * other .tensor ,
200+ _edges = self ._edges ,
201+ _tensor = self ._tensor * other ._tensor ,
202+ _parity = self ._parity ,
203+ _mask = self ._mask ,
148204 )
149205 try :
150- result = self .tensor * other
206+ result = self ._tensor * other
151207 except TypeError :
152208 return NotImplemented
153209 if isinstance (result , torch .Tensor ):
154210 return ParityTensor (
155- edges = self .edges ,
156- tensor = result ,
211+ _edges = self ._edges ,
212+ _tensor = result ,
213+ _parity = self ._parity ,
214+ _mask = self ._mask ,
157215 )
158216 return NotImplemented
159217
160218 def __rmul__ (self , other : typing .Any ) -> ParityTensor :
161219 try :
162- result = other * self .tensor
220+ result = other * self ._tensor
163221 except TypeError :
164222 return NotImplemented
165223 if isinstance (result , torch .Tensor ):
166224 return ParityTensor (
167- edges = self .edges ,
168- tensor = result ,
225+ _edges = self ._edges ,
226+ _tensor = result ,
227+ _parity = self ._parity ,
228+ _mask = self ._mask ,
169229 )
170230 return NotImplemented
171231
172232 def __imul__ (self , other : typing .Any ) -> ParityTensor :
173233 if isinstance (other , ParityTensor ):
174234 self ._validate_edge_compatibility (other )
175- self .tensor *= other .tensor
235+ self ._tensor *= other ._tensor
176236 else :
177- self .tensor *= other
237+ self ._tensor *= other
178238 return self
179239
180240 def __truediv__ (self , other : typing .Any ) -> ParityTensor :
181241 if isinstance (other , ParityTensor ):
182242 self ._validate_edge_compatibility (other )
183243 return ParityTensor (
184- edges = self .edges ,
185- tensor = self .tensor / other .tensor ,
244+ _edges = self ._edges ,
245+ _tensor = self ._tensor / other ._tensor ,
246+ _parity = self ._parity ,
247+ _mask = self ._mask ,
186248 )
187249 try :
188- result = self .tensor / other
250+ result = self ._tensor / other
189251 except TypeError :
190252 return NotImplemented
191253 if isinstance (result , torch .Tensor ):
192254 return ParityTensor (
193- edges = self .edges ,
194- tensor = result ,
255+ _edges = self ._edges ,
256+ _tensor = result ,
257+ _parity = self ._parity ,
258+ _mask = self ._mask ,
195259 )
196260 return NotImplemented
197261
198262 def __rtruediv__ (self , other : typing .Any ) -> ParityTensor :
199263 try :
200- result = other / self .tensor
264+ result = other / self ._tensor
201265 except TypeError :
202266 return NotImplemented
203267 if isinstance (result , torch .Tensor ):
204268 return ParityTensor (
205- edges = self .edges ,
206- tensor = result ,
269+ _edges = self ._edges ,
270+ _tensor = result ,
271+ _parity = self ._parity ,
272+ _mask = self ._mask ,
207273 )
208274 return NotImplemented
209275
210276 def __itruediv__ (self , other : typing .Any ) -> ParityTensor :
211277 if isinstance (other , ParityTensor ):
212278 self ._validate_edge_compatibility (other )
213- self .tensor /= other .tensor
279+ self ._tensor /= other ._tensor
214280 else :
215- self .tensor /= other
281+ self ._tensor /= other
216282 return self
0 commit comments