@@ -63,6 +63,35 @@ def update_mask(self) -> ParityTensor:
6363 self ._tensor = torch .where (self .mask , self ._tensor , 0 )
6464 return self
6565
66+ def permute (self , before_by_after : tuple [int , ...]) -> ParityTensor :
67+ """
68+ Permute the indices of the parity tensor.
69+ """
70+ assert set (before_by_after ) == set (range (self .tensor .dim ())), "Permutation indices must cover all dimensions."
71+
72+ edges = tuple (self .edges [i ] for i in before_by_after )
73+ tensor = self .tensor .permute (before_by_after )
74+ parity = tuple (self .parity [i ] for i in before_by_after )
75+ mask = self .mask .permute (before_by_after )
76+
77+ total_parity = functools .reduce (
78+ torch .logical_xor ,
79+ (
80+ torch .logical_and (parity [i ], parity [j ])
81+ for j in range (self .tensor .dim ())
82+ for i in range (0 , j ) # all 0 <= i < j < dim
83+ if before_by_after [i ] > before_by_after [j ]),
84+ torch .zeros ([], dtype = torch .bool ),
85+ )
86+ tensor = torch .where (total_parity , - tensor , + tensor )
87+
88+ return ParityTensor (
89+ _edges = edges ,
90+ _tensor = tensor ,
91+ _parity = parity ,
92+ _mask = mask ,
93+ )
94+
6695 def __post_init__ (self ) -> None :
6796 assert len (self ._edges ) == self ._tensor .dim (), f"Edges length ({ len (self ._edges )} ) must match tensor dimensions ({ self ._tensor .dim ()} )."
6897 for dim , (even , odd ) in zip (self ._tensor .shape , self ._edges ):
0 commit comments