Skip to content

Commit 80e5a03

Browse files
committed
Bugfix and testing
1 parent 37d68f5 commit 80e5a03

File tree

7 files changed

+59
-41
lines changed

7 files changed

+59
-41
lines changed

run_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import argparse
1212
import os
13+
import sys
1314
import logging
1415
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
1516
import tensorflow as tf
@@ -60,7 +61,7 @@ def main():
6061

6162
# Setup logging:
6263
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
63-
console = logging.StreamHandler()
64+
console = logging.StreamHandler(sys.stdout)
6465
console.setFormatter(formatter)
6566
console.setLevel(logging_level)
6667
logger = logging.getLogger('starclass')

starclass/taskmanager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def get_number_tasks(self, classifier=None):
316316
return num
317317

318318
#----------------------------------------------------------------------------------------------
319-
def _query_task(self, classifier=None, priority=None, chunk=1):
319+
def _query_task(self, classifier=None, priority=None, chunk=1, ignore_existing=False):
320320

321321
search_joins = []
322322
search_query = []
@@ -330,7 +330,7 @@ def _query_task(self, classifier=None, priority=None, chunk=1):
330330
search_query.append(f'temp.starclass_todolist.priority={priority:d}')
331331

332332
# If a classifier is specified, constrain to only that classifier:
333-
if classifier is not None:
333+
if classifier is not None and not ignore_existing:
334334
search_joins.append(f"LEFT JOIN starclass_diagnostics ON starclass_diagnostics.priority=temp.starclass_todolist.priority AND starclass_diagnostics.classifier='{classifier:s}'")
335335
search_query.append("starclass_diagnostics.status IS NULL")
336336

@@ -411,7 +411,7 @@ def _query_task(self, classifier=None, priority=None, chunk=1):
411411
return None
412412

413413
#----------------------------------------------------------------------------------------------
414-
def get_task(self, priority=None, classifier=None, change_classifier=True, chunk=1):
414+
def get_task(self, priority=None, classifier=None, change_classifier=True, chunk=1, ignore_existing=False):
415415
"""
416416
Get next task to be processed.
417417
@@ -432,7 +432,7 @@ def get_task(self, priority=None, classifier=None, change_classifier=True, chunk
432432
.. codeauthor:: Rasmus Handberg <[email protected]>
433433
"""
434434

435-
task = self._query_task(classifier=classifier, priority=priority, chunk=chunk)
435+
task = self._query_task(classifier=classifier, priority=priority, chunk=chunk, ignore_existing=ignore_existing)
436436

437437
# If no task is returned for the given classifier, find another
438438
# classifier where tasks are available:
@@ -441,7 +441,7 @@ def get_task(self, priority=None, classifier=None, change_classifier=True, chunk
441441
# task for all of them:
442442
all_tasks = []
443443
for cl in self.all_classifiers.difference([classifier]):
444-
task = self._query_task(classifier=cl, priority=priority, chunk=chunk)
444+
task = self._query_task(classifier=cl, priority=priority, chunk=chunk, ignore_existing=ignore_existing)
445445
if task is not None:
446446
all_tasks.append(task)
447447

@@ -454,7 +454,7 @@ def get_task(self, priority=None, classifier=None, change_classifier=True, chunk
454454

455455
# If this is reached, all classifiers are done, and we can
456456
# start running the MetaClassifier:
457-
task = self._query_task(classifier='meta', priority=priority, chunk=chunk)
457+
task = self._query_task(classifier='meta', priority=priority, chunk=chunk, ignore_existing=ignore_existing)
458458

459459
return task
460460

starclass/training_sets/testing_tset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class testing_tset(TrainingSet):
2121
.. codeauthor:: Rasmus Handberg <[email protected]>
2222
"""
2323
# Class constants:
24-
key = 'testtset'
24+
key = 'testing'
2525
datadir = 'keplerq9v3'
2626
_todo_name = 'todo-testing'
2727

starclass/training_sets/training_set.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def features(self):
459459
with BaseClassifier(tset=self, features_cache=self.features_cache) as stcl:
460460
for rowidx in self.train_idx:
461461
task = self.tm.get_task(priority=rowidx+1, classifier=cl,
462-
change_classifier=False, chunk=1)
462+
change_classifier=False, chunk=1, ignore_existing=True)
463463

464464
# Lightcurve file to load:
465465
# We do not use the one from the database because in the simulations the
@@ -487,7 +487,7 @@ def features_test(self):
487487
# when opened several times in parallel.
488488
for rowidx in self.test_idx:
489489
task = self.tm.get_task(priority=rowidx+1, classifier=cl,
490-
change_classifier=False, chunk=1)
490+
change_classifier=False, chunk=1, ignore_existing=True)
491491

492492
# Lightcurve file to load:
493493
# We do not use the one from the database because in the simulations the

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def capture_run_cli(cli, params=[], mpiexec=False):
4141
if mpiexec:
4242
cmd = ['mpiexec', '-n', '2'] + cmd
4343

44+
print("Running command: " + ' '.join(cmd))
4445
proc = subprocess.Popen(cmd,
4546
cwd=os.path.join(os.path.dirname(__file__), '..'),
4647
stdout=subprocess.PIPE,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:c483fbd9dae17610a6c08931b3590d43f45d44eb27ad883bcabbb80c39b92ef0
3+
size 23379968

tests/test_classifiers.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,42 +26,27 @@
2626
AVAILABLE_CLASSIFIERS.remove('meta')
2727

2828
#--------------------------------------------------------------------------------------------------
29-
@pytest.mark.parametrize('classifier', AVAILABLE_CLASSIFIERS) # FIXME: + ['meta']
29+
@pytest.mark.parametrize('classifier', AVAILABLE_CLASSIFIERS + ['meta'])
3030
def test_classifiers_train_test(monkeypatch, SHARED_INPUT_DIR, classifier):
3131

3232
stcl = starclass.get_classifier(classifier)
3333

34-
# Pick out a task to use for testing:
35-
with starclass.TaskManager(SHARED_INPUT_DIR) as tm:
36-
task1 = tm.get_task(classifier=classifier, change_classifier=False, chunk=1)[0]
37-
print(task1)
38-
3934
with tempfile.TemporaryDirectory(prefix='starclass-testing-') as tmpdir:
4035
if classifier == 'meta':
4136
# For the MetaClassifier, we need to manipulate the training-set
4237
# a little bit before we can train. We have to mimic that
4338
# all the other classifiers have already been trained and cross-validated
4439
# in order to fill up the training-set todo-file with probabilities
4540
# which the MetaClassifier uses for training.
46-
tsetclass = starclass.get_trainingset('keplerq9v3')
47-
input_folder = tsetclass.find_input_folder()
48-
49-
# Create a copy of the root files of the trainings set (ignore that actual data)
50-
# in the temp. directory:
51-
tsetdir = os.path.join(tmpdir, os.path.basename(input_folder))
52-
print("New dummy input folder: %s" % tsetdir)
53-
os.makedirs(tsetdir)
54-
for f in os.listdir(input_folder):
55-
fpath = os.path.join(input_folder, f)
56-
if os.path.isfile(fpath) and not f.endswith(('.sqlite', '.sqlite-journal')):
57-
shutil.copy(fpath, tsetdir)
5841

5942
# Change the environment variable to the temp. dir:
6043
monkeypatch.setenv("STARCLASS_TSETS", tmpdir)
6144

6245
# Copy the pre-prepared todo-file to the training-set directory:
63-
prepared_todo = os.path.join(SHARED_INPUT_DIR, 'meta', 'todo.sqlite')
46+
tsetclass = starclass.get_trainingset('keplerq9v3')
47+
prepared_todo = os.path.join(SHARED_INPUT_DIR, 'meta', 'keplerq9v3-tset.sqlite')
6448
new_todo = os.path.join(tsetclass.find_input_folder(), tsetclass._todo_name + '.sqlite')
49+
os.makedirs(os.path.dirname(new_todo), exist_ok=True)
6550
shutil.copyfile(prepared_todo, new_todo)
6651

6752
# Initialize the training-set in the temp folder,
@@ -72,6 +57,14 @@ def test_classifiers_train_test(monkeypatch, SHARED_INPUT_DIR, classifier):
7257
tsetclass = starclass.get_trainingset('testing')
7358
tset = tsetclass(tf=0.2, random_seed=42)
7459

60+
print(tset)
61+
print(tset.fake_metaclassifier)
62+
63+
# Pick out a task to use for testing:
64+
with starclass.TaskManager(tset.todo_file, load_into_memory=False, classes=tset.StellarClasses) as tm:
65+
task1 = tm.get_task(classifier=classifier, change_classifier=False, chunk=1)[0]
66+
print(task1)
67+
7568
# Initialize the classifier and run training and testing:
7669
with stcl(tset=tset, features_cache=None, data_dir=tmpdir) as cl:
7770
print(cl.data_dir)
@@ -132,27 +125,47 @@ def test_classifiers_train_test(monkeypatch, SHARED_INPUT_DIR, classifier):
132125
assert results1[key] == results2[key], "Non-identical results before and after saving/loading model"
133126

134127
#--------------------------------------------------------------------------------------------------
135-
@pytest.mark.parametrize('classifier', AVAILABLE_CLASSIFIERS)
136-
def test_run_training(PRIVATE_INPUT_DIR, classifier):
128+
@pytest.mark.parametrize('classifier', AVAILABLE_CLASSIFIERS) # FIXME: + ['meta']
129+
def test_run_training_and_starclass(monkeypatch, PRIVATE_INPUT_DIR, classifier):
130+
with tempfile.TemporaryDirectory(prefix='starclass-testing-') as tmpdir:
131+
if classifier == 'meta':
132+
# For the MetaClassifier, we need to manipulate the training-set
133+
# a little bit before we can train. We have to mimic that
134+
# all the other classifiers have already been trained and cross-validated
135+
# in order to fill up the training-set todo-file with probabilities
136+
# which the MetaClassifier uses for training.
137+
138+
# Change the environment variable to the temp. dir:
139+
monkeypatch.setenv("STARCLASS_TSETS", tmpdir)
137140

138-
tsetclass = starclass.get_trainingset('testing')
139-
tset = tsetclass(tf=0.2, random_seed=42)
141+
# Copy the pre-prepared todo-file to the training-set directory:
142+
tsetclass = starclass.get_trainingset('keplerq9v3')
143+
prepared_todo = os.path.join(PRIVATE_INPUT_DIR, 'meta', 'keplerq9v3-tset.sqlite')
144+
new_todo = os.path.join(tsetclass.find_input_folder(), tsetclass._todo_name + '.sqlite')
145+
os.makedirs(os.path.dirname(new_todo), exist_ok=True)
146+
shutil.copyfile(prepared_todo, new_todo)
147+
148+
tset = tsetclass(tf=0.2, random_seed=42)
149+
tset.fake_metaclassifier = True
150+
else:
151+
tsetclass = starclass.get_trainingset('testing')
152+
tset = tsetclass(tf=0.2, random_seed=42)
140153

141-
with tempfile.TemporaryDirectory(prefix='starclass-testing-') as tmpdir:
142154
logfile = os.path.join(tmpdir, 'training.log')
143155
todo_file = os.path.join(PRIVATE_INPUT_DIR, 'todo_run.sqlite')
144156

145157
# Train the classifier:
146158
out, err, exitcode = capture_run_cli('run_training.py', [
147159
'--classifier=' + classifier,
148-
'--trainingset=testing',
160+
'--trainingset=' + tset.key,
149161
'--level=L1',
150162
'--testfraction=0.2',
151163
'--log=' + logfile,
152164
'--log-level=info',
153165
'--output=' + tmpdir
154166
])
155167
assert exitcode == 0
168+
assert ' - INFO - Done.' in out
156169

157170
# Check that a log-file was indeed generated:
158171
assert os.path.isfile(logfile), "Log-file not generated"
@@ -174,7 +187,7 @@ def test_run_training(PRIVATE_INPUT_DIR, classifier):
174187
'--debug',
175188
'--overwrite',
176189
'--classifier=' + classifier,
177-
'--trainingset=testing',
190+
'--trainingset=' + tset.key,
178191
'--level=L1',
179192
'--datadir=' + tmpdir,
180193
todo_file
@@ -188,10 +201,10 @@ def test_run_training(PRIVATE_INPUT_DIR, classifier):
188201

189202
cursor.execute("SELECT * FROM starclass_settings;")
190203
row = cursor.fetchall()
191-
assert len(row) == 1, "Only one settings row should exist"
204+
assert len(row) == 1, "Exactly one settings row should exist"
192205
settings = row[0]
193206
print(dict(settings))
194-
assert settings['tset'] == 'testtset'
207+
assert settings['tset'] == tset.key
195208

196209
cursor.execute("SELECT * FROM starclass_diagnostics WHERE priority=17;")
197210
row = cursor.fetchone()
@@ -208,12 +221,12 @@ def test_run_training(PRIVATE_INPUT_DIR, classifier):
208221
print(dict(row))
209222
assert row['priority'] == 17
210223
assert row['classifier'] == classifier
211-
tset.StellarClasses[row['class']] # Will result in KeyError of not correct
224+
tset.StellarClasses[row['class']] # Will result in KeyError if not correct
212225
assert 0 <= row['prob'] <= 1, "Invalid probability"
213226

214227
cursor.execute("SELECT * FROM starclass_features_common;")
215228
results = cursor.fetchall()
216-
assert len(results) == 1
229+
assert len(results) == 1, "Exactly one features_common row should exist"
217230
row = dict(results[0])
218231
print(row)
219232
assert row['priority'] == 17
@@ -222,7 +235,7 @@ def test_run_training(PRIVATE_INPUT_DIR, classifier):
222235
if classifier != 'slosh':
223236
cursor.execute(f"SELECT * FROM starclass_features_{classifier:s};")
224237
results = cursor.fetchall()
225-
assert len(results) == 1
238+
assert len(results) == 1, f"Exactly one features_{classifier:s} row should exist"
226239
row = dict(results[0])
227240
print(row)
228241
assert row['priority'] == 17

0 commit comments

Comments
 (0)