diff --git a/test/test_datasets.py b/test/test_datasets.py index 22c14cbc08d..0b6b08924d9 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1056,6 +1056,50 @@ def test_not_found_or_corrupted(self): with pytest.raises(datasets_utils.lazy_importer.lmdb.Error): super().test_not_found_or_corrupted() + def test_class_name_verification(self): + err_msg = ( + "Unknown value '{}' for LSUN class. " + "The valid value is one of 'train', 'val' or 'test' or a list of categories " + "e.g. ['bedroom_train', 'bedroom_val', 'bridge_train', 'bridge_val', " + "'church_outdoor_train', 'church_outdoor_val', 'classroom_train', 'classroom_val', " + "'conference_room_train', 'conference_room_val', 'dining_room_train', 'dining_room_val', " + "'kitchen_train', 'kitchen_val', 'living_room_train', 'living_room_val', " + "'restaurant_train', 'restaurant_val', 'tower_train', 'tower_val']." + ) + + cases = [ + "bedroom", + "bedroom_train", + ] + for case in cases: + with pytest.raises( + ValueError, + match=re.escape(err_msg.format(case)), + ): + with self.create_dataset(classes=case): + pass + + for case in [ + ["bedroom_train", "bedroom"], + ["bedroom_train", "bedroommmmmmmm_val"], + ]: + with pytest.raises( + ValueError, + match=re.escape(err_msg.format(case[-1])), + ): + with self.create_dataset(classes=case): + pass + + for case in [[None], [1]]: + with pytest.raises( + TypeError, + match=re.escape( + f"Expected type str for elements in argument classes, but got type {type(case[0]).__name__}." + ), + ): + with self.create_dataset(classes=case): + pass + class KineticsTestCase(datasets_utils.VideoDatasetTestCase): DATASET_CLASS = datasets.Kinetics diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index 6f6c7a5eb63..89438514804 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -4,11 +4,11 @@ import string from collections.abc import Iterable from pathlib import Path -from typing import Any, Callable, cast, Optional, Union +from typing import Any, Callable, Optional, Union from PIL import Image -from .utils import iterable_to_str, verify_str_arg +from .utils import verify_str_arg from .vision import VisionDataset @@ -108,31 +108,45 @@ def _verify_classes(self, classes: Union[str, list[str]]) -> list[str]: ] dset_opts = ["train", "val", "test"] - try: - classes = cast(str, classes) - verify_str_arg(classes, "classes", dset_opts) + err_msg = ( + "Unknown value '{classes}' for LSUN class. " + "The valid value is one of 'train', 'val' or 'test' or a list of categories " + "e.g. ['bedroom_train', 'bedroom_val', 'bridge_train', 'bridge_val', " + "'church_outdoor_train', 'church_outdoor_val', 'classroom_train', 'classroom_val', " + "'conference_room_train', 'conference_room_val', 'dining_room_train', 'dining_room_val', " + "'kitchen_train', 'kitchen_val', 'living_room_train', 'living_room_val', " + "'restaurant_train', 'restaurant_val', 'tower_train', 'tower_val']." + ) + + if isinstance(classes, str): + if classes not in dset_opts: + raise ValueError(err_msg.format(classes=classes)) + # If classes is a string, it should be one of the dataset options + # and not a specific category. if classes == "test": classes = [classes] else: classes = [c + "_" + classes for c in categories] - except ValueError: - if not isinstance(classes, Iterable): - msg = "Expected type str or Iterable for argument classes, but got type {}." - raise ValueError(msg.format(type(classes))) - + elif isinstance(classes, Iterable): classes = list(classes) msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}." + for c in classes: - verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c))) + if not isinstance(c, str): + raise TypeError(msg_fmtstr_type.format(type(c).__name__)) + msg = err_msg.format(classes=c) + c_short = c.split("_") + if len(c_short) < 2: + raise ValueError(msg) category, dset_opt = "_".join(c_short[:-1]), c_short[-1] - msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." - msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories)) verify_str_arg(category, valid_values=categories, custom_msg=msg) - - msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) + else: + raise TypeError( + f"Expected type str or Iterable for argument classes, but got type {type(classes).__name__}." + ) return classes