@@ -60,8 +60,8 @@ def to(self, device: torch.device) -> ParityTensor:
6060 """
6161 Copy the tensor to a specified device.
6262 """
63- return ParityTensor (
64- _edges = self . _edges ,
63+ return dataclasses . replace (
64+ self ,
6565 _tensor = self ._tensor .to (device ),
6666 _parity = tuple (p .to (device ) for p in self ._parity ) if self ._parity is not None else None ,
6767 _mask = self ._mask .to (device ) if self ._mask is not None else None ,
@@ -96,7 +96,8 @@ def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor:
9696 )
9797 tensor = torch .where (total_parity , - tensor , + tensor )
9898
99- return ParityTensor (
99+ return dataclasses .replace (
100+ self ,
100101 _edges = edges ,
101102 _tensor = tensor ,
102103 _parity = parity ,
@@ -128,40 +129,32 @@ def _validate_edge_compatibility(self, other: ParityTensor) -> None:
128129 assert self ._edges == other .edges , f"Edges must match for arithmetic operations. Got { self ._edges } and { other .edges } ."
129130
130131 def __pos__ (self ) -> ParityTensor :
131- return ParityTensor (
132- _edges = self . _edges ,
132+ return dataclasses . replace (
133+ self ,
133134 _tensor = + self ._tensor ,
134- _parity = self ._parity ,
135- _mask = self ._mask ,
136135 )
137136
138137 def __neg__ (self ) -> ParityTensor :
139- return ParityTensor (
140- _edges = self . _edges ,
138+ return dataclasses . replace (
139+ self ,
141140 _tensor = - self ._tensor ,
142- _parity = self ._parity ,
143- _mask = self ._mask ,
144141 )
145142
146143 def __add__ (self , other : typing .Any ) -> ParityTensor :
147144 if isinstance (other , ParityTensor ):
148145 self ._validate_edge_compatibility (other )
149- return ParityTensor (
150- _edges = self . _edges ,
146+ return dataclasses . replace (
147+ self ,
151148 _tensor = self ._tensor + other ._tensor ,
152- _parity = self ._parity ,
153- _mask = self ._mask ,
154149 )
155150 try :
156151 result = self ._tensor + other
157152 except TypeError :
158153 return NotImplemented
159154 if isinstance (result , torch .Tensor ):
160- return ParityTensor (
161- _edges = self . _edges ,
155+ return dataclasses . replace (
156+ self ,
162157 _tensor = result ,
163- _parity = self ._parity ,
164- _mask = self ._mask ,
165158 )
166159 return NotImplemented
167160
@@ -171,11 +164,9 @@ def __radd__(self, other: typing.Any) -> ParityTensor:
171164 except TypeError :
172165 return NotImplemented
173166 if isinstance (result , torch .Tensor ):
174- return ParityTensor (
175- _edges = self . _edges ,
167+ return dataclasses . replace (
168+ self ,
176169 _tensor = result ,
177- _parity = self ._parity ,
178- _mask = self ._mask ,
179170 )
180171 return NotImplemented
181172
@@ -190,22 +181,18 @@ def __iadd__(self, other: typing.Any) -> ParityTensor:
190181 def __sub__ (self , other : typing .Any ) -> ParityTensor :
191182 if isinstance (other , ParityTensor ):
192183 self ._validate_edge_compatibility (other )
193- return ParityTensor (
194- _edges = self . _edges ,
184+ return dataclasses . replace (
185+ self ,
195186 _tensor = self ._tensor - other ._tensor ,
196- _parity = self ._parity ,
197- _mask = self ._mask ,
198187 )
199188 try :
200189 result = self ._tensor - other
201190 except TypeError :
202191 return NotImplemented
203192 if isinstance (result , torch .Tensor ):
204- return ParityTensor (
205- _edges = self . _edges ,
193+ return dataclasses . replace (
194+ self ,
206195 _tensor = result ,
207- _parity = self ._parity ,
208- _mask = self ._mask ,
209196 )
210197 return NotImplemented
211198
@@ -215,11 +202,9 @@ def __rsub__(self, other: typing.Any) -> ParityTensor:
215202 except TypeError :
216203 return NotImplemented
217204 if isinstance (result , torch .Tensor ):
218- return ParityTensor (
219- _edges = self . _edges ,
205+ return dataclasses . replace (
206+ self ,
220207 _tensor = result ,
221- _parity = self ._parity ,
222- _mask = self ._mask ,
223208 )
224209 return NotImplemented
225210
@@ -234,22 +219,18 @@ def __isub__(self, other: typing.Any) -> ParityTensor:
234219 def __mul__ (self , other : typing .Any ) -> ParityTensor :
235220 if isinstance (other , ParityTensor ):
236221 self ._validate_edge_compatibility (other )
237- return ParityTensor (
238- _edges = self . _edges ,
222+ return dataclasses . replace (
223+ self ,
239224 _tensor = self ._tensor * other ._tensor ,
240- _parity = self ._parity ,
241- _mask = self ._mask ,
242225 )
243226 try :
244227 result = self ._tensor * other
245228 except TypeError :
246229 return NotImplemented
247230 if isinstance (result , torch .Tensor ):
248- return ParityTensor (
249- _edges = self . _edges ,
231+ return dataclasses . replace (
232+ self ,
250233 _tensor = result ,
251- _parity = self ._parity ,
252- _mask = self ._mask ,
253234 )
254235 return NotImplemented
255236
@@ -259,11 +240,9 @@ def __rmul__(self, other: typing.Any) -> ParityTensor:
259240 except TypeError :
260241 return NotImplemented
261242 if isinstance (result , torch .Tensor ):
262- return ParityTensor (
263- _edges = self . _edges ,
243+ return dataclasses . replace (
244+ self ,
264245 _tensor = result ,
265- _parity = self ._parity ,
266- _mask = self ._mask ,
267246 )
268247 return NotImplemented
269248
@@ -278,22 +257,18 @@ def __imul__(self, other: typing.Any) -> ParityTensor:
278257 def __truediv__ (self , other : typing .Any ) -> ParityTensor :
279258 if isinstance (other , ParityTensor ):
280259 self ._validate_edge_compatibility (other )
281- return ParityTensor (
282- _edges = self . _edges ,
260+ return dataclasses . replace (
261+ self ,
283262 _tensor = self ._tensor / other ._tensor ,
284- _parity = self ._parity ,
285- _mask = self ._mask ,
286263 )
287264 try :
288265 result = self ._tensor / other
289266 except TypeError :
290267 return NotImplemented
291268 if isinstance (result , torch .Tensor ):
292- return ParityTensor (
293- _edges = self . _edges ,
269+ return dataclasses . replace (
270+ self ,
294271 _tensor = result ,
295- _parity = self ._parity ,
296- _mask = self ._mask ,
297272 )
298273 return NotImplemented
299274
@@ -303,11 +278,9 @@ def __rtruediv__(self, other: typing.Any) -> ParityTensor:
303278 except TypeError :
304279 return NotImplemented
305280 if isinstance (result , torch .Tensor ):
306- return ParityTensor (
307- _edges = self . _edges ,
281+ return dataclasses . replace (
282+ self ,
308283 _tensor = result ,
309- _parity = self ._parity ,
310- _mask = self ._mask ,
311284 )
312285 return NotImplemented
313286
0 commit comments