@@ -569,7 +569,7 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
569569 _tensor = tensor ,
570570 )
571571
572- def _group_legs (
572+ def _group_edges (
573573 self ,
574574 left_legs : typing .Iterable [int ],
575575 ) -> tuple [GrassmannTensor , tuple [int , ...], tuple [int , ...]]:
@@ -615,7 +615,7 @@ def svd(
615615 if isinstance (cutoff , tuple ):
616616 assert len (cutoff ) == 2 , "The length of cutoff must be 2 if cutoff is a tuple."
617617
618- tensor , left_legs , right_legs = self ._group_legs (free_names_u )
618+ tensor , left_legs , right_legs = self ._group_edges (free_names_u )
619619
620620 (even_left , odd_left ) = tensor .edges [0 ]
621621 (even_right , odd_right ) = tensor .edges [1 ]
@@ -718,12 +718,18 @@ def svd(
718718
719719 return U , S , Vh
720720
721+ def _get_inv_order (self , order : tuple [int , ...]) -> tuple [int , ...]:
722+ inv = [0 ] * self .tensor .dim ()
723+ for new_pos , orig_idx in enumerate (order ):
724+ inv [orig_idx ] = new_pos
725+ return tuple (inv )
726+
721727 def exponential (self , pairs : tuple [int , ...]) -> GrassmannTensor :
722- tensor , left_legs , right_legs = self ._group_legs (pairs )
728+ tensor , left_legs , right_legs = self ._group_edges (pairs )
723729
724- axes_to_reverse = [i for i in range (2 ) if tensor .arrow [i ]]
725- if axes_to_reverse :
726- tensor = tensor .reverse (tuple (axes_to_reverse ))
730+ edges_to_reverse = [i for i in range (2 ) if tensor .arrow [i ]]
731+ if edges_to_reverse :
732+ tensor = tensor .reverse (tuple (edges_to_reverse ))
727733
728734 left_dim , right_dim = tensor .tensor .shape
729735
@@ -744,17 +750,14 @@ def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor:
744750
745751 tensor_exp = dataclasses .replace (tensor , _tensor = tensor_exp )
746752
747- if axes_to_reverse :
748- tensor_exp = tensor_exp .reverse (tuple (axes_to_reverse ))
753+ if edges_to_reverse :
754+ tensor_exp = tensor_exp .reverse (tuple (edges_to_reverse ))
749755
750756 order = left_legs + right_legs
751757 edges_after_permute = tuple (self .edges [i ] for i in order )
752758 tensor_exp = tensor_exp .reshape (edges_after_permute )
753759
754- inv = [0 ] * self .tensor .dim ()
755- for new_pos , orig_idx in enumerate (order ):
756- inv [orig_idx ] = new_pos
757- inv_order = tuple (inv )
760+ inv_order = self ._get_inv_order (order )
758761
759762 tensor_exp = tensor_exp .permute (inv_order )
760763
0 commit comments