Skip to content

Commit c0dcd6d

Browse files
committed
dev(svd): improve code readability
- Improve code readability - Remove redundant code
1 parent df748b1 commit c0dcd6d

File tree

2 files changed

+24
-46
lines changed

2 files changed

+24
-46
lines changed

grassmann_tensor/tensor.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -615,40 +615,37 @@ def svd(
615615
Vh_odd = odd_tensor.new_zeros((0, odd_right))
616616

617617
n_even, n_odd = S_even.shape[0], S_odd.shape[0]
618-
total = n_even + n_odd
619618

620619
if cutoff is None:
621-
cutoff = (n_even, n_odd)
620+
k_even, k_odd = n_even, n_odd
622621
elif isinstance(cutoff, int):
623-
assert total >= 0, "Invalid total singular values."
624-
if total == 0:
622+
if n_even == 0 and n_odd == 0:
625623
raise RuntimeError("Both parity block are empty. Can not form SVD.")
626-
k = min(int(cutoff), total)
627-
assert k > 0, f"Cutoff must be greater than 0, but got {k}"
628-
629-
cutoff = (cutoff, cutoff)
630-
cutoff = (min(cutoff[0], n_even), min(cutoff[1], n_odd))
631-
632-
if isinstance(cutoff, tuple):
633-
k_even = max(0, min(int(cutoff[0]), n_even))
634-
k_odd = max(0, min(int(cutoff[1]), n_odd))
624+
assert cutoff > 0, f"Cutoff must be greater than 0, but got {cutoff}"
625+
k_even = min(cutoff, n_even)
626+
k_odd = min(cutoff, n_odd)
627+
elif isinstance(cutoff, tuple):
628+
assert len(cutoff) == 2, "The length of cutoff must be 2 if cutoff is a tuple."
635629
if n_even == 0 and n_odd == 0:
636630
raise RuntimeError("Both parity block are empty. Can not form SVD.")
637-
assert (k_even > 0 or n_even == 0) and (k_odd > 0 or n_odd == 0), (
638-
"Per-block cutoff must be compatible with available singulars"
639-
)
640-
641-
keep_even = torch.zeros(n_even, dtype=torch.bool, device=S_even.device)
642-
keep_odd = torch.zeros(n_odd, dtype=torch.bool, device=S_odd.device)
643-
if k_even > 0:
644-
keep_even[:k_even] = True
645-
if k_odd > 0:
646-
keep_odd[:k_odd] = True
631+
k_even = max(0, min(int(cutoff[0]), n_even))
632+
k_odd = max(0, min(int(cutoff[1]), n_odd))
647633
else:
648634
raise ValueError(
649635
f"Cutoff must be an integer or a tuple of two integers, but got {cutoff}"
650636
)
651637

638+
assert (k_even > 0 or n_even == 0) and (k_odd > 0 or n_odd == 0), (
639+
"Per-block cutoff must be compatible with available singulars"
640+
)
641+
642+
keep_even = torch.zeros(n_even, dtype=torch.bool, device=S_even.device)
643+
keep_odd = torch.zeros(n_odd, dtype=torch.bool, device=S_odd.device)
644+
if k_even > 0:
645+
keep_even[:k_even] = True
646+
if k_odd > 0:
647+
keep_odd[:k_odd] = True
648+
652649
U_even_trunc = U_even[:, keep_even]
653650
S_even_trunc = S_even[keep_even]
654651
Vh_even_trunc = Vh_even[keep_even, :]

tests/svd_test.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,17 @@
1818

1919

2020
def get_total_singular(edges: Edges, free_names_u: FreeNamesU) -> tuple[int, int]:
21-
even, odd = edges[free_names_u[0]]
22-
for i in range(1, len(free_names_u)):
23-
e, o = edges[free_names_u[i]]
24-
even, odd = even * e + odd * o, even * o + odd * e
25-
even_singular = min(even, odd)
26-
21+
even_singular = min(GrassmannTensor.calculate_even_odd(tuple(edges[i] for i in free_names_u)))
2722
set_all = set(range(len(edges)))
28-
right_idx = sorted(set_all - set(free_names_u))
29-
even, odd = edges[right_idx[0]]
30-
for i in range(1, len(right_idx)):
31-
e, o = edges[right_idx[i]]
32-
even, odd = even * e + odd * o, even * o + odd * e
33-
odd_singular = min(even, odd)
23+
remain_idx = sorted(set_all - set(free_names_u))
24+
odd_singular = min(GrassmannTensor.calculate_even_odd(tuple(edges[i] for i in remain_idx)))
3425
return even_singular, odd_singular
3526

3627

37-
def tau_for_cutoff(
38-
c: int | tuple[int, int], total: int, alpha: float = 0.8, slack: float = 1.05
39-
) -> float:
28+
def tau_for_cutoff(c: int, total: int, alpha: float = 0.8, slack: float = 1.05) -> float:
4029
cut = 0
4130
if isinstance(c, int):
4231
cut = c
43-
elif isinstance(c, tuple):
44-
cut = sum(c)
4532
lo, hi = 1e-8, 1e-1
4633
x = (total - cut) / max(1, total - 1)
4734
return (lo + (hi - lo) * (x**alpha)) * slack
@@ -80,12 +67,6 @@ def svd_cases() -> SVDCases:
8067
(even_singular, odd_singular),
8168
]
8269
for cutoff in cutoff_list:
83-
if isinstance(cutoff, int):
84-
if cutoff is not None and cutoff < 1:
85-
continue
86-
elif isinstance(cutoff, tuple):
87-
if cutoff[0] < 1 or cutoff[1] < 1:
88-
continue
8970
if cutoff is None:
9071
kept = total
9172
elif isinstance(cutoff, int):

0 commit comments

Comments
 (0)