Skip to content

Commit 754dce6

Browse files
GaussGauss
authored andcommitted
fix(svd): resolve coverage and type annotation issues
- Resolve coverage issues of svd function - Correct type annotations in test cases - Reformat the code of svd Signed-off-by: Gauss <[email protected]>
1 parent bfa26e2 commit 754dce6

File tree

2 files changed

+39
-24
lines changed

2 files changed

+39
-24
lines changed

grassmann_tensor/tensor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,10 @@ def svd(
556556
odd_keep_indices = top_indices[odd_mask] - S_even.shape[0]
557557

558558
keep_even = torch.zeros_like(S_even, dtype=torch.bool).to(S_even.device)
559-
if even_keep_indices.numel() > 0:
560-
keep_even[even_keep_indices] = True
559+
keep_even[even_keep_indices] = True
561560

562561
keep_odd = torch.ones_like(S_odd, dtype=torch.bool).to(S_odd.device)
563-
if odd_keep_indices.numel() > 0:
564-
keep_odd[odd_keep_indices] = True
562+
keep_odd[odd_keep_indices] = True
565563

566564
U_even_trunc = U_even[:, keep_even]
567565
S_even_trunc = S_even[keep_even]

tests/svd_test.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import pytest
3+
from _pytest.mark.structures import ParameterSet
34
import math
45
import itertools
56
from typing import TypeAlias, Iterable
@@ -13,7 +14,8 @@
1314
Tau: TypeAlias = float
1415
FreeNamesU: TypeAlias = tuple[int, ...]
1516

16-
SVDCases = Iterable[tuple[Arrow, Edges, Tensor, Cutoff, Tau, FreeNamesU]]
17+
SVDCases = Iterable[ParameterSet]
18+
1719

1820
def get_total_singular(edges: Edges, free_names_u: FreeNamesU) -> int:
1921
even, odd = edges[free_names_u[0]]
@@ -31,21 +33,31 @@ def get_total_singular(edges: Edges, free_names_u: FreeNamesU) -> int:
3133
total_singular += min(even, odd)
3234
return total_singular
3335

36+
3437
def tau_for_cutoff(c: int, total: int, alpha: float = 0.8) -> float:
3538
lo, hi = 1e-8, 1e-1
3639
x = (total - c) / max(1, total - 1)
37-
return lo + (hi - lo) * (x ** alpha)
40+
return lo + (hi - lo) * (x**alpha)
41+
3842

3943
def choose_free_names(n_edges: int, limit: int = 8) -> list[FreeNamesU]:
40-
combos = [tuple(c) for r in range(1, n_edges) for c in itertools.combinations(range(n_edges), r)]
44+
combos = [
45+
tuple(c) for r in range(1, n_edges) for c in itertools.combinations(range(n_edges), r)
46+
]
4147
return combos[:limit]
4248

49+
4350
BASE_GT_CASES: list[tuple[Arrow, Edges, Tensor]] = [
4451
((True, True), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64)),
4552
((True, True, True), ((2, 2), (4, 4), (8, 8)), torch.randn(4, 8, 16, dtype=torch.float64)),
46-
((True, True, True, True), ((2, 2), (4, 4), (8, 8), (16, 16)), torch.randn(4, 8, 16, 32, dtype=torch.float64)),
53+
(
54+
(True, True, True, True),
55+
((2, 2), (4, 4), (8, 8), (16, 16)),
56+
torch.randn(4, 8, 16, 32, dtype=torch.float64),
57+
),
4758
]
4859

60+
4961
def svd_cases() -> SVDCases:
5062
params = []
5163
for arrow, edges, tensor in BASE_GT_CASES:
@@ -57,24 +69,30 @@ def svd_cases() -> SVDCases:
5769
tau = tau_for_cutoff(cutoff or total, total)
5870
params.append(
5971
pytest.param(
60-
arrow, edges, tensor, cutoff, tau, fnu,
61-
id=f"edges={tuple(edges)}|fnu={fnu}|cut={cutoff}|tau={tau:.2e}"
72+
arrow,
73+
edges,
74+
tensor,
75+
cutoff,
76+
tau,
77+
fnu,
78+
id=f"edges={tuple(edges)}|fnu={fnu}|cut={cutoff}|tau={tau:.2e}",
6279
)
6380
)
6481
return params
6582

83+
6684
@pytest.mark.parametrize(
6785
"arrow, edges, tensor, cutoff, tau, free_names_u",
6886
svd_cases(),
6987
)
7088
@pytest.mark.repeat(20)
7189
def test_svd(
72-
arrow: Arrow,
73-
edges: Edges,
74-
tensor: Tensor,
75-
cutoff: Cutoff,
76-
tau: Tau,
77-
free_names_u: FreeNamesU,
90+
arrow: Arrow,
91+
edges: Edges,
92+
tensor: Tensor,
93+
cutoff: Cutoff,
94+
tau: Tau,
95+
free_names_u: FreeNamesU,
7896
) -> None:
7997
gt = GrassmannTensor(arrow, edges, tensor)
8098
U, S, Vh = gt.svd(free_names_u, cutoff=cutoff)
@@ -107,20 +125,19 @@ def test_svd(
107125
rel_err = (masked - USV.tensor).norm() / max(den, eps)
108126
assert rel_err <= tau
109127

128+
110129
@pytest.mark.parametrize(
111130
"arrow, edges, tensor, cutoff, tau, free_names_u",
112131
svd_cases(),
113132
)
114133
def test_svd_with_zero_cutoff(
115-
arrow: Arrow,
116-
edges: Edges,
117-
tensor: Tensor,
118-
cutoff: Cutoff,
119-
tau: Tau,
120-
free_names_u: FreeNamesU,
134+
arrow: Arrow,
135+
edges: Edges,
136+
tensor: Tensor,
137+
cutoff: Cutoff,
138+
tau: Tau,
139+
free_names_u: FreeNamesU,
121140
) -> None:
122141
gt = GrassmannTensor(arrow, edges, tensor)
123142
with pytest.raises(AssertionError, match="Cutoff must be greater than 0"):
124143
_, _, _ = gt.svd(free_names_u, cutoff=0)
125-
126-

0 commit comments

Comments
 (0)