@@ -615,40 +615,37 @@ def svd(
615615 Vh_odd = odd_tensor .new_zeros ((0 , odd_right ))
616616
617617 n_even , n_odd = S_even .shape [0 ], S_odd .shape [0 ]
618- total = n_even + n_odd
619618
620619 if cutoff is None :
621- cutoff = ( n_even , n_odd )
620+ k_even , k_odd = n_even , n_odd
622621 elif isinstance (cutoff , int ):
623- assert total >= 0 , "Invalid total singular values."
624- if total == 0 :
622+ if n_even == 0 and n_odd == 0 :
625623 raise RuntimeError ("Both parity block are empty. Can not form SVD." )
626- k = min (int (cutoff ), total )
627- assert k > 0 , f"Cutoff must be greater than 0, but got { k } "
628-
629- cutoff = (cutoff , cutoff )
630- cutoff = (min (cutoff [0 ], n_even ), min (cutoff [1 ], n_odd ))
631-
632- if isinstance (cutoff , tuple ):
633- k_even = max (0 , min (int (cutoff [0 ]), n_even ))
634- k_odd = max (0 , min (int (cutoff [1 ]), n_odd ))
624+ assert cutoff > 0 , f"Cutoff must be greater than 0, but got { cutoff } "
625+ k_even = min (cutoff , n_even )
626+ k_odd = min (cutoff , n_odd )
627+ elif isinstance (cutoff , tuple ):
628+ assert len (cutoff ) == 2 , "The length of cutoff must be 2 if cutoff is a tuple."
635629 if n_even == 0 and n_odd == 0 :
636630 raise RuntimeError ("Both parity block are empty. Can not form SVD." )
637- assert (k_even > 0 or n_even == 0 ) and (k_odd > 0 or n_odd == 0 ), (
638- "Per-block cutoff must be compatible with available singulars"
639- )
640-
641- keep_even = torch .zeros (n_even , dtype = torch .bool , device = S_even .device )
642- keep_odd = torch .zeros (n_odd , dtype = torch .bool , device = S_odd .device )
643- if k_even > 0 :
644- keep_even [:k_even ] = True
645- if k_odd > 0 :
646- keep_odd [:k_odd ] = True
631+ k_even = max (0 , min (int (cutoff [0 ]), n_even ))
632+ k_odd = max (0 , min (int (cutoff [1 ]), n_odd ))
647633 else :
648634 raise ValueError (
649635 f"Cutoff must be an integer or a tuple of two integers, but got { cutoff } "
650636 )
651637
638+ assert (k_even > 0 or n_even == 0 ) and (k_odd > 0 or n_odd == 0 ), (
639+ "Per-block cutoff must be compatible with available singulars"
640+ )
641+
642+ keep_even = torch .zeros (n_even , dtype = torch .bool , device = S_even .device )
643+ keep_odd = torch .zeros (n_odd , dtype = torch .bool , device = S_odd .device )
644+ if k_even > 0 :
645+ keep_even [:k_even ] = True
646+ if k_odd > 0 :
647+ keep_odd [:k_odd ] = True
648+
652649 U_even_trunc = U_even [:, keep_even ]
653650 S_even_trunc = S_even [keep_even ]
654651 Vh_even_trunc = Vh_even [keep_even , :]
0 commit comments