diff --git a/psb2/psb2.py b/psb2/psb2.py index 937c1be..4b6c849 100644 --- a/psb2/psb2.py +++ b/psb2/psb2.py @@ -101,9 +101,13 @@ def fetch_examples(datasets_directory, problem_name, n_train, n_test, format='ps assert n_train < 1000000, "Cannot sample more than 1 million examples" assert n_test < 1000000, "Cannot sample more than 1 million examples" + # Check whether problem_name is a valid PSB2 problem + if problem_name.strip() not in get_problem_names(): + raise AttributeError('The provided problem_name (' + problem_name.strip() +') is not available in the list of PSB2 problems, which are the followings: ' + ', '.join(get_problem_names()) + '.') + # Load data - edge_data = fetch_and_possibly_cache_data(datasets_directory, problem_name, "edge") - random_data = fetch_and_possibly_cache_data(datasets_directory, problem_name, "random") + edge_data = fetch_and_possibly_cache_data(datasets_directory, problem_name.strip(), "edge") + random_data = fetch_and_possibly_cache_data(datasets_directory, problem_name.strip(), "random") # Seed RNG source random.seed(seed) @@ -124,4 +128,4 @@ def fetch_examples(datasets_directory, problem_name, n_train, n_test, format='ps def get_problem_names(): """Returns a list of strings of the problem names in PSB2.""" - return PROBLEMS + return [pro for pro in PROBLEMS]