26
26
AVAILABLE_CLASSIFIERS .remove ('meta' )
27
27
28
28
#--------------------------------------------------------------------------------------------------
29
- @pytest .mark .parametrize ('classifier' , AVAILABLE_CLASSIFIERS ) # FIXME: + ['meta']
29
+ @pytest .mark .parametrize ('classifier' , AVAILABLE_CLASSIFIERS + ['meta' ])
30
30
def test_classifiers_train_test (monkeypatch , SHARED_INPUT_DIR , classifier ):
31
31
32
32
stcl = starclass .get_classifier (classifier )
33
33
34
- # Pick out a task to use for testing:
35
- with starclass .TaskManager (SHARED_INPUT_DIR ) as tm :
36
- task1 = tm .get_task (classifier = classifier , change_classifier = False , chunk = 1 )[0 ]
37
- print (task1 )
38
-
39
34
with tempfile .TemporaryDirectory (prefix = 'starclass-testing-' ) as tmpdir :
40
35
if classifier == 'meta' :
41
36
# For the MetaClassifier, we need to manipulate the training-set
42
37
# a little bit before we can train. We have to mimic that
43
38
# all the other classifiers have already been trained and cross-validated
44
39
# in order to fill up the training-set todo-file with probabilities
45
40
# which the MetaClassifier uses for training.
46
- tsetclass = starclass .get_trainingset ('keplerq9v3' )
47
- input_folder = tsetclass .find_input_folder ()
48
-
49
- # Create a copy of the root files of the trainings set (ignore that actual data)
50
- # in the temp. directory:
51
- tsetdir = os .path .join (tmpdir , os .path .basename (input_folder ))
52
- print ("New dummy input folder: %s" % tsetdir )
53
- os .makedirs (tsetdir )
54
- for f in os .listdir (input_folder ):
55
- fpath = os .path .join (input_folder , f )
56
- if os .path .isfile (fpath ) and not f .endswith (('.sqlite' , '.sqlite-journal' )):
57
- shutil .copy (fpath , tsetdir )
58
41
59
42
# Change the environment variable to the temp. dir:
60
43
monkeypatch .setenv ("STARCLASS_TSETS" , tmpdir )
61
44
62
45
# Copy the pre-prepared todo-file to the training-set directory:
63
- prepared_todo = os .path .join (SHARED_INPUT_DIR , 'meta' , 'todo.sqlite' )
46
+ tsetclass = starclass .get_trainingset ('keplerq9v3' )
47
+ prepared_todo = os .path .join (SHARED_INPUT_DIR , 'meta' , 'keplerq9v3-tset.sqlite' )
64
48
new_todo = os .path .join (tsetclass .find_input_folder (), tsetclass ._todo_name + '.sqlite' )
49
+ os .makedirs (os .path .dirname (new_todo ), exist_ok = True )
65
50
shutil .copyfile (prepared_todo , new_todo )
66
51
67
52
# Initialize the training-set in the temp folder,
@@ -72,6 +57,14 @@ def test_classifiers_train_test(monkeypatch, SHARED_INPUT_DIR, classifier):
72
57
tsetclass = starclass .get_trainingset ('testing' )
73
58
tset = tsetclass (tf = 0.2 , random_seed = 42 )
74
59
60
+ print (tset )
61
+ print (tset .fake_metaclassifier )
62
+
63
+ # Pick out a task to use for testing:
64
+ with starclass .TaskManager (tset .todo_file , load_into_memory = False , classes = tset .StellarClasses ) as tm :
65
+ task1 = tm .get_task (classifier = classifier , change_classifier = False , chunk = 1 )[0 ]
66
+ print (task1 )
67
+
75
68
# Initialize the classifier and run training and testing:
76
69
with stcl (tset = tset , features_cache = None , data_dir = tmpdir ) as cl :
77
70
print (cl .data_dir )
@@ -132,27 +125,47 @@ def test_classifiers_train_test(monkeypatch, SHARED_INPUT_DIR, classifier):
132
125
assert results1 [key ] == results2 [key ], "Non-identical results before and after saving/loading model"
133
126
134
127
#--------------------------------------------------------------------------------------------------
135
- @pytest .mark .parametrize ('classifier' , AVAILABLE_CLASSIFIERS )
136
- def test_run_training (PRIVATE_INPUT_DIR , classifier ):
128
+ @pytest .mark .parametrize ('classifier' , AVAILABLE_CLASSIFIERS ) # FIXME: + ['meta']
129
+ def test_run_training_and_starclass (monkeypatch , PRIVATE_INPUT_DIR , classifier ):
130
+ with tempfile .TemporaryDirectory (prefix = 'starclass-testing-' ) as tmpdir :
131
+ if classifier == 'meta' :
132
+ # For the MetaClassifier, we need to manipulate the training-set
133
+ # a little bit before we can train. We have to mimic that
134
+ # all the other classifiers have already been trained and cross-validated
135
+ # in order to fill up the training-set todo-file with probabilities
136
+ # which the MetaClassifier uses for training.
137
+
138
+ # Change the environment variable to the temp. dir:
139
+ monkeypatch .setenv ("STARCLASS_TSETS" , tmpdir )
137
140
138
- tsetclass = starclass .get_trainingset ('testing' )
139
- tset = tsetclass (tf = 0.2 , random_seed = 42 )
141
+ # Copy the pre-prepared todo-file to the training-set directory:
142
+ tsetclass = starclass .get_trainingset ('keplerq9v3' )
143
+ prepared_todo = os .path .join (PRIVATE_INPUT_DIR , 'meta' , 'keplerq9v3-tset.sqlite' )
144
+ new_todo = os .path .join (tsetclass .find_input_folder (), tsetclass ._todo_name + '.sqlite' )
145
+ os .makedirs (os .path .dirname (new_todo ), exist_ok = True )
146
+ shutil .copyfile (prepared_todo , new_todo )
147
+
148
+ tset = tsetclass (tf = 0.2 , random_seed = 42 )
149
+ tset .fake_metaclassifier = True
150
+ else :
151
+ tsetclass = starclass .get_trainingset ('testing' )
152
+ tset = tsetclass (tf = 0.2 , random_seed = 42 )
140
153
141
- with tempfile .TemporaryDirectory (prefix = 'starclass-testing-' ) as tmpdir :
142
154
logfile = os .path .join (tmpdir , 'training.log' )
143
155
todo_file = os .path .join (PRIVATE_INPUT_DIR , 'todo_run.sqlite' )
144
156
145
157
# Train the classifier:
146
158
out , err , exitcode = capture_run_cli ('run_training.py' , [
147
159
'--classifier=' + classifier ,
148
- '--trainingset=testing' ,
160
+ '--trainingset=' + tset . key ,
149
161
'--level=L1' ,
150
162
'--testfraction=0.2' ,
151
163
'--log=' + logfile ,
152
164
'--log-level=info' ,
153
165
'--output=' + tmpdir
154
166
])
155
167
assert exitcode == 0
168
+ assert ' - INFO - Done.' in out
156
169
157
170
# Check that a log-file was indeed generated:
158
171
assert os .path .isfile (logfile ), "Log-file not generated"
@@ -174,7 +187,7 @@ def test_run_training(PRIVATE_INPUT_DIR, classifier):
174
187
'--debug' ,
175
188
'--overwrite' ,
176
189
'--classifier=' + classifier ,
177
- '--trainingset=testing' ,
190
+ '--trainingset=' + tset . key ,
178
191
'--level=L1' ,
179
192
'--datadir=' + tmpdir ,
180
193
todo_file
@@ -188,10 +201,10 @@ def test_run_training(PRIVATE_INPUT_DIR, classifier):
188
201
189
202
cursor .execute ("SELECT * FROM starclass_settings;" )
190
203
row = cursor .fetchall ()
191
- assert len (row ) == 1 , "Only one settings row should exist"
204
+ assert len (row ) == 1 , "Exactly one settings row should exist"
192
205
settings = row [0 ]
193
206
print (dict (settings ))
194
- assert settings ['tset' ] == 'testtset'
207
+ assert settings ['tset' ] == tset . key
195
208
196
209
cursor .execute ("SELECT * FROM starclass_diagnostics WHERE priority=17;" )
197
210
row = cursor .fetchone ()
@@ -208,12 +221,12 @@ def test_run_training(PRIVATE_INPUT_DIR, classifier):
208
221
print (dict (row ))
209
222
assert row ['priority' ] == 17
210
223
assert row ['classifier' ] == classifier
211
- tset .StellarClasses [row ['class' ]] # Will result in KeyError of not correct
224
+ tset .StellarClasses [row ['class' ]] # Will result in KeyError if not correct
212
225
assert 0 <= row ['prob' ] <= 1 , "Invalid probability"
213
226
214
227
cursor .execute ("SELECT * FROM starclass_features_common;" )
215
228
results = cursor .fetchall ()
216
- assert len (results ) == 1
229
+ assert len (results ) == 1 , "Exactly one features_common row should exist"
217
230
row = dict (results [0 ])
218
231
print (row )
219
232
assert row ['priority' ] == 17
@@ -222,7 +235,7 @@ def test_run_training(PRIVATE_INPUT_DIR, classifier):
222
235
if classifier != 'slosh' :
223
236
cursor .execute (f"SELECT * FROM starclass_features_{ classifier :s} ;" )
224
237
results = cursor .fetchall ()
225
- assert len (results ) == 1
238
+ assert len (results ) == 1 , f"Exactly one features_ { classifier :s } row should exist"
226
239
row = dict (results [0 ])
227
240
print (row )
228
241
assert row ['priority' ] == 17
0 commit comments