11import torch
22import pytest
3+ from _pytest .mark .structures import ParameterSet
34import math
45import itertools
56from typing import TypeAlias , Iterable
1314Tau : TypeAlias = float
1415FreeNamesU : TypeAlias = tuple [int , ...]
1516
16- SVDCases = Iterable [tuple [Arrow , Edges , Tensor , Cutoff , Tau , FreeNamesU ]]
17+ SVDCases = Iterable [ParameterSet ]
18+
1719
1820def get_total_singular (edges : Edges , free_names_u : FreeNamesU ) -> int :
1921 even , odd = edges [free_names_u [0 ]]
@@ -31,21 +33,31 @@ def get_total_singular(edges: Edges, free_names_u: FreeNamesU) -> int:
3133 total_singular += min (even , odd )
3234 return total_singular
3335
36+
3437def tau_for_cutoff (c : int , total : int , alpha : float = 0.8 ) -> float :
3538 lo , hi = 1e-8 , 1e-1
3639 x = (total - c ) / max (1 , total - 1 )
37- return lo + (hi - lo ) * (x ** alpha )
40+ return lo + (hi - lo ) * (x ** alpha )
41+
3842
3943def choose_free_names (n_edges : int , limit : int = 8 ) -> list [FreeNamesU ]:
40- combos = [tuple (c ) for r in range (1 , n_edges ) for c in itertools .combinations (range (n_edges ), r )]
44+ combos = [
45+ tuple (c ) for r in range (1 , n_edges ) for c in itertools .combinations (range (n_edges ), r )
46+ ]
4147 return combos [:limit ]
4248
49+
4350BASE_GT_CASES : list [tuple [Arrow , Edges , Tensor ]] = [
4451 ((True , True ), ((2 , 2 ), (4 , 4 )), torch .randn (4 , 8 , dtype = torch .float64 )),
4552 ((True , True , True ), ((2 , 2 ), (4 , 4 ), (8 , 8 )), torch .randn (4 , 8 , 16 , dtype = torch .float64 )),
46- ((True , True , True , True ), ((2 , 2 ), (4 , 4 ), (8 , 8 ), (16 , 16 )), torch .randn (4 , 8 , 16 , 32 , dtype = torch .float64 )),
53+ (
54+ (True , True , True , True ),
55+ ((2 , 2 ), (4 , 4 ), (8 , 8 ), (16 , 16 )),
56+ torch .randn (4 , 8 , 16 , 32 , dtype = torch .float64 ),
57+ ),
4758]
4859
60+
4961def svd_cases () -> SVDCases :
5062 params = []
5163 for arrow , edges , tensor in BASE_GT_CASES :
@@ -57,24 +69,30 @@ def svd_cases() -> SVDCases:
5769 tau = tau_for_cutoff (cutoff or total , total )
5870 params .append (
5971 pytest .param (
60- arrow , edges , tensor , cutoff , tau , fnu ,
61- id = f"edges={ tuple (edges )} |fnu={ fnu } |cut={ cutoff } |tau={ tau :.2e} "
72+ arrow ,
73+ edges ,
74+ tensor ,
75+ cutoff ,
76+ tau ,
77+ fnu ,
78+ id = f"edges={ tuple (edges )} |fnu={ fnu } |cut={ cutoff } |tau={ tau :.2e} " ,
6279 )
6380 )
6481 return params
6582
83+
6684@pytest .mark .parametrize (
6785 "arrow, edges, tensor, cutoff, tau, free_names_u" ,
6886 svd_cases (),
6987)
7088@pytest .mark .repeat (20 )
7189def test_svd (
72- arrow : Arrow ,
73- edges : Edges ,
74- tensor : Tensor ,
75- cutoff : Cutoff ,
76- tau : Tau ,
77- free_names_u : FreeNamesU ,
90+ arrow : Arrow ,
91+ edges : Edges ,
92+ tensor : Tensor ,
93+ cutoff : Cutoff ,
94+ tau : Tau ,
95+ free_names_u : FreeNamesU ,
7896) -> None :
7997 gt = GrassmannTensor (arrow , edges , tensor )
8098 U , S , Vh = gt .svd (free_names_u , cutoff = cutoff )
@@ -107,20 +125,19 @@ def test_svd(
107125 rel_err = (masked - USV .tensor ).norm () / max (den , eps )
108126 assert rel_err <= tau
109127
128+
110129@pytest .mark .parametrize (
111130 "arrow, edges, tensor, cutoff, tau, free_names_u" ,
112131 svd_cases (),
113132)
114133def test_svd_with_zero_cutoff (
115- arrow : Arrow ,
116- edges : Edges ,
117- tensor : Tensor ,
118- cutoff : Cutoff ,
119- tau : Tau ,
120- free_names_u : FreeNamesU ,
134+ arrow : Arrow ,
135+ edges : Edges ,
136+ tensor : Tensor ,
137+ cutoff : Cutoff ,
138+ tau : Tau ,
139+ free_names_u : FreeNamesU ,
121140) -> None :
122141 gt = GrassmannTensor (arrow , edges , tensor )
123142 with pytest .raises (AssertionError , match = "Cutoff must be greater than 0" ):
124143 _ , _ , _ = gt .svd (free_names_u , cutoff = 0 )
125-
126-
0 commit comments