Skip to content

Commit 1d4f421

Browse files
committed
Merge branch 'dev/permute'
2 parents 48f6df6 + d02bdc5 commit 1d4f421

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

parity_tensor/parity_tensor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,35 @@ def update_mask(self) -> ParityTensor:
6363
self._tensor = torch.where(self.mask, self._tensor, 0)
6464
return self
6565

66+
def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor:
67+
"""
68+
Permute the indices of the parity tensor.
69+
"""
70+
assert set(before_by_after) == set(range(self.tensor.dim())), "Permutation indices must cover all dimensions."
71+
72+
edges = tuple(self.edges[i] for i in before_by_after)
73+
tensor = self.tensor.permute(before_by_after)
74+
parity = tuple(self.parity[i] for i in before_by_after)
75+
mask = self.mask.permute(before_by_after)
76+
77+
total_parity = functools.reduce(
78+
torch.logical_xor,
79+
(
80+
torch.logical_and(parity[i], parity[j])
81+
for j in range(self.tensor.dim())
82+
for i in range(0, j) # all 0 <= i < j < dim
83+
if before_by_after[i] > before_by_after[j]),
84+
torch.zeros([], dtype=torch.bool),
85+
)
86+
tensor = torch.where(total_parity, -tensor, +tensor)
87+
88+
return ParityTensor(
89+
_edges=edges,
90+
_tensor=tensor,
91+
_parity=parity,
92+
_mask=mask,
93+
)
94+
6695
def __post_init__(self) -> None:
6796
assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})."
6897
for dim, (even, odd) in zip(self._tensor.shape, self._edges):

0 commit comments

Comments
 (0)