Skip to content

Commit

Permalink
added unit test for model/benchmark retrieval method
Browse files Browse the repository at this point in the history
  • Loading branch information
shehadak committed Nov 2, 2023
1 parent 2307893 commit 03d8f1e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
8 changes: 3 additions & 5 deletions brainscore_language/submission/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ def _get_ids(args_dict: Dict[str, Union[str, List]], key: str) -> Union[List, st


def get_models_and_benchmarks(args_dict: Dict[str, Union[str, List]]) -> Tuple[List[str], List[str]]:
"""
Identifies the entire set of models and benchmarks to be scored and saves them as environment
variables `BS_NEW_MODELS` and `BS_NEW_BENCHMARKS`.
"""
""" Identifies the entire set of models and benchmarks to be scored and prints them to stdout. """

new_models = _get_ids(args_dict, 'new_models')
new_benchmarks = _get_ids(args_dict, 'new_benchmarks')
Expand Down Expand Up @@ -101,7 +98,7 @@ def get_models_and_benchmarks(args_dict: Dict[str, Union[str, List]]) -> Tuple[L


def run_scoring(args_dict: Dict[str, Union[str, List]]):
""" prepares parameters for the `run_scoring_endpoint`. """
""" prepares parameters for and calls the `run_scoring_endpoint`. """
new_models = _get_ids(args_dict, 'new_models')
new_benchmarks = _get_ids(args_dict, 'new_benchmarks')

Expand Down Expand Up @@ -132,6 +129,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument('--new_benchmarks', type=str, nargs='*', default=None,
help='The identifiers of newly submitted benchmarks on which to score all models')
parser.add_argument('--fn', type=str, nargs='?', default='run_scoring',
choices=['run_scoring', 'get_models_and_benchmarks'],
help='The endpoint method to run. `run_scoring` to score `new_models` on `new_benchmarks`, or `get_models_and_benchmarks` to respond with a list of models and benchmarks to score.')
args, remaining_args = parser.parse_known_args()

Expand Down
14 changes: 12 additions & 2 deletions tests/test_submission/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
""" the mock import has to be before importing endpoints so that the database is properly mocked """
from .mock_config import test_database

from brainscore_core.submission import database_models
from brainscore_core.submission import database_models, RunScoringEndpoint
from brainscore_core.submission.database import connect_db
from brainscore_core.submission.database_models import clear_schema
from brainscore_language.submission.endpoints import run_scoring
from brainscore_language.submission.endpoints import run_scoring, get_models_and_benchmarks


logger = logging.getLogger(__name__)
Expand All @@ -32,6 +32,16 @@ def teardown_method(self):
logger.info('Clean database')
clear_schema()

def test_get_models_benchmarks(self):
new_models = ['randomembedding-100']
new_benchmarks = ['Pereira2018.243sentences-linear']
args_dict = {'jenkins_id': 62, 'user_id': 1, 'model_type': 'artificialsubject',
'public': True, 'competition': 'None', 'new_models': new_models,
'new_benchmarks': new_benchmarks, 'specified_only': True}
model_ids, benchmark_ids = get_models_and_benchmarks(args_dict)
assert model_ids == new_models
assert benchmark_ids == new_benchmarks

def test_successful_run(self):
args_dict = {'jenkins_id': 62, 'user_id': 1, 'model_type': 'artificialsubject',
'public': True, 'competition': 'None', 'new_models': ['randomembedding-100'],
Expand Down

0 comments on commit 03d8f1e

Please sign in to comment.