Skip to content

Commit bfa26e2

Browse files
GaussGauss
authored andcommitted
fix(svd): correct cutoff logic and add test cases for svd
- Correct cutoff logic and support None as no-trunction mode - Add parameterized test cases for svd - Fix missing coverage on exception and boundary paths Signed-off-by: Gauss <[email protected]>
1 parent 30003b4 commit bfa26e2

File tree

2 files changed

+127
-32
lines changed

2 files changed

+127
-32
lines changed

grassmann_tensor/tensor.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def svd(
503503
2. Merge the tensor with two groups.
504504
3. Split the block tensor into two parts.
505505
4. Compute the singular value decomposition.
506-
5. Use the cutoff to cut off the lower singular values and corresponding U and V.
506+
5. Use cutoff to keep the largest cutoff singular values (globally across even/odd blocks).
507507
6. Contract U, S and Vh.
508508
7. Split the legs into original left and right.
509509
The returned tensors U and V are not unique, nor are they continuous with respect to self.
@@ -534,31 +534,34 @@ def svd(
534534
U_even, S_even, Vh_even = torch.linalg.svd(even_tensor, full_matrices=full_matrices)
535535
U_odd, S_odd, Vh_odd = torch.linalg.svd(odd_tensor, full_matrices=full_matrices)
536536

537-
if cutoff is None:
538-
cutoff = 0
539537
total = S_even.numel() + S_odd.numel()
540-
cutoff = max(0, min(cutoff, total))
541538

542-
if cutoff == 0:
543-
keep_even = torch.ones_like(S_even, dtype=torch.bool)
544-
keep_odd = torch.ones_like(S_odd, dtype=torch.bool)
539+
if cutoff is None:
540+
cutoff = total
541+
else:
542+
cutoff = min(int(cutoff), total)
543+
544+
assert cutoff > 0, f"Cutoff must be greater than 0, but got {cutoff}"
545+
546+
if cutoff == total:
547+
keep_even = torch.ones_like(S_even, dtype=torch.bool, device=S_even.device)
548+
keep_odd = torch.ones_like(S_odd, dtype=torch.bool, device=S_odd.device)
545549
else:
546550
S_cat = torch.cat([S_even, S_odd])
547-
sorted_vals, sorted_indices = S_cat.sort()
548-
cutoff_indices = sorted_indices[:cutoff]
551+
top_vals, top_indices = torch.topk(S_cat, k=cutoff, largest=True, sorted=False)
549552

550-
mask_even_indices = cutoff_indices < S_even.shape[0]
551-
mask_odd_indices = ~mask_even_indices
552-
even_cutoff_indices = cutoff_indices[mask_even_indices]
553-
odd_cutoff_indices = cutoff_indices[mask_odd_indices] - S_even.shape[0]
553+
even_mask = top_indices < S_even.shape[0]
554+
odd_mask = ~even_mask
555+
even_keep_indices = top_indices[even_mask]
556+
odd_keep_indices = top_indices[odd_mask] - S_even.shape[0]
554557

555-
keep_even = torch.ones_like(S_even, dtype=torch.bool)
556-
if even_cutoff_indices.numel() > 0:
557-
keep_even[even_cutoff_indices] = False
558+
keep_even = torch.zeros_like(S_even, dtype=torch.bool).to(S_even.device)
559+
if even_keep_indices.numel() > 0:
560+
keep_even[even_keep_indices] = True
558561

559-
keep_odd = torch.ones_like(S_odd, dtype=torch.bool)
560-
if odd_cutoff_indices.numel() > 0:
561-
keep_odd[odd_cutoff_indices] = False
562+
keep_odd = torch.ones_like(S_odd, dtype=torch.bool).to(S_odd.device)
563+
if odd_keep_indices.numel() > 0:
564+
keep_odd[odd_keep_indices] = True
562565

563566
U_even_trunc = U_even[:, keep_even]
564567
S_even_trunc = S_even[keep_even]

tests/svd_test.py

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,126 @@
11
import torch
2+
import pytest
23
import math
4+
import itertools
5+
from typing import TypeAlias, Iterable
6+
37
from grassmann_tensor import GrassmannTensor
48

9+
Arrow: TypeAlias = tuple[bool, ...]
10+
Edges: TypeAlias = tuple[tuple[int, int], ...]
11+
Tensor: TypeAlias = torch.Tensor
12+
Cutoff: TypeAlias = int
13+
Tau: TypeAlias = float
14+
FreeNamesU: TypeAlias = tuple[int, ...]
15+
16+
SVDCases = Iterable[tuple[Arrow, Edges, Tensor, Cutoff, Tau, FreeNamesU]]
17+
18+
def get_total_singular(edges: Edges, free_names_u: FreeNamesU) -> int:
19+
even, odd = edges[free_names_u[0]]
20+
for i in range(1, len(free_names_u)):
21+
e, o = edges[free_names_u[i]]
22+
even, odd = even * e + odd * o, even * o + odd * e
23+
total_singular = min(even, odd)
24+
25+
set_all = set(range(len(edges)))
26+
right_idx = sorted(set_all - set(free_names_u))
27+
even, odd = edges[right_idx[0]]
28+
for i in range(1, len(right_idx)):
29+
e, o = edges[right_idx[i]]
30+
even, odd = even * e + odd * o, even * o + odd * e
31+
total_singular += min(even, odd)
32+
return total_singular
33+
34+
def tau_for_cutoff(c: int, total: int, alpha: float = 0.8) -> float:
35+
lo, hi = 1e-8, 1e-1
36+
x = (total - c) / max(1, total - 1)
37+
return lo + (hi - lo) * (x ** alpha)
38+
39+
def choose_free_names(n_edges: int, limit: int = 8) -> list[FreeNamesU]:
40+
combos = [tuple(c) for r in range(1, n_edges) for c in itertools.combinations(range(n_edges), r)]
41+
return combos[:limit]
542

6-
def test_svd() -> None:
7-
gt = GrassmannTensor(
8-
(True, True, True, True),
9-
((8, 8), (4, 4), (2, 2), (1, 1)),
10-
torch.randn([16, 8, 4, 2], dtype=torch.float64),
11-
)
12-
U, S, Vh = gt.svd((0, 3), cutoff=1)
43+
BASE_GT_CASES: list[tuple[Arrow, Edges, Tensor]] = [
44+
((True, True), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64)),
45+
((True, True, True), ((2, 2), (4, 4), (8, 8)), torch.randn(4, 8, 16, dtype=torch.float64)),
46+
((True, True, True, True), ((2, 2), (4, 4), (8, 8), (16, 16)), torch.randn(4, 8, 16, 32, dtype=torch.float64)),
47+
]
48+
49+
def svd_cases() -> SVDCases:
50+
params = []
51+
for arrow, edges, tensor in BASE_GT_CASES:
52+
for fnu in choose_free_names(len(edges)):
53+
total = get_total_singular(edges, fnu)
54+
for cutoff in [None, total, total - 1, total - 2]:
55+
if cutoff is not None and cutoff < 1:
56+
continue
57+
tau = tau_for_cutoff(cutoff or total, total)
58+
params.append(
59+
pytest.param(
60+
arrow, edges, tensor, cutoff, tau, fnu,
61+
id=f"edges={tuple(edges)}|fnu={fnu}|cut={cutoff}|tau={tau:.2e}"
62+
)
63+
)
64+
return params
65+
66+
@pytest.mark.parametrize(
67+
"arrow, edges, tensor, cutoff, tau, free_names_u",
68+
svd_cases(),
69+
)
70+
@pytest.mark.repeat(20)
71+
def test_svd(
72+
arrow: Arrow,
73+
edges: Edges,
74+
tensor: Tensor,
75+
cutoff: Cutoff,
76+
tau: Tau,
77+
free_names_u: FreeNamesU,
78+
) -> None:
79+
gt = GrassmannTensor(arrow, edges, tensor)
80+
U, S, Vh = gt.svd(free_names_u, cutoff=cutoff)
1381

1482
# reshape U
15-
# left_arrow = U.arrow[:-1]
1683
left_dim = math.prod(U.tensor.shape[:-1])
1784
left_edge = list(U.edges[:-1])
1885
U = U.reshape((left_dim, -1))
1986

2087
# reshape Vh
21-
# right_arrow = Vh.arrow[1:]
2288
right_dim = math.prod(Vh.tensor.shape[1:])
2389
right_edge = list(Vh.edges[1:])
2490
Vh = Vh.reshape((-1, right_dim))
2591

2692
US = GrassmannTensor.matmul(U, S)
2793
USV = GrassmannTensor.matmul(US, Vh)
2894

95+
set_all = set(range(len(edges)))
96+
set_u = set(free_names_u)
97+
set_v = sorted(set_all - set_u)
98+
perm_order = list(free_names_u) + list(set_v)
99+
inv_perm = [perm_order.index(i) for i in range(len(edges))]
100+
29101
USV = USV.reshape(tuple(left_edge + right_edge))
30-
USV = USV.permute((0, 2, 3, 1))
102+
USV = USV.permute(tuple(inv_perm))
103+
104+
masked = gt.update_mask().tensor
105+
den = masked.norm()
106+
eps = torch.finfo(masked.dtype).eps
107+
rel_err = (masked - USV.tensor).norm() / max(den, eps)
108+
assert rel_err <= tau
109+
110+
@pytest.mark.parametrize(
111+
"arrow, edges, tensor, cutoff, tau, free_names_u",
112+
svd_cases(),
113+
)
114+
def test_svd_with_zero_cutoff(
115+
arrow: Arrow,
116+
edges: Edges,
117+
tensor: Tensor,
118+
cutoff: Cutoff,
119+
tau: Tau,
120+
free_names_u: FreeNamesU,
121+
) -> None:
122+
gt = GrassmannTensor(arrow, edges, tensor)
123+
with pytest.raises(AssertionError, match="Cutoff must be greater than 0"):
124+
_, _, _ = gt.svd(free_names_u, cutoff=0)
125+
31126

32-
# assert torch.allclose(gt.update_mask().tensor, USV.tensor)
33-
rel_err = (gt.update_mask().tensor - USV.tensor).norm() / gt.update_mask().tensor.norm()
34-
assert rel_err < 1e-2

0 commit comments

Comments
 (0)