Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions psb2/psb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,13 @@ def fetch_examples(datasets_directory, problem_name, n_train, n_test, format='ps
assert n_train < 1000000, "Cannot sample more than 1 million examples"
assert n_test < 1000000, "Cannot sample more than 1 million examples"

# Check whether problem_name is a valid PSB2 problem
if problem_name.strip() not in get_problem_names():
raise AttributeError('The provided problem_name (' + problem_name.strip() +') is not available in the list of PSB2 problems, which are the followings: ' + ', '.join(get_problem_names()) + '.')

# Load data
edge_data = fetch_and_possibly_cache_data(datasets_directory, problem_name, "edge")
random_data = fetch_and_possibly_cache_data(datasets_directory, problem_name, "random")
edge_data = fetch_and_possibly_cache_data(datasets_directory, problem_name.strip(), "edge")
random_data = fetch_and_possibly_cache_data(datasets_directory, problem_name.strip(), "random")

# Seed RNG source
random.seed(seed)
Expand All @@ -124,4 +128,4 @@ def fetch_examples(datasets_directory, problem_name, n_train, n_test, format='ps

def get_problem_names():
"""Returns a list of strings of the problem names in PSB2."""
return PROBLEMS
return [pro for pro in PROBLEMS]