Skip to content

Commit 78c3554

Browse files
Uncertainty Baselines Teamcopybara-github
authored andcommitted
Allows calculation of Tracin values
PiperOrigin-RevId: 493218263
1 parent 4c6032f commit 78c3554

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# coding=utf-8
2+
# Copyright 2022 The Uncertainty Baselines Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Configuration file for experiment with Waterbirds data and ResNet model."""
17+
18+
import ml_collections
19+
from configs import base_config # local file import from experimental.shoshin
20+
21+
22+
def get_signal_config():
23+
"""Get training config."""
24+
config = ml_collections.ConfigDict()
25+
config.checkpoint_selection = 'first'
26+
config.checkpoint_list = [
27+
'epoch-01-val_auc-0.76.ckpt',
28+
]
29+
config.checkpoint_name = ''
30+
config.checkpoint_number = 5
31+
config.included_layers = -2
32+
return config
33+
34+
35+
def get_config() -> ml_collections.ConfigDict:
36+
"""Get mlp config."""
37+
config = base_config.get_config()
38+
39+
# Consider landbirds on water and waterbirds on land as subgroups.
40+
config.data.subgroup_ids = () # ('0_1', '1_0')
41+
config.data.subgroup_proportions = () # (0.04, 0.012)
42+
config.data.initial_sample_proportion = 1
43+
44+
config.active_sampling.num_samples_per_round = 500
45+
config.num_rounds = 4
46+
47+
data = config.data
48+
data.name = 'waterbirds10k'
49+
data.num_classes = 2
50+
51+
model = config.model
52+
model.name = 'resnet'
53+
model.dropout_rate = 0.2
54+
55+
config.output_dir = ''
56+
config.generate_individual_table = True
57+
config.round_idx = 0
58+
59+
# To ensure that the last layers actually predict the outcome and not the bias
60+
config.train_bias = False
61+
62+
config.signal = get_signal_config()
63+
64+
return config
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# coding=utf-8
2+
# Copyright 2022 The Uncertainty Baselines Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
r"""Binary executable for generating tracin table.
17+
18+
This file serves as a binary to calculate tracin values and create a lookup table
19+
that maps from example ID to tracin label.
20+
21+
Usage:
22+
# pylint: disable=line-too-long
23+
24+
ml_python3 third_party/py/uncertainty_baselines/experimental/shoshin/generate_tracin_table.py \
25+
--adhoc_import_modules=uncertainty_baselines \
26+
-- \
27+
--xm_runlocal \
28+
--logtostderr \
29+
--config=third_party/py/uncertainty_baselines/experimental/shoshin/configs/waterbirds_resnet_tracin_config.py
30+
31+
# pylint: enable=line-too-long
32+
33+
Note: In output_dir, models trained on different splits of data must already
34+
exist and be present in directory.
35+
"""
36+
37+
import os
38+
39+
from absl import app
40+
from absl import flags
41+
from absl import logging
42+
from ml_collections import config_flags
43+
import data # local file import from experimental.shoshin
44+
import generate_bias_table_lib # local file import from experimental.shoshin
45+
import models # local file import from experimental.shoshin
46+
import sampling_policies # local file import from experimental.shoshin
47+
from configs import base_config # local file import from experimental.shoshin
48+
49+
50+
FLAGS = flags.FLAGS
51+
config_flags.DEFINE_config_file('config')
52+
53+
54+
def main(_) -> None:
55+
56+
config = FLAGS.config
57+
base_config.check_flags(config)
58+
ckpt_dir = os.path.join(config.output_dir,
59+
generate_bias_table_lib.CHECKPOINT_SUBDIR)
60+
model_params = models.ModelTrainingParameters(
61+
model_name=config.model.name,
62+
train_bias=config.train_bias,
63+
num_classes=config.data.num_classes,
64+
num_subgroups=0,
65+
num_epochs=config.training.num_epochs,
66+
learning_rate=config.optimizer.learning_rate,
67+
hidden_sizes=config.model.hidden_sizes,
68+
)
69+
70+
dataset_builder = data.get_dataset(config.data.name)
71+
if config.generate_individual_table:
72+
if config.round_idx == 0:
73+
dataloader = dataset_builder(config.data.num_splits,
74+
config.data.initial_sample_proportion,
75+
config.data.subgroup_ids,
76+
config.data.subgroup_proportions,)
77+
else:
78+
dataloader = dataset_builder(config.data.num_splits, 1,
79+
config.data.subgroup_ids,
80+
config.data.subgroup_proportions,)
81+
# Filter each split to only have examples from example_ids_table
82+
dataloader.train_splits = [
83+
dataloader.train_ds.filter(
84+
generate_bias_table_lib.filter_ids_fn(ids_tab)) for
85+
ids_tab in sampling_policies.convert_ids_to_table(config.ids_dir)]
86+
dataloader = data.apply_batch(dataloader, config.data.batch_size)
87+
model_params.num_subgroups = dataloader.num_subgroups
88+
model_checkpoints = generate_bias_table_lib.load_model_checkpoints(
89+
ckpt_dir, model_params, config.signal.checkpoint_list,
90+
config.signal.checkpoint_selection, config.signal.checkpoint_number,
91+
config.signal.checkpoint_name)
92+
93+
logging.info('%s checkpoints loaded', len(model_checkpoints))
94+
if config.signal.checkpoint_selection == 'name':
95+
table_name = config.signal.checkpoint_name
96+
else:
97+
table_name = config.signal.checkpoint_selection
98+
_ = generate_bias_table_lib.get_example_id_to_tracin_value_table(
99+
dataloader=dataloader,
100+
model_checkpoints=model_checkpoints,
101+
included_layers=config.signal.included_layers,
102+
save_dir=config.save_dir,
103+
save_table=True,
104+
table_name=table_name)
105+
else:
106+
# TODO(martinstrobel): Combine individual tracinvalues to a mean value
107+
raise NotImplementedError('Not implemented yet')
108+
109+
110+
if __name__ == '__main__':
111+
app.run(main)

0 commit comments

Comments
 (0)