diff --git a/cumulus_library/upload.py b/cumulus_library/upload.py index 01f70785..a1f556a8 100644 --- a/cumulus_library/upload.py +++ b/cumulus_library/upload.py @@ -72,6 +72,13 @@ def upload_files(args: dict): "study export folder." ) file_paths = list(args["data_path"].glob("**/*.parquet")) + if args["target"]: + filtered_paths = [] + for path in file_paths: + if any(study in str(path) for study in args["target"]): + filtered_paths.append(path) + file_paths = filtered_paths + if not args["user"] or not args["id"]: sys.exit("user/id not provided, please pass --user and --id") try: diff --git a/tests/test_cli.py b/tests/test_cli.py index e911ba0f..4e5915a8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -11,6 +11,7 @@ import requests_mock from cumulus_library import cli +from cumulus_library import upload @mock.patch("pyathena.connect") @@ -204,7 +205,7 @@ def test_cli_creates_studies( def test_cli_upload_studies(mock_glob, args, status, login_error, raises): mock_glob.side_effect = [ [Path(__file__)], - [Path(str(Path(__file__)) + "/test_data/count_synthea_patient.parquet")], + [Path(str(Path(__file__).parent) + "/test_data/count_synthea_patient.parquet")], ] with raises: with requests_mock.Mocker() as r: @@ -217,3 +218,36 @@ def test_cli_upload_studies(mock_glob, args, status, login_error, raises): ) r.post("https://presigned.url.org", status_code=status) cli.main(cli_args=args) + + +@pytest.mark.parametrize( + "args,calls", + [ + (["upload", "--user", "user", "--id", "id", "./foo"], 2), + (["upload", "--user", "user", "--id", "id", "./foo", "-t", "test_data"], 1), + (["upload", "--user", "user", "--id", "id", "./foo", "-t", "not_found"], 0), + ], +) +@mock.patch.dict( + os.environ, + clear=True, +) +@mock.patch("pathlib.Path.glob") +@mock.patch("cumulus_library.upload.upload_data") +def test_cli_upload_filter(mock_upload_data, mock_glob, args, calls): + mock_glob.side_effect = [ + [ + Path( + str(Path(__file__).parent) + "/test_data/count_synthea_patient.parquet" + ), + Path( + str(Path(__file__).parent) + "/other_data/count_synthea_patient.parquet" + ), + ], + ] + cli.main(cli_args=args) + if len(mock_upload_data.call_args_list) == 1: + target = args[args.index("-t") + 1] + # filepath is in the third argument position in the upload data arg list + assert target in str(mock_upload_data.call_args[0][2]) + assert mock_upload_data.call_count == calls