|
1 | 1 | import torch |
| 2 | +import pytest |
2 | 3 | import math |
| 4 | +import itertools |
| 5 | +from typing import TypeAlias, Iterable |
| 6 | + |
3 | 7 | from grassmann_tensor import GrassmannTensor |
4 | 8 |
|
| 9 | +Arrow: TypeAlias = tuple[bool, ...] |
| 10 | +Edges: TypeAlias = tuple[tuple[int, int], ...] |
| 11 | +Tensor: TypeAlias = torch.Tensor |
| 12 | +Cutoff: TypeAlias = int |
| 13 | +Tau: TypeAlias = float |
| 14 | +FreeNamesU: TypeAlias = tuple[int, ...] |
| 15 | + |
| 16 | +SVDCases = Iterable[tuple[Arrow, Edges, Tensor, Cutoff, Tau, FreeNamesU]] |
| 17 | + |
| 18 | +def get_total_singular(edges: Edges, free_names_u: FreeNamesU) -> int: |
| 19 | + even, odd = edges[free_names_u[0]] |
| 20 | + for i in range(1, len(free_names_u)): |
| 21 | + e, o = edges[free_names_u[i]] |
| 22 | + even, odd = even * e + odd * o, even * o + odd * e |
| 23 | + total_singular = min(even, odd) |
| 24 | + |
| 25 | + set_all = set(range(len(edges))) |
| 26 | + right_idx = sorted(set_all - set(free_names_u)) |
| 27 | + even, odd = edges[right_idx[0]] |
| 28 | + for i in range(1, len(right_idx)): |
| 29 | + e, o = edges[right_idx[i]] |
| 30 | + even, odd = even * e + odd * o, even * o + odd * e |
| 31 | + total_singular += min(even, odd) |
| 32 | + return total_singular |
| 33 | + |
| 34 | +def tau_for_cutoff(c: int, total: int, alpha: float = 0.8) -> float: |
| 35 | + lo, hi = 1e-8, 1e-1 |
| 36 | + x = (total - c) / max(1, total - 1) |
| 37 | + return lo + (hi - lo) * (x ** alpha) |
| 38 | + |
| 39 | +def choose_free_names(n_edges: int, limit: int = 8) -> list[FreeNamesU]: |
| 40 | + combos = [tuple(c) for r in range(1, n_edges) for c in itertools.combinations(range(n_edges), r)] |
| 41 | + return combos[:limit] |
5 | 42 |
|
6 | | -def test_svd() -> None: |
7 | | - gt = GrassmannTensor( |
8 | | - (True, True, True, True), |
9 | | - ((8, 8), (4, 4), (2, 2), (1, 1)), |
10 | | - torch.randn([16, 8, 4, 2], dtype=torch.float64), |
11 | | - ) |
12 | | - U, S, Vh = gt.svd((0, 3), cutoff=1) |
| 43 | +BASE_GT_CASES: list[tuple[Arrow, Edges, Tensor]] = [ |
| 44 | + ((True, True), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64)), |
| 45 | + ((True, True, True), ((2, 2), (4, 4), (8, 8)), torch.randn(4, 8, 16, dtype=torch.float64)), |
| 46 | + ((True, True, True, True), ((2, 2), (4, 4), (8, 8), (16, 16)), torch.randn(4, 8, 16, 32, dtype=torch.float64)), |
| 47 | +] |
| 48 | + |
| 49 | +def svd_cases() -> SVDCases: |
| 50 | + params = [] |
| 51 | + for arrow, edges, tensor in BASE_GT_CASES: |
| 52 | + for fnu in choose_free_names(len(edges)): |
| 53 | + total = get_total_singular(edges, fnu) |
| 54 | + for cutoff in [None, total, total - 1, total - 2]: |
| 55 | + if cutoff is not None and cutoff < 1: |
| 56 | + continue |
| 57 | + tau = tau_for_cutoff(cutoff or total, total) |
| 58 | + params.append( |
| 59 | + pytest.param( |
| 60 | + arrow, edges, tensor, cutoff, tau, fnu, |
| 61 | + id=f"edges={tuple(edges)}|fnu={fnu}|cut={cutoff}|tau={tau:.2e}" |
| 62 | + ) |
| 63 | + ) |
| 64 | + return params |
| 65 | + |
| 66 | +@pytest.mark.parametrize( |
| 67 | + "arrow, edges, tensor, cutoff, tau, free_names_u", |
| 68 | + svd_cases(), |
| 69 | +) |
| 70 | +@pytest.mark.repeat(20) |
| 71 | +def test_svd( |
| 72 | + arrow: Arrow, |
| 73 | + edges: Edges, |
| 74 | + tensor: Tensor, |
| 75 | + cutoff: Cutoff, |
| 76 | + tau: Tau, |
| 77 | + free_names_u: FreeNamesU, |
| 78 | +) -> None: |
| 79 | + gt = GrassmannTensor(arrow, edges, tensor) |
| 80 | + U, S, Vh = gt.svd(free_names_u, cutoff=cutoff) |
13 | 81 |
|
14 | 82 | # reshape U |
15 | | - # left_arrow = U.arrow[:-1] |
16 | 83 | left_dim = math.prod(U.tensor.shape[:-1]) |
17 | 84 | left_edge = list(U.edges[:-1]) |
18 | 85 | U = U.reshape((left_dim, -1)) |
19 | 86 |
|
20 | 87 | # reshape Vh |
21 | | - # right_arrow = Vh.arrow[1:] |
22 | 88 | right_dim = math.prod(Vh.tensor.shape[1:]) |
23 | 89 | right_edge = list(Vh.edges[1:]) |
24 | 90 | Vh = Vh.reshape((-1, right_dim)) |
25 | 91 |
|
26 | 92 | US = GrassmannTensor.matmul(U, S) |
27 | 93 | USV = GrassmannTensor.matmul(US, Vh) |
28 | 94 |
|
| 95 | + set_all = set(range(len(edges))) |
| 96 | + set_u = set(free_names_u) |
| 97 | + set_v = sorted(set_all - set_u) |
| 98 | + perm_order = list(free_names_u) + list(set_v) |
| 99 | + inv_perm = [perm_order.index(i) for i in range(len(edges))] |
| 100 | + |
29 | 101 | USV = USV.reshape(tuple(left_edge + right_edge)) |
30 | | - USV = USV.permute((0, 2, 3, 1)) |
| 102 | + USV = USV.permute(tuple(inv_perm)) |
| 103 | + |
| 104 | + masked = gt.update_mask().tensor |
| 105 | + den = masked.norm() |
| 106 | + eps = torch.finfo(masked.dtype).eps |
| 107 | + rel_err = (masked - USV.tensor).norm() / max(den, eps) |
| 108 | + assert rel_err <= tau |
| 109 | + |
| 110 | +@pytest.mark.parametrize( |
| 111 | + "arrow, edges, tensor, cutoff, tau, free_names_u", |
| 112 | + svd_cases(), |
| 113 | +) |
| 114 | +def test_svd_with_zero_cutoff( |
| 115 | + arrow: Arrow, |
| 116 | + edges: Edges, |
| 117 | + tensor: Tensor, |
| 118 | + cutoff: Cutoff, |
| 119 | + tau: Tau, |
| 120 | + free_names_u: FreeNamesU, |
| 121 | +) -> None: |
| 122 | + gt = GrassmannTensor(arrow, edges, tensor) |
| 123 | + with pytest.raises(AssertionError, match="Cutoff must be greater than 0"): |
| 124 | + _, _, _ = gt.svd(free_names_u, cutoff=0) |
| 125 | + |
31 | 126 |
|
32 | | - # assert torch.allclose(gt.update_mask().tensor, USV.tensor) |
33 | | - rel_err = (gt.update_mask().tensor - USV.tensor).norm() / gt.update_mask().tensor.norm() |
34 | | - assert rel_err < 1e-2 |
|
0 commit comments