Skip to content

Commit 332eb78

Browse files
Uncertainty Baselines Teamcopybara-github
authored andcommitted
CelebA dataset
- Implements the dataset generation for CelebA to work with the new active sampling framework. PiperOrigin-RevId: 476085835
1 parent 2f9f4e9 commit 332eb78

File tree

2 files changed

+244
-19
lines changed

2 files changed

+244
-19
lines changed

experimental/shoshin/configs/celeb_a_resnet_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def get_config() -> ml_collections.ConfigDict:
2323
"""Get mlp config."""
2424
config = base_config.get_config()
2525

26+
config.data.subgroup_ids = ('Blond_Hair',) # ('Blond_Hair')
27+
config.data.subgroup_proportions = (0.01,) # (0.04, 0.012)
28+
2629
data = config.data
2730
data.name = 'celeb_a'
2831
data.num_classes = 2

experimental/shoshin/data.py

Lines changed: 241 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import os
2828
from typing import Any, Dict, Iterator, Optional, Tuple, List, Union
2929

30+
import pandas as pd
3031
import tensorflow as tf
3132
import tensorflow_datasets as tfds
3233

@@ -470,7 +471,7 @@ def get_waterbirds_dataset(
470471
to their respective combined datasets.
471472
"""
472473
split_size_in_pct = int(100 * initial_sample_proportion / num_splits)
473-
reduced_datset_sz = int(100 * initial_sample_proportion)
474+
reduced_dataset_sz = int(100 * initial_sample_proportion)
474475
builder_kwargs = {
475476
'subgroup_ids': subgroup_ids,
476477
'subgroup_proportions': subgroup_proportions
@@ -479,7 +480,7 @@ def get_waterbirds_dataset(
479480
'waterbirds_dataset',
480481
split=[
481482
f'validation[{k}%:{k+split_size_in_pct}%]'
482-
for k in range(0, reduced_datset_sz, split_size_in_pct)
483+
for k in range(0, reduced_dataset_sz, split_size_in_pct)
483484
],
484485
data_dir=DATA_DIR,
485486
builder_kwargs=builder_kwargs,
@@ -490,7 +491,7 @@ def get_waterbirds_dataset(
490491
'waterbirds_dataset',
491492
split=[
492493
f'train[{k}%:{k+split_size_in_pct}%]'
493-
for k in range(0, reduced_datset_sz, split_size_in_pct)
494+
for k in range(0, reduced_dataset_sz, split_size_in_pct)
494495
],
495496
data_dir=DATA_DIR,
496497
builder_kwargs=builder_kwargs,
@@ -528,8 +529,233 @@ def get_waterbirds_dataset(
528529
train_sample_ds=train_sample,
529530
eval_ds=eval_datasets)
530531

532+
IMG_ALIGNED_DATA = ('https://drive.google.com/uc?export=download&'
533+
'id=0B7EVK8r0v71pZjFTYXZWM3FlRnM')
534+
EVAL_LIST = ('https://drive.google.com/uc?export=download&'
535+
'id=0B7EVK8r0v71pY0NSMzRuSXJEVkk')
536+
# Landmark coordinates: left_eye, right_eye etc.
537+
LANDMARKS_DATA = ('https://drive.google.com/uc?export=download&'
538+
'id=0B7EVK8r0v71pd0FJY3Blby1HUTQ')
539+
540+
# Attributes in the image (Eyeglasses, Mustache etc).
541+
ATTR_DATA = ('https://drive.google.com/uc?export=download&'
542+
'id=0B7EVK8r0v71pblRyaVFSWGxPY0U')
543+
544+
LANDMARK_HEADINGS = ('lefteye_x lefteye_y righteye_x righteye_y '
545+
'nose_x nose_y leftmouth_x leftmouth_y rightmouth_x '
546+
'rightmouth_y').split()
547+
ATTR_HEADINGS = (
548+
'5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs '
549+
'Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair '
550+
'Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair '
551+
'Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache '
552+
'Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline '
553+
'Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings '
554+
'Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young'
555+
).split()
556+
557+
_CITATION = """\
558+
@inproceedings{conf/iccv/LiuLWT15,
559+
added-at = {2018-10-09T00:00:00.000+0200},
560+
author = {Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou},
561+
biburl = {https://www.bibsonomy.org/bibtex/250e4959be61db325d2f02c1d8cd7bfbb/dblp},
562+
booktitle = {ICCV},
563+
crossref = {conf/iccv/2015},
564+
ee = {http://doi.ieeecomputersociety.org/10.1109/ICCV.2015.425},
565+
interhash = {3f735aaa11957e73914bbe2ca9d5e702},
566+
intrahash = {50e4959be61db325d2f02c1d8cd7bfbb},
567+
isbn = {978-1-4673-8391-2},
568+
keywords = {dblp},
569+
pages = {3730-3738},
570+
publisher = {IEEE Computer Society},
571+
timestamp = {2018-10-11T11:43:28.000+0200},
572+
title = {Deep Learning Face Attributes in the Wild.},
573+
url = {http://dblp.uni-trier.de/db/conf/iccv/iccv2015.html#LiuLWT15},
574+
year = 2015
575+
}
576+
"""
577+
578+
_DESCRIPTION = """\
579+
CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset\
580+
with more than 200K celebrity images, each with 40 attribute annotations. The \
581+
images in this dataset cover large pose variations and background clutter. \
582+
CelebA has large diversities, large quantities, and rich annotations, including\
583+
- 10,177 number of identities,
584+
- 202,599 number of face images, and
585+
- 5 landmark locations, 40 binary attributes annotations per image.
586+
The dataset can be employed as the training and test sets for the following \
587+
computer vision tasks: face attribute recognition, face detection, and landmark\
588+
(or facial part) localization.
589+
Note: CelebA dataset may contain potential bias. The fairness indicators
590+
[example](https://www.tensorflow.org/responsible_ai/fairness_indicators/tutorials/Fairness_Indicators_TFCO_CelebA_Case_Study)
591+
goes into detail about several considerations to keep in mind while using the
592+
CelebA dataset.
593+
"""
594+
595+
596+
class LocalCelebADataset(tfds.core.GeneratorBasedBuilder):
597+
"""CelebA dataset. Aligned and cropped. With metadata."""
598+
599+
VERSION = tfds.core.Version('2.0.1')
600+
SUPPORTED_VERSIONS = [
601+
tfds.core.Version('2.0.0'),
602+
]
603+
RELEASE_NOTES = {
604+
'2.0.1': 'New split API (https://tensorflow.org/datasets/splits)',
605+
}
606+
607+
def __init__(self,
608+
subgroup_ids: List[str],
609+
subgroup_proportions: Optional[List[float]] = None,
610+
label_attr: Optional[str] = 'Male',
611+
**kwargs):
612+
super(LocalCelebADataset, self).__init__(**kwargs)
613+
self.subgroup_ids = subgroup_ids
614+
self.label_attr = label_attr
615+
if subgroup_proportions:
616+
self.subgroup_proportions = subgroup_proportions
617+
else:
618+
self.subgroup_proportions = [1.] * len(subgroup_ids)
619+
620+
def _info(self):
621+
return tfds.core.DatasetInfo(
622+
builder=self,
623+
features=tfds.features.FeaturesDict({
624+
'example_id':
625+
tfds.features.Text(),
626+
'subgroup_id':
627+
tfds.features.Text(),
628+
'subgroup_label':
629+
tfds.features.ClassLabel(num_classes=2),
630+
'feature':
631+
tfds.features.Image(
632+
shape=(218, 178, 3), encoding_format='jpeg'),
633+
'label':
634+
tfds.features.ClassLabel(num_classes=2),
635+
'image_filename':
636+
tfds.features.Text(),
637+
}),
638+
supervised_keys=('feature', 'label', 'example_id'),
639+
)
640+
641+
def _split_generators(self, dl_manager):
642+
downloaded_dirs = dl_manager.download({
643+
'img_align_celeba': IMG_ALIGNED_DATA,
644+
'list_eval_partition': EVAL_LIST,
645+
'list_attr_celeba': ATTR_DATA,
646+
'landmarks_celeba': LANDMARKS_DATA,
647+
})
648+
649+
# Load all images in memory (~1 GiB)
650+
# Use split to convert: `img_align_celeba/000005.jpg` -> `000005.jpg`
651+
all_images = {
652+
os.path.split(k)[-1]: img for k, img in dl_manager.iter_archive(
653+
downloaded_dirs['img_align_celeba'])
654+
}
655+
return [
656+
tfds.core.SplitGenerator(
657+
name=tfds.Split.TRAIN,
658+
gen_kwargs={
659+
'file_id': 0,
660+
'downloaded_dirs': downloaded_dirs,
661+
'downloaded_images': all_images,
662+
'is_training': True,
663+
}),
664+
tfds.core.SplitGenerator(
665+
name=tfds.Split.VALIDATION,
666+
gen_kwargs={
667+
'file_id': 1,
668+
'downloaded_dirs': downloaded_dirs,
669+
'downloaded_images': all_images,
670+
'is_training': False,
671+
}),
672+
tfds.core.SplitGenerator(
673+
name=tfds.Split.TEST,
674+
gen_kwargs={
675+
'file_id': 2,
676+
'downloaded_dirs': downloaded_dirs,
677+
'downloaded_images': all_images,
678+
'is_training': False,
679+
})
680+
]
681+
682+
def _process_celeba_config_file(self, file_path):
683+
"""Unpack the celeba config file.
684+
685+
The file starts with the number of lines, and a header.
686+
Afterwards, there is a configuration for each file: one per line.
687+
Args:
688+
file_path: Path to the file with the configuration.
689+
Returns:
690+
keys: names of the attributes
691+
values: map from the file name to the list of attribute values for
692+
this file.
693+
"""
694+
695+
with tf.io.gfile.GFile(file_path) as f:
696+
data_raw = f.read()
697+
lines = data_raw.split('\n')
698+
699+
keys = lines[1].strip().split()
700+
values = {}
701+
# Go over each line (skip the last one, as it is empty).
702+
for line in lines[2:-1]:
703+
row_values = line.strip().split()
704+
# Each row start with the 'file_name' and then space-separated values.
705+
values[row_values[0]] = [int(v) for v in row_values[1:]]
706+
return keys, values
707+
708+
def _generate_examples(self, file_id, downloaded_dirs, downloaded_images,
709+
is_training):
710+
"""Yields examples."""
711+
712+
attr_path = downloaded_dirs['list_attr_celeba']
531713

532-
@register_dataset('celeb_a')
714+
attributes = self._process_celeba_config_file(attr_path)
715+
dataset = pd.DataFrame.from_dict(
716+
attributes[1], orient='index', columns=attributes[0])
717+
718+
if is_training:
719+
dataset_size = 300000
720+
sampled_datasets = []
721+
remaining_proportion = 1.
722+
remaining_dataset = dataset.copy()
723+
for idx, subgroup_id in enumerate(self.subgroup_ids):
724+
725+
subgroup_dataset = dataset[dataset[subgroup_id] == 1]
726+
subgroup_sample_size = int(dataset_size *
727+
self.subgroup_proportions[idx])
728+
subgroup_dataset = subgroup_dataset.sample(min(len(subgroup_dataset),
729+
subgroup_sample_size))
730+
sampled_datasets.append(subgroup_dataset)
731+
remaining_proportion -= self.subgroup_proportions[idx]
732+
remaining_dataset = remaining_dataset[remaining_dataset[subgroup_id] ==
733+
-1]
734+
735+
remaining_sample_size = int(dataset_size * remaining_proportion)
736+
remaining_dataset = remaining_dataset.sample(min(len(remaining_dataset),
737+
remaining_sample_size))
738+
sampled_datasets.append(remaining_dataset)
739+
740+
dataset = pd.concat(sampled_datasets)
741+
dataset = dataset.sample(min(len(dataset), dataset_size))
742+
for file_name in dataset.index:
743+
subgroup_id = self.subgroup_ids[0] if dataset.loc[file_name][
744+
self.subgroup_ids[0]] == 1 else 'Not_' + self.subgroup_ids[0]
745+
subgroup_label = 1 if subgroup_id in self.subgroup_ids else 0
746+
label = 1 if dataset.loc[file_name][self.label_attr] == 1 else 0
747+
record = {
748+
'example_id': file_name,
749+
'subgroup_id': subgroup_id,
750+
'subgroup_label': subgroup_label,
751+
'feature': downloaded_images[file_name],
752+
'label': label,
753+
'image_filename': file_name
754+
}
755+
yield file_name, record
756+
757+
758+
@register_dataset('local_celeb_a')
533759
def get_celeba_dataset(
534760
num_splits: int, initial_sample_proportion: float,
535761
subgroup_ids: List[str], subgroup_proportions: List[float],
@@ -549,47 +775,44 @@ def get_celeba_dataset(
549775
combined training dataset, and a dictionary mapping evaluation dataset names
550776
to their respective combined datasets.
551777
"""
552-
del subgroup_proportions, subgroup_ids
553778
read_config = tfds.ReadConfig()
554779
read_config.add_tfds_id = True # Set `True` to return the 'tfds_id' key
780+
555781
split_size_in_pct = int(100 * initial_sample_proportion / num_splits)
556782
reduced_dataset_sz = int(100 * initial_sample_proportion)
783+
builder_kwargs = {
784+
'subgroup_ids': subgroup_ids,
785+
'subgroup_proportions': subgroup_proportions
786+
}
557787
train_splits = tfds.load(
558-
'celeb_a',
788+
'local_celeb_a_dataset',
559789
read_config=read_config,
560790
split=[
561791
f'train[:{k}%]+train[{k+split_size_in_pct}%:]'
562792
for k in range(0, reduced_dataset_sz, split_size_in_pct)
563793
],
794+
builder_kwargs=builder_kwargs,
564795
data_dir=DATA_DIR,
565796
try_gcs=False,
566-
as_supervised=True
567797
)
568798
val_splits = tfds.load(
569-
'celeb_a',
799+
'local_celeb_a_dataset',
570800
read_config=read_config,
571801
split=[
572802
f'validation[{k}%:{k+split_size_in_pct}%]'
573803
for k in range(0, reduced_dataset_sz, split_size_in_pct)
574804
],
805+
builder_kwargs=builder_kwargs,
575806
data_dir=DATA_DIR,
576807
try_gcs=False,
577-
as_supervised=True
578808
)
579-
train_sample = tfds.load(
580-
'celeb_a',
581-
split='train_sample',
582-
data_dir=DATA_DIR,
583-
try_gcs=False,
584-
as_supervised=True,
585-
with_info=False)
586809

587810
test_ds = tfds.load(
588-
'celeb_a',
811+
'local_celeb_a_dataset',
589812
split='test',
813+
builder_kwargs=builder_kwargs,
590814
data_dir=DATA_DIR,
591815
try_gcs=False,
592-
as_supervised=True,
593816
with_info=False)
594817

595818
train_ds = gather_data_splits(list(range(num_splits)), train_splits)
@@ -602,5 +825,4 @@ def get_celeba_dataset(
602825
train_splits,
603826
val_splits,
604827
train_ds,
605-
train_sample_ds=train_sample,
606828
eval_ds=eval_datasets)

0 commit comments

Comments
 (0)