99import dataclasses
1010import functools
1111import typing
12+ import math
1213import torch
1314
1415
@@ -295,9 +296,28 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
295296 tensor = self .tensor .reshape (())
296297 return GrassmannTensor (_arrow = (), _edges = (), _tensor = tensor )
297298
299+ if new_shape == (1 ,) and int (self .tensor .numel ()) == 1 :
300+ eo = self ._calculate_even_odd ()
301+ new_shape = (eo ,)
302+
298303 cursor_plan : int = 0
299304 cursor_self : int = 0
300305 while cursor_plan != len (new_shape ) or cursor_self != self .tensor .dim ():
306+ if cursor_self == self .tensor .dim () and cursor_plan != len (new_shape ):
307+ new_shape_check = new_shape [cursor_plan ]
308+ if (isinstance (new_shape_check , int ) and new_shape_check == 1 ) or (
309+ new_shape_check == (1 , 0 )
310+ ):
311+ arrow .append (False )
312+ edges .append ((1 , 0 ))
313+ shape .append (1 )
314+ cursor_plan += 1
315+ continue
316+ raise AssertionError (
317+ "New shape exceeds after exhausting self dimensions: "
318+ f"edges={ self .edges } , new_shape={ new_shape } "
319+ )
320+
301321 if cursor_plan != len (new_shape ) and new_shape [cursor_plan ] == - 1 :
302322 # Does not change
303323 arrow .append (self .arrow [cursor_self ])
@@ -306,7 +326,11 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
306326 cursor_self += 1
307327 cursor_plan += 1
308328 continue
309- elif cursor_plan != len (new_shape ) and new_shape [cursor_plan ] == (1 , 0 ):
329+ elif (
330+ cursor_plan != len (new_shape )
331+ and new_shape [cursor_plan ] == (1 , 0 )
332+ and cursor_plan < len (new_shape ) - 1
333+ ):
310334 # A trivial plan edge
311335 arrow .append (False )
312336 edges .append ((1 , 0 ))
@@ -532,6 +556,146 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
532556 _tensor = tensor ,
533557 )
534558
559+ def svd (
560+ self ,
561+ free_names_u : tuple [int , ...],
562+ * ,
563+ cutoff : int | None | tuple [int , int ] = None ,
564+ ) -> tuple [GrassmannTensor , GrassmannTensor , GrassmannTensor ]:
565+ """
566+ This function is used to computes the singular value decomposition of a grassmann tensor.
567+ The SVD are implemented by follow steps:
568+ 1. Split the legs into left and right;
569+ 2. Merge the tensor with two groups.
570+ 3. Split the block tensor into two parts.
571+ 4. Compute the singular value decomposition.
572+ 5. Use cutoff to keep the largest cutoff singular values (globally across even/odd blocks).
573+ 6. Contract U, S and Vh.
574+ 7. Split the legs into original left and right.
575+ The returned tensors U and V are not unique, nor are they continuous with respect to self.
576+ Due to this lack of uniqueness, different hardware and software may compute different singular vectors.
577+ Gradients computed using U or Vh will only be finite when A does not have repeated singular values.
578+ Furthermore, if the distance between any two singular values is close to zero, the gradient
579+ will be numerically unstable, as it depends on the singular values
580+ """
581+ left_legs = tuple (int (i ) for i in free_names_u )
582+ right_legs = tuple (i for i in range (self .tensor .dim ()) if i not in left_legs )
583+ assert set (left_legs ) | set (right_legs ) == set (range (self .tensor .dim ())), (
584+ "Left/right must cover all tensor legs."
585+ )
586+
587+ if isinstance (cutoff , tuple ):
588+ assert len (cutoff ) == 2 , "The length of cutoff must be 2 if cutoff is a tuple."
589+
590+ order = left_legs + right_legs
591+ tensor = self .permute (order )
592+
593+ left_dim = math .prod (tensor .tensor .shape [: len (left_legs )])
594+ right_dim = math .prod (tensor .tensor .shape [len (left_legs ) :])
595+
596+ tensor = tensor .reshape ((left_dim , right_dim ))
597+
598+ (even_left , odd_left ) = tensor .edges [0 ]
599+ (even_right , odd_right ) = tensor .edges [1 ]
600+ even_tensor = tensor .tensor [:even_left , :even_right ]
601+ odd_tensor = tensor .tensor [even_left :, even_right :]
602+
603+ if even_tensor .numel () > 0 :
604+ U_even , S_even , Vh_even = torch .linalg .svd (even_tensor , full_matrices = False )
605+ else :
606+ U_even = even_tensor .new_zeros ((even_left , 0 ))
607+ S_even = even_tensor .new_zeros ((0 ,))
608+ Vh_even = even_tensor .new_zeros ((0 , even_right ))
609+
610+ if odd_tensor .numel () > 0 :
611+ U_odd , S_odd , Vh_odd = torch .linalg .svd (odd_tensor , full_matrices = False )
612+ else :
613+ U_odd = odd_tensor .new_zeros ((odd_left , 0 ))
614+ S_odd = odd_tensor .new_zeros ((0 ,))
615+ Vh_odd = odd_tensor .new_zeros ((0 , odd_right ))
616+
617+ n_even , n_odd = S_even .shape [0 ], S_odd .shape [0 ]
618+
619+ if cutoff is None :
620+ k_even , k_odd = n_even , n_odd
621+ elif isinstance (cutoff , int ):
622+ if n_even == 0 and n_odd == 0 :
623+ raise RuntimeError ("Both parity block are empty. Can not form SVD." )
624+ assert cutoff > 0 , f"Cutoff must be greater than 0, but got { cutoff } "
625+ k_even = min (cutoff , n_even )
626+ k_odd = min (cutoff , n_odd )
627+ elif isinstance (cutoff , tuple ):
628+ assert len (cutoff ) == 2 , "The length of cutoff must be 2 if cutoff is a tuple."
629+ if n_even == 0 and n_odd == 0 :
630+ raise RuntimeError ("Both parity block are empty. Can not form SVD." )
631+ k_even = max (0 , min (int (cutoff [0 ]), n_even ))
632+ k_odd = max (0 , min (int (cutoff [1 ]), n_odd ))
633+ else :
634+ raise ValueError (
635+ f"Cutoff must be an integer or a tuple of two integers, but got { cutoff } "
636+ )
637+
638+ assert (k_even > 0 or n_even == 0 ) and (k_odd > 0 or n_odd == 0 ), (
639+ "Per-block cutoff must be compatible with available singulars"
640+ )
641+
642+ keep_even = torch .zeros (n_even , dtype = torch .bool , device = S_even .device )
643+ keep_odd = torch .zeros (n_odd , dtype = torch .bool , device = S_odd .device )
644+ if k_even > 0 :
645+ keep_even [:k_even ] = True
646+ if k_odd > 0 :
647+ keep_odd [:k_odd ] = True
648+
649+ U_even_trunc = U_even [:, keep_even ]
650+ S_even_trunc = S_even [keep_even ]
651+ Vh_even_trunc = Vh_even [keep_even , :]
652+
653+ U_odd_trunc = U_odd [:, keep_odd ]
654+ S_odd_trunc = S_odd [keep_odd ]
655+ Vh_odd_trunc = Vh_odd [keep_odd , :]
656+
657+ U_tensor = torch .block_diag (U_even_trunc , U_odd_trunc ) # type: ignore[no-untyped-call]
658+ S_tensor = torch .cat ([S_even_trunc , S_odd_trunc ], dim = 0 )
659+ Vh_tensor = torch .block_diag (Vh_even_trunc , Vh_odd_trunc ) # type: ignore[no-untyped-call]
660+
661+ U_edges = (
662+ (U_even_trunc .shape [0 ], U_odd_trunc .shape [0 ]),
663+ (U_even_trunc .shape [1 ], U_odd_trunc .shape [1 ]),
664+ )
665+ S_edges = (
666+ (U_even_trunc .shape [1 ], U_odd_trunc .shape [1 ]),
667+ (Vh_even_trunc .shape [0 ], Vh_odd_trunc .shape [0 ]),
668+ )
669+ Vh_edges = (
670+ (Vh_even_trunc .shape [0 ], Vh_odd_trunc .shape [0 ]),
671+ (Vh_even_trunc .shape [1 ], Vh_odd_trunc .shape [1 ]),
672+ )
673+
674+ U = GrassmannTensor (_arrow = (True , True ), _edges = U_edges , _tensor = U_tensor )
675+ S = GrassmannTensor (
676+ _arrow = (
677+ False ,
678+ True ,
679+ ),
680+ _edges = S_edges ,
681+ _tensor = torch .diag (S_tensor ),
682+ )
683+ Vh = GrassmannTensor (_arrow = (False , True ), _edges = Vh_edges , _tensor = Vh_tensor )
684+ # Split
685+ left_arrow = [self .arrow [i ] for i in left_legs ]
686+ left_edges = [self .edges [i ] for i in left_legs ]
687+
688+ right_arrow = [self .arrow [i ] for i in right_legs ]
689+ right_edges = [self .edges [i ] for i in right_legs ]
690+
691+ U = U .reshape ((* left_edges , U_edges [1 ]))
692+ U ._arrow = tuple (left_arrow + [True ])
693+
694+ Vh = Vh .reshape ((Vh_edges [0 ], * right_edges ))
695+ Vh ._arrow = tuple ([False ] + right_arrow )
696+
697+ return U , S , Vh
698+
535699 def __post_init__ (self ) -> None :
536700 assert len (self ._arrow ) == self ._tensor .dim (), (
537701 f"Arrow length ({ len (self ._arrow )} ) must match tensor dimensions ({ self ._tensor .dim ()} )."
0 commit comments