14
14
# limitations under the License.
15
15
16
16
17
- import os
18
17
import pickle
19
18
import random
20
19
29
28
30
29
31
30
@pytest .fixture (scope = "module" )
32
- def gen_test_data (tmp_path_factory ):
31
+ def gen_pickle_files (tmp_path_factory ):
33
32
dir_pickles = tmp_path_factory .mktemp ("pickleddatawds" ).as_posix ()
34
- dir_tars_tmp = tmp_path_factory .mktemp ("webdatamodule" ).as_posix ()
35
- dir_tars = {split : f"{ dir_tars_tmp } { str (split ).split ('.' )[- 1 ]} " for split in Split }
36
33
prefix_sample = "sample"
37
34
suffix_sample = "tensor.pyd"
38
- prefix_tar = "tensor"
39
35
n_samples_per_split = 10
40
- n_samples = {split : n_samples_per_split for split in Split }
41
- os .makedirs (dir_pickles , exist_ok = True )
42
36
prefixes = []
43
37
# generate the pickles for train, val, and test
44
38
for i in range (n_samples_per_split * 3 ):
@@ -52,6 +46,24 @@ def gen_test_data(tmp_path_factory):
52
46
Split .val : prefixes [n_samples_per_split : n_samples_per_split * 2 ],
53
47
Split .test : prefixes [n_samples_per_split * 2 : n_samples_per_split * 3 ],
54
48
}
49
+ return (
50
+ dir_pickles ,
51
+ prefix_sample ,
52
+ suffix_sample ,
53
+ prefixes_pickle ,
54
+ n_samples_per_split ,
55
+ )
56
+
57
+
58
+ @pytest .fixture (scope = "module" )
59
+ def gen_test_data (tmp_path_factory , gen_pickle_files ):
60
+ dir_pickles , prefix_sample , suffix_sample , prefixes_pickle , n_samples_per_split = (
61
+ gen_pickle_files
62
+ )
63
+ dir_tars_tmp = tmp_path_factory .mktemp ("webdatamodule" ).as_posix ()
64
+ dir_tars = {split : f"{ dir_tars_tmp } { str (split ).split ('.' )[- 1 ]} " for split in Split }
65
+ prefix_tar = "tensor"
66
+ n_samples = {split : n_samples_per_split for split in Split }
55
67
# generate the tars
56
68
pickles_to_tars (
57
69
dir_pickles ,
0 commit comments