Skip to content

Commit 288a5d9

Browse files
committed
Add support to permute a tensor.
1 parent 48f6df6 commit 288a5d9

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

parity_tensor/parity_tensor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,33 @@ def update_mask(self) -> ParityTensor:
6363
self._tensor = torch.where(self.mask, self._tensor, 0)
6464
return self
6565

66+
def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor:
67+
"""
68+
Permute the indices of the parity tensor.
69+
"""
70+
edges = tuple(self.edges[i] for i in before_by_after)
71+
tensor = self.tensor.permute(before_by_after)
72+
parity = tuple(self.parity[i] for i in before_by_after)
73+
mask = self.mask.permute(before_by_after)
74+
75+
total_parity = functools.reduce(
76+
torch.logical_xor,
77+
(
78+
torch.logical_and(parity[i], parity[j])
79+
for j in range(self.tensor.dim())
80+
for i in range(0, j) # all 0 <= i < j < dim
81+
if before_by_after[i] > before_by_after[j]),
82+
torch.zeros([], dtype=torch.bool),
83+
)
84+
tensor = torch.where(total_parity, -tensor, +tensor)
85+
86+
return ParityTensor(
87+
_edges=edges,
88+
_tensor=tensor,
89+
_parity=parity,
90+
_mask=mask,
91+
)
92+
6693
def __post_init__(self) -> None:
6794
assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})."
6895
for dim, (even, odd) in zip(self._tensor.shape, self._edges):

0 commit comments

Comments
 (0)