Skip to content

Commit

Permalink
update test for uploading dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
shunk031 committed Jun 14, 2024
1 parent 672f30e commit 0ccf579
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 62 deletions.
8 changes: 4 additions & 4 deletions MSCOCO.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ class MsCocoConfig(ds.BuilderConfig):
TASKS: Tuple[str, ...] = (
"captions",
"instances",
"person_keypoints",
"person-keypoints",
)

def __init__(
Expand Down Expand Up @@ -894,7 +894,7 @@ def get_processor(self) -> MsCocoProcessor:
return CaptionsProcessor()
elif self.task == "instances":
return InstancesProcessor()
elif self.task == "person_keypoints":
elif self.task == "person-keypoints":
return PersonKeypointsProcessor()
else:
raise ValueError(f"Invalid task: {self.task}")
Expand Down Expand Up @@ -924,7 +924,7 @@ def dataset_configs(year: int, version: ds.Version) -> List[MsCocoConfig]:
),
MsCocoConfig(
year=year,
coco_task="person_keypoints",
coco_task="person-keypoints",
version=version,
),
# MsCocoConfig(
Expand All @@ -934,7 +934,7 @@ def dataset_configs(year: int, version: ds.Version) -> List[MsCocoConfig]:
# ),
# MsCocoConfig(
# year=year,
# coco_task=("captions", "person_keypoints"),
# coco_task=("captions", "person-keypoints"),
# version=version,
# ),
]
Expand Down
47 changes: 47 additions & 0 deletions tests/MSCOCO_caption_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os

import datasets as ds
import pytest


@pytest.mark.skipif(
condition=bool(os.environ.get("CI", False)),
reason=(
"Because this loading script downloads a large dataset, "
"we will skip running it on CI."
),
)
@pytest.mark.parametrize(
argnames=(
"dataset_year",
"coco_task",
"expected_num_train",
"expected_num_validation",
),
argvalues=(
(2014, "captions", 82783, 40504),
(2017, "captions", 118287, 5000),
),
)
def test_load_caption_dataset(
dataset_path: str,
dataset_year: int,
coco_task: str,
expected_num_train: int,
expected_num_validation: int,
repo_id: str,
):
dataset = ds.load_dataset(
path=dataset_path,
year=dataset_year,
coco_task=coco_task,
)
assert isinstance(dataset, ds.DatasetDict)

assert dataset["train"].num_rows == expected_num_train
assert dataset["validation"].num_rows == expected_num_validation

dataset.push_to_hub(
repo_id=repo_id,
config_name=f"year={dataset_year}_task={coco_task}",
)
56 changes: 56 additions & 0 deletions tests/MSCOCO_instances_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os

import datasets as ds
import pytest


@pytest.mark.skipif(
condition=bool(os.environ.get("CI", False)),
reason=(
"Because this loading script downloads a large dataset, "
"we will skip running it on CI."
),
)
@pytest.mark.parametrize(
argnames="decode_rle,",
argvalues=(
True,
False,
),
)
@pytest.mark.parametrize(
argnames=(
"dataset_year",
"coco_task",
"expected_num_train",
"expected_num_validation",
),
argvalues=(
(2014, "instances", 82081, 40137),
(2017, "instances", 117266, 4952),
),
)
def test_load_instances_dataset(
dataset_path: str,
dataset_year: int,
coco_task: str,
decode_rle: bool,
expected_num_train: int,
expected_num_validation: int,
repo_id: str,
):
dataset = ds.load_dataset(
path=dataset_path,
year=dataset_year,
coco_task=coco_task,
decode_rle=decode_rle,
)
assert isinstance(dataset, ds.DatasetDict)

assert dataset["train"].num_rows == expected_num_train
assert dataset["validation"].num_rows == expected_num_validation

dataset.push_to_hub(
repo_id=repo_id,
config_name=f"year={dataset_year}_task={coco_task}_decode-rle={decode_rle}",
)
56 changes: 56 additions & 0 deletions tests/MSCOCO_keypoints_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os

import datasets as ds
import pytest


@pytest.mark.skipif(
condition=bool(os.environ.get("CI", False)),
reason=(
"Because this loading script downloads a large dataset, "
"we will skip running it on CI."
),
)
@pytest.mark.parametrize(
argnames="decode_rle,",
argvalues=(
True,
False,
),
)
@pytest.mark.parametrize(
argnames=(
"dataset_year",
"coco_task",
"expected_num_train",
"expected_num_validation",
),
argvalues=(
(2014, "person-keypoints", 45174, 21634),
(2017, "person-keypoints", 64115, 2693),
),
)
def test_load_keypoints_dataset(
dataset_path: str,
dataset_year: int,
coco_task: str,
decode_rle: bool,
expected_num_train: int,
expected_num_validation: int,
repo_id: str,
):
dataset = ds.load_dataset(
path=dataset_path,
year=dataset_year,
coco_task=coco_task,
decode_rle=decode_rle,
)
assert isinstance(dataset, ds.DatasetDict)

assert dataset["train"].num_rows == expected_num_train
assert dataset["validation"].num_rows == expected_num_validation

dataset.push_to_hub(
repo_id=repo_id,
config_name=f"year={dataset_year}_task={coco_task}_decode-rle={decode_rle}",
)
58 changes: 0 additions & 58 deletions tests/MSCOCO_test.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,6 @@
import os

import datasets as ds
import pytest

from MSCOCO import CATEGORIES, SUPER_CATEGORIES


@pytest.fixture
def dataset_path() -> str:
return "MSCOCO.py"


@pytest.mark.skipif(
condition=bool(os.environ.get("CI", False)),
reason=(
"Because this loading script downloads a large dataset, "
"we will skip running it on CI."
),
)
@pytest.mark.parametrize(
argnames="decode_rle,",
argvalues=(
True,
False,
),
)
@pytest.mark.parametrize(
argnames=(
"dataset_year",
"coco_task",
"expected_num_train",
"expected_num_validation",
),
argvalues=(
(2014, "captions", 82783, 40504),
(2017, "captions", 118287, 5000),
(2014, "instances", 82081, 40137),
(2017, "instances", 117266, 4952),
(2014, "person_keypoints", 45174, 21634),
(2017, "person_keypoints", 64115, 2693),
),
)
def test_load_dataset(
dataset_path: str,
dataset_year: int,
coco_task: str,
decode_rle: bool,
expected_num_train: int,
expected_num_validation: int,
):
dataset = ds.load_dataset(
path=dataset_path,
year=dataset_year,
coco_task=coco_task,
decode_rle=decode_rle,
)
assert dataset["train"].num_rows == expected_num_train
assert dataset["validation"].num_rows == expected_num_validation


def test_consts():
assert len(CATEGORIES) == 80
assert len(SUPER_CATEGORIES) == 12
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest


@pytest.fixture
def org_name() -> str:
return "shunk031"


@pytest.fixture
def dataset_name() -> str:
return "MSCOCO"


@pytest.fixture
def dataset_path(dataset_name: str) -> str:
return f"{dataset_name}.py"


@pytest.fixture
def repo_id(org_name: str, dataset_name: str) -> str:
return f"{org_name}/{dataset_name}"

0 comments on commit 0ccf579

Please sign in to comment.