@@ -569,6 +569,27 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
569569 _tensor = tensor ,
570570 )
571571
572+ def _group_edges (
573+ self ,
574+ left_legs : typing .Iterable [int ],
575+ ) -> tuple [GrassmannTensor , tuple [int , ...], tuple [int , ...]]:
576+ left_legs = tuple (int (i ) for i in left_legs )
577+ right_legs = tuple (i for i in range (self .tensor .dim ()) if i not in left_legs )
578+ assert set (left_legs ) | set (right_legs ) == set (range (self .tensor .dim ())), (
579+ "Left/right must cover all tensor legs."
580+ )
581+
582+ order = left_legs + right_legs
583+
584+ tensor = self .permute (order )
585+
586+ left_dim = math .prod (tensor .tensor .shape [: len (left_legs )])
587+ right_dim = math .prod (tensor .tensor .shape [len (left_legs ) :])
588+
589+ tensor = tensor .reshape ((left_dim , right_dim ))
590+
591+ return tensor , left_legs , right_legs
592+
572593 def svd (
573594 self ,
574595 free_names_u : tuple [int , ...],
@@ -591,22 +612,10 @@ def svd(
591612 Furthermore, if the distance between any two singular values is close to zero, the gradient
592613 will be numerically unstable, as it depends on the singular values
593614 """
594- left_legs = tuple (int (i ) for i in free_names_u )
595- right_legs = tuple (i for i in range (self .tensor .dim ()) if i not in left_legs )
596- assert set (left_legs ) | set (right_legs ) == set (range (self .tensor .dim ())), (
597- "Left/right must cover all tensor legs."
598- )
599-
600615 if isinstance (cutoff , tuple ):
601616 assert len (cutoff ) == 2 , "The length of cutoff must be 2 if cutoff is a tuple."
602617
603- order = left_legs + right_legs
604- tensor = self .permute (order )
605-
606- left_dim = math .prod (tensor .tensor .shape [: len (left_legs )])
607- right_dim = math .prod (tensor .tensor .shape [len (left_legs ) :])
608-
609- tensor = tensor .reshape ((left_dim , right_dim ))
618+ tensor , left_legs , right_legs = self ._group_edges (free_names_u )
610619
611620 (even_left , odd_left ) = tensor .edges [0 ]
612621 (even_right , odd_right ) = tensor .edges [1 ]
@@ -709,6 +718,51 @@ def svd(
709718
710719 return U , S , Vh
711720
721+ def _get_inv_order (self , order : tuple [int , ...]) -> tuple [int , ...]:
722+ inv = [0 ] * self .tensor .dim ()
723+ for new_position , origin_idx in enumerate (order ):
724+ inv [origin_idx ] = new_position
725+ return tuple (inv )
726+
727+ def exponential (self , pairs : tuple [int , ...]) -> GrassmannTensor :
728+ tensor , left_legs , right_legs = self ._group_edges (pairs )
729+
730+ edges_to_reverse = [i for i in range (2 ) if tensor .arrow [i ]]
731+ if edges_to_reverse :
732+ tensor = tensor .reverse (tuple (edges_to_reverse ))
733+
734+ left_dim , right_dim = tensor .tensor .shape
735+
736+ assert left_dim == right_dim , (
737+ f"Exponential requires a square operator, but got { left_dim } x { right_dim } ."
738+ )
739+
740+ (even_left , odd_left ) = tensor .edges [0 ]
741+ (even_right , odd_right ) = tensor .edges [1 ]
742+
743+ even_tensor = tensor .tensor [:even_left , :even_right ]
744+ odd_tensor = tensor .tensor [even_left :, even_right :]
745+
746+ even_tensor_exp = torch .linalg .matrix_exp (even_tensor )
747+ odd_tensor_exp = torch .linalg .matrix_exp (odd_tensor )
748+
749+ tensor_exp = torch .block_diag (even_tensor_exp , odd_tensor_exp ) # type: ignore[no-untyped-call]
750+
751+ tensor_exp = dataclasses .replace (tensor , _tensor = tensor_exp )
752+
753+ if edges_to_reverse :
754+ tensor_exp = tensor_exp .reverse (tuple (edges_to_reverse ))
755+
756+ order = left_legs + right_legs
757+ edges_after_permute = tuple (self .edges [i ] for i in order )
758+ tensor_exp = tensor_exp .reshape (edges_after_permute )
759+
760+ inv_order = self ._get_inv_order (order )
761+
762+ tensor_exp = tensor_exp .permute (inv_order )
763+
764+ return tensor_exp
765+
712766 def __post_init__ (self ) -> None :
713767 assert len (self ._arrow ) == self ._tensor .dim (), (
714768 f"Arrow length ({ len (self ._arrow )} ) must match tensor dimensions ({ self ._tensor .dim ()} )."
0 commit comments