@@ -569,6 +569,27 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
569569 _tensor = tensor ,
570570 )
571571
572+ def _group_legs (
573+ self ,
574+ left_legs : typing .Iterable [int ],
575+ ) -> tuple [GrassmannTensor , tuple [int , ...], tuple [int , ...]]:
576+ left_legs = tuple (int (i ) for i in left_legs )
577+ right_legs = tuple (i for i in range (self .tensor .dim ()) if i not in left_legs )
578+ assert set (left_legs ) | set (right_legs ) == set (range (self .tensor .dim ())), (
579+ "Left/right must cover all tensor legs."
580+ )
581+
582+ order = left_legs + right_legs
583+
584+ tensor = self .permute (order )
585+
586+ left_dim = math .prod (tensor .tensor .shape [: len (left_legs )])
587+ right_dim = math .prod (tensor .tensor .shape [len (left_legs ) :])
588+
589+ tensor = tensor .reshape ((left_dim , right_dim ))
590+
591+ return tensor , left_legs , right_legs
592+
572593 def svd (
573594 self ,
574595 free_names_u : tuple [int , ...],
@@ -591,22 +612,10 @@ def svd(
591612 Furthermore, if the distance between any two singular values is close to zero, the gradient
592613 will be numerically unstable, as it depends on the singular values
593614 """
594- left_legs = tuple (int (i ) for i in free_names_u )
595- right_legs = tuple (i for i in range (self .tensor .dim ()) if i not in left_legs )
596- assert set (left_legs ) | set (right_legs ) == set (range (self .tensor .dim ())), (
597- "Left/right must cover all tensor legs."
598- )
599-
600615 if isinstance (cutoff , tuple ):
601616 assert len (cutoff ) == 2 , "The length of cutoff must be 2 if cutoff is a tuple."
602617
603- order = left_legs + right_legs
604- tensor = self .permute (order )
605-
606- left_dim = math .prod (tensor .tensor .shape [: len (left_legs )])
607- right_dim = math .prod (tensor .tensor .shape [len (left_legs ) :])
608-
609- tensor = tensor .reshape ((left_dim , right_dim ))
618+ tensor , left_legs , right_legs = self ._group_legs (free_names_u )
610619
611620 (even_left , odd_left ) = tensor .edges [0 ]
612621 (even_right , odd_right ) = tensor .edges [1 ]
@@ -709,6 +718,48 @@ def svd(
709718
710719 return U , S , Vh
711720
721+ def exponential (self , pairs : tuple [int , ...]) -> GrassmannTensor :
722+ tensor , left_legs , right_legs = self ._group_legs (pairs )
723+
724+ axes_to_reverse = [i for i in range (2 ) if tensor .arrow [i ]]
725+ if axes_to_reverse :
726+ tensor = tensor .reverse (tuple (axes_to_reverse ))
727+
728+ left_dim , right_dim = tensor .tensor .shape
729+
730+ assert left_dim == right_dim , (
731+ f"Exponential requires a square operator, but got { left_dim } x { right_dim } ."
732+ )
733+
734+ (even_left , odd_left ) = tensor .edges [0 ]
735+ (even_right , odd_right ) = tensor .edges [1 ]
736+
737+ even_tensor = tensor .tensor [:even_left , :even_right ]
738+ odd_tensor = tensor .tensor [even_left :, even_right :]
739+
740+ even_tensor_exp = torch .linalg .matrix_exp (even_tensor )
741+ odd_tensor_exp = torch .linalg .matrix_exp (odd_tensor )
742+
743+ tensor_exp = torch .block_diag (even_tensor_exp , odd_tensor_exp ) # type: ignore[no-untyped-call]
744+
745+ tensor_exp = dataclasses .replace (tensor , _tensor = tensor_exp )
746+
747+ if axes_to_reverse :
748+ tensor_exp = tensor_exp .reverse (tuple (axes_to_reverse ))
749+
750+ order = left_legs + right_legs
751+ edges_after_permute = tuple (self .edges [i ] for i in order )
752+ tensor_exp = tensor_exp .reshape (edges_after_permute )
753+
754+ inv = [0 ] * self .tensor .dim ()
755+ for new_pos , orig_idx in enumerate (order ):
756+ inv [orig_idx ] = new_pos
757+ inv_order = tuple (inv )
758+
759+ tensor_exp = tensor_exp .permute (inv_order )
760+
761+ return tensor_exp
762+
712763 def __post_init__ (self ) -> None :
713764 assert len (self ._arrow ) == self ._tensor .dim (), (
714765 f"Arrow length ({ len (self ._arrow )} ) must match tensor dimensions ({ self ._tensor .dim ()} )."
0 commit comments