Skip to content

fix: fix lsun dataset error message. #9152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,50 @@ def test_not_found_or_corrupted(self):
with pytest.raises(datasets_utils.lazy_importer.lmdb.Error):
super().test_not_found_or_corrupted()

def test_class_name_verification(self):
err_msg = (
"Unknown value '{}' for LSUN class. "
"The valid value is one of 'train', 'val' or 'test' or a list of categories "
"e.g. ['bedroom_train', 'bedroom_val', 'bridge_train', 'bridge_val', "
"'church_outdoor_train', 'church_outdoor_val', 'classroom_train', 'classroom_val', "
"'conference_room_train', 'conference_room_val', 'dining_room_train', 'dining_room_val', "
"'kitchen_train', 'kitchen_val', 'living_room_train', 'living_room_val', "
"'restaurant_train', 'restaurant_val', 'tower_train', 'tower_val']."
)

cases = [
"bedroom",
"bedroom_train",
]
for case in cases:
with pytest.raises(
ValueError,
match=re.escape(err_msg.format(case)),
):
with self.create_dataset(classes=case):
pass

for case in [
["bedroom_train", "bedroom"],
["bedroom_train", "bedroommmmmmmm_val"],
]:
with pytest.raises(
ValueError,
match=re.escape(err_msg.format(case[-1])),
):
with self.create_dataset(classes=case):
pass

for case in [[None], [1]]:
with pytest.raises(
TypeError,
match=re.escape(
f"Expected type str for elements in argument classes, but got type {type(case[0]).__name__}."
),
):
with self.create_dataset(classes=case):
pass


class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.Kinetics
Expand Down
44 changes: 29 additions & 15 deletions torchvision/datasets/lsun.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import string
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Callable, cast, Optional, Union
from typing import Any, Callable, Optional, Union

from PIL import Image

from .utils import iterable_to_str, verify_str_arg
from .utils import verify_str_arg
from .vision import VisionDataset


Expand Down Expand Up @@ -108,31 +108,45 @@ def _verify_classes(self, classes: Union[str, list[str]]) -> list[str]:
]
dset_opts = ["train", "val", "test"]

try:
classes = cast(str, classes)
verify_str_arg(classes, "classes", dset_opts)
err_msg = (
"Unknown value '{classes}' for LSUN class. "
"The valid value is one of 'train', 'val' or 'test' or a list of categories "
"e.g. ['bedroom_train', 'bedroom_val', 'bridge_train', 'bridge_val', "
"'church_outdoor_train', 'church_outdoor_val', 'classroom_train', 'classroom_val', "
"'conference_room_train', 'conference_room_val', 'dining_room_train', 'dining_room_val', "
"'kitchen_train', 'kitchen_val', 'living_room_train', 'living_room_val', "
"'restaurant_train', 'restaurant_val', 'tower_train', 'tower_val']."
)

if isinstance(classes, str):
if classes not in dset_opts:
raise ValueError(err_msg.format(classes=classes))
# If classes is a string, it should be one of the dataset options
# and not a specific category.
if classes == "test":
classes = [classes]
else:
classes = [c + "_" + classes for c in categories]
except ValueError:
if not isinstance(classes, Iterable):
msg = "Expected type str or Iterable for argument classes, but got type {}."
raise ValueError(msg.format(type(classes)))

elif isinstance(classes, Iterable):
classes = list(classes)
msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."

for c in classes:
verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
if not isinstance(c, str):
raise TypeError(msg_fmtstr_type.format(type(c).__name__))
msg = err_msg.format(classes=c)

c_short = c.split("_")
if len(c_short) < 2:
raise ValueError(msg)
category, dset_opt = "_".join(c_short[:-1]), c_short[-1]

msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
verify_str_arg(category, valid_values=categories, custom_msg=msg)

msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
else:
raise TypeError(
f"Expected type str or Iterable for argument classes, but got type {type(classes).__name__}."
)

return classes

Expand Down