@@ -63,6 +63,33 @@ def update_mask(self) -> ParityTensor:
6363 self ._tensor = torch .where (self .mask , self ._tensor , 0 )
6464 return self
6565
66+ def permute (self , before_by_after : tuple [int , ...]) -> ParityTensor :
67+ """
68+ Permute the indices of the parity tensor.
69+ """
70+ edges = tuple (self .edges [i ] for i in before_by_after )
71+ tensor = self .tensor .permute (before_by_after )
72+ parity = tuple (self .parity [i ] for i in before_by_after )
73+ mask = self .mask .permute (before_by_after )
74+
75+ total_parity = functools .reduce (
76+ torch .logical_xor ,
77+ (
78+ torch .logical_and (parity [i ], parity [j ])
79+ for j in range (self .tensor .dim ())
80+ for i in range (0 , j ) # all 0 <= i < j < dim
81+ if before_by_after [i ] > before_by_after [j ]),
82+ torch .zeros ([], dtype = torch .bool ),
83+ )
84+ tensor = torch .where (total_parity , - tensor , + tensor )
85+
86+ return ParityTensor (
87+ _edges = edges ,
88+ _tensor = tensor ,
89+ _parity = parity ,
90+ _mask = mask ,
91+ )
92+
6693 def __post_init__ (self ) -> None :
6794 assert len (self ._edges ) == self ._tensor .dim (), f"Edges length ({ len (self ._edges )} ) must match tensor dimensions ({ self ._tensor .dim ()} )."
6895 for dim , (even , odd ) in zip (self ._tensor .shape , self ._edges ):
0 commit comments