Skip to content

Commit a0ad401

Browse files
committed
Merge branch 'dev/use-replace-function'
2 parents c93d50a + 34a704e commit a0ad401

File tree

1 file changed

+32
-59
lines changed

1 file changed

+32
-59
lines changed

parity_tensor/parity_tensor.py

Lines changed: 32 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def to(self, device: torch.device) -> ParityTensor:
6060
"""
6161
Copy the tensor to a specified device.
6262
"""
63-
return ParityTensor(
64-
_edges=self._edges,
63+
return dataclasses.replace(
64+
self,
6565
_tensor=self._tensor.to(device),
6666
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
6767
_mask=self._mask.to(device) if self._mask is not None else None,
@@ -96,7 +96,8 @@ def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor:
9696
)
9797
tensor = torch.where(total_parity, -tensor, +tensor)
9898

99-
return ParityTensor(
99+
return dataclasses.replace(
100+
self,
100101
_edges=edges,
101102
_tensor=tensor,
102103
_parity=parity,
@@ -128,40 +129,32 @@ def _validate_edge_compatibility(self, other: ParityTensor) -> None:
128129
assert self._edges == other.edges, f"Edges must match for arithmetic operations. Got {self._edges} and {other.edges}."
129130

130131
def __pos__(self) -> ParityTensor:
131-
return ParityTensor(
132-
_edges=self._edges,
132+
return dataclasses.replace(
133+
self,
133134
_tensor=+self._tensor,
134-
_parity=self._parity,
135-
_mask=self._mask,
136135
)
137136

138137
def __neg__(self) -> ParityTensor:
139-
return ParityTensor(
140-
_edges=self._edges,
138+
return dataclasses.replace(
139+
self,
141140
_tensor=-self._tensor,
142-
_parity=self._parity,
143-
_mask=self._mask,
144141
)
145142

146143
def __add__(self, other: typing.Any) -> ParityTensor:
147144
if isinstance(other, ParityTensor):
148145
self._validate_edge_compatibility(other)
149-
return ParityTensor(
150-
_edges=self._edges,
146+
return dataclasses.replace(
147+
self,
151148
_tensor=self._tensor + other._tensor,
152-
_parity=self._parity,
153-
_mask=self._mask,
154149
)
155150
try:
156151
result = self._tensor + other
157152
except TypeError:
158153
return NotImplemented
159154
if isinstance(result, torch.Tensor):
160-
return ParityTensor(
161-
_edges=self._edges,
155+
return dataclasses.replace(
156+
self,
162157
_tensor=result,
163-
_parity=self._parity,
164-
_mask=self._mask,
165158
)
166159
return NotImplemented
167160

@@ -171,11 +164,9 @@ def __radd__(self, other: typing.Any) -> ParityTensor:
171164
except TypeError:
172165
return NotImplemented
173166
if isinstance(result, torch.Tensor):
174-
return ParityTensor(
175-
_edges=self._edges,
167+
return dataclasses.replace(
168+
self,
176169
_tensor=result,
177-
_parity=self._parity,
178-
_mask=self._mask,
179170
)
180171
return NotImplemented
181172

@@ -190,22 +181,18 @@ def __iadd__(self, other: typing.Any) -> ParityTensor:
190181
def __sub__(self, other: typing.Any) -> ParityTensor:
191182
if isinstance(other, ParityTensor):
192183
self._validate_edge_compatibility(other)
193-
return ParityTensor(
194-
_edges=self._edges,
184+
return dataclasses.replace(
185+
self,
195186
_tensor=self._tensor - other._tensor,
196-
_parity=self._parity,
197-
_mask=self._mask,
198187
)
199188
try:
200189
result = self._tensor - other
201190
except TypeError:
202191
return NotImplemented
203192
if isinstance(result, torch.Tensor):
204-
return ParityTensor(
205-
_edges=self._edges,
193+
return dataclasses.replace(
194+
self,
206195
_tensor=result,
207-
_parity=self._parity,
208-
_mask=self._mask,
209196
)
210197
return NotImplemented
211198

@@ -215,11 +202,9 @@ def __rsub__(self, other: typing.Any) -> ParityTensor:
215202
except TypeError:
216203
return NotImplemented
217204
if isinstance(result, torch.Tensor):
218-
return ParityTensor(
219-
_edges=self._edges,
205+
return dataclasses.replace(
206+
self,
220207
_tensor=result,
221-
_parity=self._parity,
222-
_mask=self._mask,
223208
)
224209
return NotImplemented
225210

@@ -234,22 +219,18 @@ def __isub__(self, other: typing.Any) -> ParityTensor:
234219
def __mul__(self, other: typing.Any) -> ParityTensor:
235220
if isinstance(other, ParityTensor):
236221
self._validate_edge_compatibility(other)
237-
return ParityTensor(
238-
_edges=self._edges,
222+
return dataclasses.replace(
223+
self,
239224
_tensor=self._tensor * other._tensor,
240-
_parity=self._parity,
241-
_mask=self._mask,
242225
)
243226
try:
244227
result = self._tensor * other
245228
except TypeError:
246229
return NotImplemented
247230
if isinstance(result, torch.Tensor):
248-
return ParityTensor(
249-
_edges=self._edges,
231+
return dataclasses.replace(
232+
self,
250233
_tensor=result,
251-
_parity=self._parity,
252-
_mask=self._mask,
253234
)
254235
return NotImplemented
255236

@@ -259,11 +240,9 @@ def __rmul__(self, other: typing.Any) -> ParityTensor:
259240
except TypeError:
260241
return NotImplemented
261242
if isinstance(result, torch.Tensor):
262-
return ParityTensor(
263-
_edges=self._edges,
243+
return dataclasses.replace(
244+
self,
264245
_tensor=result,
265-
_parity=self._parity,
266-
_mask=self._mask,
267246
)
268247
return NotImplemented
269248

@@ -278,22 +257,18 @@ def __imul__(self, other: typing.Any) -> ParityTensor:
278257
def __truediv__(self, other: typing.Any) -> ParityTensor:
279258
if isinstance(other, ParityTensor):
280259
self._validate_edge_compatibility(other)
281-
return ParityTensor(
282-
_edges=self._edges,
260+
return dataclasses.replace(
261+
self,
283262
_tensor=self._tensor / other._tensor,
284-
_parity=self._parity,
285-
_mask=self._mask,
286263
)
287264
try:
288265
result = self._tensor / other
289266
except TypeError:
290267
return NotImplemented
291268
if isinstance(result, torch.Tensor):
292-
return ParityTensor(
293-
_edges=self._edges,
269+
return dataclasses.replace(
270+
self,
294271
_tensor=result,
295-
_parity=self._parity,
296-
_mask=self._mask,
297272
)
298273
return NotImplemented
299274

@@ -303,11 +278,9 @@ def __rtruediv__(self, other: typing.Any) -> ParityTensor:
303278
except TypeError:
304279
return NotImplemented
305280
if isinstance(result, torch.Tensor):
306-
return ParityTensor(
307-
_edges=self._edges,
281+
return dataclasses.replace(
282+
self,
308283
_tensor=result,
309-
_parity=self._parity,
310-
_mask=self._mask,
311284
)
312285
return NotImplemented
313286

0 commit comments

Comments
 (0)