From 822ca3f0b5b27a21114831237aa3a480dbd8ee95 Mon Sep 17 00:00:00 2001
From: Shivank Garg <128126577+shivank21@users.noreply.github.com>
Date: Wed, 3 Jan 2024 12:58:57 +0530
Subject: [PATCH] Added Files for Paper Implementation
---
README.md | 122 +++++++++++-
Scripts/train_demo.sh | 30 +++
Scripts/train_demo_multigpu.sh | 32 ++++
dataset.py | 95 ++++++++++
inference.py | 137 ++++++++++++++
plot.py | 228 +++++++++++++++++++++++
requirements.sh | 8 +
score.py | 72 +++++++
train.py | 331 +++++++++++++++++++++++++++++++++
9 files changed, 1054 insertions(+), 1 deletion(-)
create mode 100644 Scripts/train_demo.sh
create mode 100644 Scripts/train_demo_multigpu.sh
create mode 100644 dataset.py
create mode 100644 inference.py
create mode 100644 plot.py
create mode 100644 requirements.sh
create mode 100644 score.py
create mode 100644 train.py
diff --git a/README.md b/README.md
index 34cdcff..97d2445 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,121 @@
-# confidence-is-all-you-need
\ No newline at end of file
+## Confidence is All You Need for MI Attacks
+
+This directory contains code to reproduce our paper:
+**"Confidence is all you need for MI Attacks"**
+https://arxiv.org/abs/2311.15373
+by Abhishek Sinha, Himanshi Tibrewal, Mansi Gupta, Nikhar Waghela, Shivank Garg
+
+Our work is based upon :
+
+**"Membership Inference Attacks From First Principles"**
+https://arxiv.org/abs/2112.03570
+by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramèr.
+
+### INSTALLING DEPENDENCIES
+To install the basic dependencies needed to run this repository
+
+>bash requirements.sh
+
+We train our models with JAX + ObJAX so you will need to follow build instructions for that
+https://github.com/google/objax
+https://objax.readthedocs.io/en/latest/installation_setup.html
+
+### RUNNING THE CODE
+
+#### 1. Train the models
+
+The first step in our attack is to train shadow models. As a baseline that
+should give most of the gains in our attack, you should start by training 16
+shadow models with the command
+
+> bash scripts/train_demo.sh
+
+or if you have multiple GPUs on your machine and want to train these models in
+parallel, then modify and run
+
+> bash scripts/train_demo_multigpu.sh
+
+This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and
+will output a bunch of files under the directory exp/cifar10 with structure:
+
+```
+exp/cifar10/
+- experiment_N_of_16
+-- hparams.json
+-- keep.npy
+-- ckpt/
+--- 0000000100.npz
+-- tb/
+```
+
+#### 2. Perform inference
+
+Once the models are trained, now it's necessary to perform inference and save
+the output features for each training example for each model in the dataset.
+
+> python3 inference.py --logdir=exp/cifar10/
+
+This will add to the experiment directory a new set of files
+
+```
+exp/cifar10/
+- experiment_N_of_16
+-- logits/
+--- 0000000100.npy
+```
+
+where this new file has shape (50000, 10) and stores the model's output features
+for each example.
+
+#### 3. Compute membership inference scores
+
+Finally we take the output features and generate our logit-scaled membership
+inference scores for each example for each model.
+
+> python3 score.py exp/cifar10/
+
+We find the evaluation of scores through various experiments. The calculations of logits are implemented in the score.py file, where we explored all the commented-out calculations to find the logits. It was noted that utilizing argmax values, which doesn't require knowledge of true labels, produced results comparable to those outlined in the "LIRA Likelihood Ratio Paper."
+
+And this in turn generates a new directory
+
+```
+exp/cifar10/
+- experiment_N_of_16
+-- scores/
+--- 0000000100.npy
+```
+
+with shape (50000,) storing just our scores.
+
+### PLOTTING THE RESULTS
+
+Finally we can generate pretty pictures, and run the plotting code
+
+> python3 plot.py
+
+### RESULTS {Using AUC as Metric}
+
+| | Loss Value (Baseline) | Confidence Values | log (Confidence Values) | Argmax | log (Argmax) |
+| :-----: | :-------------------: | :---------------: | :---------------------: | :----: | :----------: |
+| Attack Ours (Online) | 0.5753 | 0.5668 | 0.575 | 0.5464 | 0.5447 |
+| Attack Ours (Online,Fixed Variance) | 0.5879 | 0.593 | 0.6009 | 0.5622 | 0.5602 |
+| Attack Ours (Offline) | 0.5181 | 0.492 | 0.4721 | 0.478 | 0.4756 |
+| Attack Ours (Offline, Fixed Variance) | 0.5184 | 0.4928 | 0.4804 | 0.4834 | 0.4815 |
+| Attack Global Threshold | 0.5448 | 0.5439 | 0.5469 | 0..5376 | 0.5377 |
+
+where the global threshold attack is the baseline, and our online,
+online-with-fixed-variance, offline, and offline-with-fixed-variance attack
+variants are the four other curves. Note that because we only train a few
+models, the fixed variance variants perform best.
+
+### Citation
+
+You can cite this paper with
+
+```
+@ title= {Confidence is All You Need For MI Attacks}
+ author={Abhishek Sinha, Himanshi Tibrewal, Mansi Gupta, Nikhar Waghela, Shivank Garg},
+ journal={arXiv preprint arXiv:2311.15373},
+ year={2023}
+}
+```
\ No newline at end of file
diff --git a/Scripts/train_demo.sh b/Scripts/train_demo.sh
new file mode 100644
index 0000000..6672ddf
--- /dev/null
+++ b/Scripts/train_demo.sh
@@ -0,0 +1,30 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
\ No newline at end of file
diff --git a/Scripts/train_demo_multigpu.sh b/Scripts/train_demo_multigpu.sh
new file mode 100644
index 0000000..f20334b
--- /dev/null
+++ b/Scripts/train_demo_multigpu.sh
@@ -0,0 +1,32 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
+CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
+CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
+CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
+CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
+CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
+CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
+CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
+wait;
+CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
+CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
+CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
+CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
+CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
+CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
+CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
+CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
+wait;
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 0000000..b4c966a
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,95 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Tuple, List
+
+import numpy as np
+import tensorflow as tf
+
+
+def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
+ features = tf.io.parse_single_example(serialized_example,
+ features={'image': tf.io.FixedLenFeature([], tf.string),
+ 'label': tf.io.FixedLenFeature([], tf.int64)})
+ image = tf.image.decode_image(features['image']).set_shape(image_shape)
+ image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
+ return dict(image=image, label=features['label'])
+
+
+class DataSet:
+ """Wrapper for tf.data.Dataset to permit extensions."""
+
+ def __init__(self, data: tf.data.Dataset,
+ image_shape: Tuple[int, int, int],
+ augment_fn: Optional[Callable] = None,
+ parse_fn: Optional[Callable] = record_parse):
+ self.data = data
+ self.parse_fn = parse_fn
+ self.augment_fn = augment_fn
+ self.image_shape = image_shape
+
+ @classmethod
+ def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
+ return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
+ augment_fn=augment_fn, parse_fn=None)
+
+ @classmethod
+ def from_files(cls, filenames: List[str],
+ image_shape: Tuple[int, int, int],
+ augment_fn: Optional[Callable],
+ parse_fn: Optional[Callable] = record_parse):
+ filenames_in = filenames
+ filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
+ if not filenames:
+ raise ValueError('Empty dataset, files not found:', filenames_in)
+ return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)
+
+ @classmethod
+ def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
+ augment_fn: Optional[Callable] = None):
+ return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
+ image_shape, augment_fn=augment_fn, parse_fn=None)
+
+ def __iter__(self):
+ return iter(self.data)
+
+ def __getattr__(self, item):
+ if item in self.__dict__:
+ return self.__dict__[item]
+
+ def call_and_update(*args, **kwargs):
+ v = getattr(self.__dict__['data'], item)(*args, **kwargs)
+ if isinstance(v, tf.data.Dataset):
+ return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
+ return v
+
+ return call_and_update
+
+ def augment(self, para_augment: int = 4):
+ if self.augment_fn:
+ return self.map(self.augment_fn, para_augment)
+ return self
+
+ def nchw(self):
+ return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))
+
+ def one_hot(self, nclass: int):
+ return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))
+
+ def parse(self, para_parse: int = 2):
+ if not self.parse_fn:
+ return self
+ if self.image_shape:
+ return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
+ return self.map(self.parse_fn, para_parse)
\ No newline at end of file
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000..75dcfd4
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,137 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+# pyformat: disable
+
+import json
+import os
+import re
+
+import numpy as np
+import objax
+import tensorflow as tf # For data augmentation.
+from absl import app
+from absl import flags
+
+from train import MemModule
+from train import network
+
+FLAGS = flags.FLAGS
+
+
+def main(argv):
+ """
+ Perform inference of the saved model in order to generate the
+ output logits, using a particular set of augmentations.
+ """
+ del argv
+ tf.config.experimental.set_visible_devices([], "GPU")
+
+ def load(arch):
+ return MemModule(network(arch), nclass=100 if FLAGS.dataset == 'cifar100' else 10,
+ mnist=FLAGS.dataset == 'mnist',
+ arch=arch,
+ lr=.1,
+ batch=0,
+ epochs=0,
+ weight_decay=0)
+
+ def cache_load(arch):
+ thing = []
+ def fn():
+ if len(thing) == 0:
+ thing.append(load(arch))
+ return thing[0]
+ return fn
+
+ xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
+ ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]
+
+
+ def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
+
+ outs = []
+ for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]:
+ aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy()
+ for dx in range(0, 2*shift+1, stride):
+ for dy in range(0, 2*shift+1, stride):
+ this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2))
+
+ logits = model.model(this_x, training=True)
+ outs.append(logits)
+
+ print(np.array(outs).shape)
+ return np.array(outs).transpose((1, 0, 2))
+
+ N = 5000
+
+ def features(model, xbatch, ybatch):
+ return get_loss(model, xbatch, ybatch,
+ shift=0, reflect=True, stride=1)
+
+ for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
+ if re.search(FLAGS.regex, path) is None:
+ print("Skipping from regex")
+ continue
+
+ hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
+ arch = hparams['arch']
+ model = cache_load(arch)()
+
+ logdir = os.path.join(FLAGS.logdir, path)
+
+ checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
+ max_epoch, last_ckpt = checkpoint.restore(model.vars())
+ if max_epoch == 0: continue
+
+ if not os.path.exists(os.path.join(FLAGS.logdir, path, "logits")):
+ os.mkdir(os.path.join(FLAGS.logdir, path, "logits"))
+ if FLAGS.from_epoch is not None:
+ first = FLAGS.from_epoch
+ else:
+ first = max_epoch-1
+
+ for epoch in range(first,max_epoch+1):
+ if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
+ # no checkpoint saved here
+ continue
+
+ if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
+ print("Skipping already generated file", epoch)
+ continue
+
+ try:
+ start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch)
+ except:
+ print("Fail to load", epoch)
+ continue
+
+ stats = []
+
+ for i in range(0,len(xs_all),N):
+ stats.extend(features(model, xs_all[i:i+N],
+ ys_all[i:i+N]))
+ # This will be shape N, augs, nclass
+
+ np.save(os.path.join(FLAGS.logdir, path, "logits", "%010d"%epoch),
+ np.array(stats)[:,None,:,:])
+
+if __name__ == '__main__':
+ flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
+ flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
+ flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
+ flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
+ flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
+ app.run(main)
\ No newline at end of file
diff --git a/plot.py b/plot.py
new file mode 100644
index 0000000..25a7832
--- /dev/null
+++ b/plot.py
@@ -0,0 +1,228 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+# pyformat: disable
+
+import os
+import scipy.stats
+
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.metrics import auc, roc_curve
+import functools
+
+# Look at me being proactive!
+import matplotlib
+matplotlib.rcParams['pdf.fonttype'] = 42
+matplotlib.rcParams['ps.fonttype'] = 42
+
+
+def sweep(score, x):
+ """
+ Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
+ """
+ fpr, tpr, _ = roc_curve(x, -score)
+ acc = np.max(1-(fpr+(1-tpr))/2)
+ return fpr, tpr, auc(fpr, tpr), acc
+
+def load_data(p):
+ """
+ Load our saved scores and then put them into a big matrix.
+ """
+ global scores, keep
+ scores = []
+ keep = []
+
+ for root,ds,_ in os.walk(p):
+ for f in ds:
+ if not f.startswith("experiment"): continue
+ if not os.path.exists(os.path.join(root,f,"scores")): continue
+ last_epoch = sorted(os.listdir(os.path.join(root,f,"scores")))
+ if len(last_epoch) == 0: continue
+ scores.append(np.load(os.path.join(root,f,"scores",last_epoch[-1])))
+ keep.append(np.load(os.path.join(root,f,"keep.npy")))
+
+ scores = np.array(scores)
+ keep = np.array(keep)[:,:scores.shape[1]]
+
+ return scores, keep
+
+def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
+ fix_variance=False):
+ """
+ Fit a two predictive models using keep and scores in order to predict
+ if the examples in check_scores were training data or not, using the
+ ground truth answer from check_keep.
+ """
+ dat_in = []
+ dat_out = []
+
+ for j in range(scores.shape[1]):
+ dat_in.append(scores[keep[:,j],j,:])
+ dat_out.append(scores[~keep[:,j],j,:])
+
+ in_size = min(min(map(len,dat_in)), in_size)
+ out_size = min(min(map(len,dat_out)), out_size)
+
+ dat_in = np.array([x[:in_size] for x in dat_in])
+ dat_out = np.array([x[:out_size] for x in dat_out])
+
+ mean_in = np.median(dat_in, 1)
+ mean_out = np.median(dat_out, 1)
+
+ if fix_variance:
+ std_in = np.std(dat_in)
+ std_out = np.std(dat_in)
+ else:
+ std_in = np.std(dat_in, 1)
+ std_out = np.std(dat_out, 1)
+
+ prediction = []
+ answers = []
+ for ans, sc in zip(check_keep, check_scores):
+ pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in+1e-30)
+ pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
+ score = pr_in-pr_out
+
+ prediction.extend(score.mean(1))
+ answers.extend(ans)
+
+ return prediction, answers
+
+def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000,
+ fix_variance=False):
+ """
+ Fit a single predictive model using keep and scores in order to predict
+ if the examples in check_scores were training data or not, using the
+ ground truth answer from check_keep.
+ """
+ dat_in = []
+ dat_out = []
+
+ for j in range(scores.shape[1]):
+ dat_in.append(scores[keep[:, j], j, :])
+ dat_out.append(scores[~keep[:, j], j, :])
+
+ out_size = min(min(map(len,dat_out)), out_size)
+
+ dat_out = np.array([x[:out_size] for x in dat_out])
+
+ mean_out = np.median(dat_out, 1)
+
+ if fix_variance:
+ std_out = np.std(dat_out)
+ else:
+ std_out = np.std(dat_out, 1)
+
+ prediction = []
+ answers = []
+ for ans, sc in zip(check_keep, check_scores):
+ score = scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
+
+ prediction.extend(score.mean(1))
+ answers.extend(ans)
+ return prediction, answers
+
+
+def generate_global(keep, scores, check_keep, check_scores):
+ """
+ Use a simple global threshold sweep to predict if the examples in
+ check_scores were training data or not, using the ground truth answer from
+ check_keep.
+ """
+ prediction = []
+ answers = []
+ for ans, sc in zip(check_keep, check_scores):
+ prediction.extend(-sc.mean(1))
+ answers.extend(ans)
+
+ return prediction, answers
+
+def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
+ """
+ Generate the ROC curves by using ntest models as test models and the rest to train.
+ """
+
+ prediction, answers = fn(keep[:-ntest],
+ scores[:-ntest],
+ keep[-ntest:],
+ scores[-ntest:])
+
+ fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
+
+ low = tpr[np.where(fpr<.001)[0][-1]]
+
+ print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low))
+
+ metric_text = ''
+ if metric == 'auc':
+ metric_text = 'auc=%.3f'%auc
+ elif metric == 'acc':
+ metric_text = 'acc=%.3f'%acc
+
+ plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
+ return (acc,auc)
+
+
+def fig_fpr_tpr():
+
+ plt.figure(figsize=(4,3))
+
+ do_plot(generate_ours,
+ keep, scores, 1,
+ "Ours (online)\n",
+ metric='auc'
+ )
+
+ do_plot(functools.partial(generate_ours, fix_variance=True),
+ keep, scores, 1,
+ "Ours (online, fixed variance)\n",
+ metric='auc'
+ )
+
+ do_plot(functools.partial(generate_ours_offline),
+ keep, scores, 1,
+ "Ours (offline)\n",
+ metric='auc'
+ )
+
+ do_plot(functools.partial(generate_ours_offline, fix_variance=True),
+ keep, scores, 1,
+ "Ours (offline, fixed variance)\n",
+ metric='auc'
+ )
+
+ do_plot(generate_global,
+ keep, scores, 1,
+ "Global threshold\n",
+ metric='auc'
+ )
+
+ plt.semilogx()
+ plt.semilogy()
+ plt.xlim(1e-5,1)
+ plt.ylim(1e-5,1)
+ plt.xlabel("False Positive Rate")
+ plt.ylabel("True Positive Rate")
+ plt.plot([0, 1], [0, 1], ls='--', color='gray')
+ plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
+ plt.legend(fontsize=8)
+ plt.savefig("/tmp/fprtpr.png")
+ plt.show()
+
+
+if __name__ == '__main__':
+ load_data("exp/cifar10/")
+ fig_fpr_tpr()
\ No newline at end of file
diff --git a/requirements.sh b/requirements.sh
new file mode 100644
index 0000000..0a82664
--- /dev/null
+++ b/requirements.sh
@@ -0,0 +1,8 @@
+pip install --upgrade pip
+pip install scipy
+pip install --upgrade objax
+pip install sklearn
+pip install numpy
+pip install matplotlib
+pip install tensorflow[and-cuda] #for gpu users
+# pip install tensorflow #for CPU users
\ No newline at end of file
diff --git a/score.py b/score.py
new file mode 100644
index 0000000..d52326e
--- /dev/null
+++ b/score.py
@@ -0,0 +1,72 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import numpy as np
+import os
+import multiprocessing as mp
+
+
+def load_one(base):
+ """
+ This loads a logits and converts it to a scored prediction.
+ """
+ root = os.path.join(logdir,base,'logits')
+ if not os.path.exists(root): return None
+
+ if not os.path.exists(os.path.join(logdir,base,'scores')):
+ os.mkdir(os.path.join(logdir,base,'scores'))
+
+ for f in os.listdir(root):
+ try:
+ opredictions = np.load(os.path.join(root,f))
+ except:
+ print("Fail")
+ continue
+
+ ## Be exceptionally careful.
+ ## Numerically stable everything, as described in the paper.
+ predictions = opredictions - np.max(opredictions, axis=3, keepdims=True)
+ predictions = np.array(np.exp(predictions), dtype=np.float64)
+ predictions = predictions/np.sum(predictions,axis=3,keepdims=True)
+
+ COUNT = predictions.shape[0]
+ # x num_examples x num_augmentations x logits
+ y_true = predictions[np.arange(COUNT),:,:,labels[:COUNT]]
+ print(y_true.shape)
+
+ print('mean acc',np.mean(predictions[:,0,0,:].argmax(1)==labels[:COUNT]))
+
+ max_confidence=np.max(predictions,axis=3)
+
+ predictions[np.arange(COUNT),:,:,labels[:COUNT]] = 0
+ y_wrong = np.sum(predictions, axis=3)
+
+ #logit = (np.log(y_true.mean((1))+1e-45) - np.log(y_wrong.mean((1))+1e-45))
+ #logit = (np.log(y_true.mean(1))+1e-45)
+ #logit = (np.(y_true.mean(1))+1e-45)
+ #logit = (np.max_confidence.mean(1))+1e-45)
+ logit = ((np.log(max_confidence.mean(1))) + 1e-45)
+
+ np.save(os.path.join(logdir, base, 'scores', f), logit)
+
+
+def load_stats():
+ with mp.Pool(8) as p:
+ p.map(load_one, [x for x in os.listdir(logdir) if 'exp' in x])
+
+
+logdir = sys.argv[1]
+labels = np.load(os.path.join(logdir,"y_train.npy"))
+load_stats()
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..016aad2
--- /dev/null
+++ b/train.py
@@ -0,0 +1,331 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+# pyformat: disable
+
+import functools
+import os
+import shutil
+from typing import Callable
+import json
+
+import jax
+import jax.numpy as jn
+import numpy as np
+import tensorflow as tf # For data augmentation.
+import tensorflow_datasets as tfds
+from absl import app, flags
+
+import objax
+from objax.jaxboard import SummaryWriter, Summary
+from objax.util import EasyDict
+from objax.zoo import convnet, wide_resnet
+
+from dataset import DataSet
+
+FLAGS = flags.FLAGS
+
+def augment(x, shift: int, mirror=True):
+ """
+ Augmentation function used in training the model.
+ """
+ y = x['image']
+ if mirror:
+ y = tf.image.random_flip_left_right(y)
+ y = tf.pad(y, [[shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT')
+ y = tf.image.random_crop(y, tf.shape(x['image']))
+ return dict(image=y, label=x['label'])
+
+
+class TrainLoop(objax.Module):
+ """
+ Training loop for general machine learning models.
+ Based on the training loop from the objax CIFAR10 example code.
+ """
+ predict: Callable
+ train_op: Callable
+
+ def __init__(self, nclass: int, **kwargs):
+ self.nclass = nclass
+ self.params = EasyDict(kwargs)
+
+ def train_step(self, summary: Summary, data: dict, progress: np.ndarray):
+ kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy())
+ for k, v in kv.items():
+ if jn.isnan(v):
+ raise ValueError('NaN, try reducing learning rate', k)
+ if summary is not None:
+ summary.scalar(k, float(v))
+
+ def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None):
+ """
+ Completely standard training. Nothing interesting to see here.
+ """
+ checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True)
+ start_epoch, last_ckpt = checkpoint.restore(self.vars())
+ train_iter = iter(train)
+ progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU
+
+ best_acc = 0
+ best_acc_epoch = -1
+
+ with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
+ for epoch in range(start_epoch, num_train_epochs):
+ # Train
+ summary = Summary()
+ loop = range(0, train_size, self.params.batch)
+ for step in loop:
+ progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size)
+ self.train_step(summary, next(train_iter), progress)
+
+ # Eval
+ accuracy, total = 0, 0
+ if epoch%FLAGS.eval_steps == 0 and test is not None:
+ for data in test:
+ total += data['image'].shape[0]
+ preds = np.argmax(self.predict(data['image'].numpy()), axis=1)
+ accuracy += (preds == data['label'].numpy()).sum()
+ accuracy /= total
+ summary.scalar('eval/accuracy', 100 * accuracy)
+ tensorboard.write(summary, step=(epoch + 1) * train_size)
+ print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](),
+ summary['eval/accuracy']()))
+
+ if summary['eval/accuracy']() > best_acc:
+ best_acc = summary['eval/accuracy']()
+ best_acc_epoch = epoch
+ elif patience is not None and epoch > best_acc_epoch + patience:
+ print("early stopping!")
+ checkpoint.save(self.vars(), epoch + 1)
+ return
+
+ else:
+ print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']()))
+
+ if epoch%save_steps == save_steps-1:
+ checkpoint.save(self.vars(), epoch + 1)
+
+
+# We inherit from the training loop and define predict and train_op.
+class MemModule(TrainLoop):
+ def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs):
+ """
+ Completely standard training. Nothing interesting to see here.
+ """
+ super().__init__(nclass, **kwargs)
+ self.model = model(1 if mnist else 3, nclass)
+ self.opt = objax.optimizer.Momentum(self.model.vars())
+ self.model_ema = objax.optimizer.ExponentialMovingAverageModule(self.model, momentum=0.999, debias=True)
+
+ @objax.Function.with_vars(self.model.vars())
+ def loss(x, label):
+ logit = self.model(x, training=True)
+ loss_wd = 0.5 * sum((v.value ** 2).sum() for k, v in self.model.vars().items() if k.endswith('.w'))
+ loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean()
+ return loss_xe + loss_wd * self.params.weight_decay, {'losses/xe': loss_xe, 'losses/wd': loss_wd}
+
+ gv = objax.GradValues(loss, self.model.vars())
+ self.gv = gv
+
+ @objax.Function.with_vars(self.vars())
+ def train_op(progress, x, y):
+ g, v = gv(x, y)
+ lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
+ lr = lr * jn.clip(progress*100,0,1)
+ self.opt(lr, g)
+ self.model_ema.update_ema()
+ return {'monitors/lr': lr, **v[1]}
+
+ self.predict = objax.Jit(objax.nn.Sequential([objax.ForceArgs(self.model_ema, training=False)]))
+
+ self.train_op = objax.Jit(train_op)
+
+
+def network(arch: str):
+ if arch == 'cnn32-3-max':
+ return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
+ pooling=objax.functional.max_pool_2d)
+ elif arch == 'cnn32-3-mean':
+ return functools.partial(convnet.ConvNet, scales=3, filters=32, filters_max=1024,
+ pooling=objax.functional.average_pool_2d)
+ elif arch == 'cnn64-3-max':
+ return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
+ pooling=objax.functional.max_pool_2d)
+ elif arch == 'cnn64-3-mean':
+ return functools.partial(convnet.ConvNet, scales=3, filters=64, filters_max=1024,
+ pooling=objax.functional.average_pool_2d)
+ elif arch == 'wrn28-1':
+ return functools.partial(wide_resnet.WideResNet, depth=28, width=1)
+ elif arch == 'wrn28-2':
+ return functools.partial(wide_resnet.WideResNet, depth=28, width=2)
+ elif arch == 'wrn28-10':
+ return functools.partial(wide_resnet.WideResNet, depth=28, width=10)
+ raise ValueError('Architecture not recognized', arch)
+
+def get_data(seed):
+ """
+ This is the function to generate subsets of the data for training models.
+
+ First, we get the training dataset either from the numpy cache
+ or otherwise we load it from tensorflow datasets.
+
+ Then, we compute the subset. This works in one of two ways.
+
+ 1. If we have a seed, then we just randomly choose examples based on
+ a prng with that seed, keeping FLAGS.pkeep fraction of the data.
+
+ 2. Otherwise, if we have an experiment ID, then we do something fancier.
+ If we run each experiment independently then even after a lot of trials
+ there will still probably be some examples that were always included
+ or always excluded. So instead, with experiment IDs, we guarantee that
+ after FLAGS.num_experiments are done, each example is seen exactly half
+ of the time in train, and half of the time not in train.
+
+ """
+ DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
+
+ if os.path.exists(os.path.join(FLAGS.logdir, "x_train.npy")):
+ inputs = np.load(os.path.join(FLAGS.logdir, "x_train.npy"))
+ labels = np.load(os.path.join(FLAGS.logdir, "y_train.npy"))
+ else:
+ print("First time, creating dataset")
+ data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
+ inputs = data['train']['image']
+ labels = data['train']['label']
+
+ inputs = (inputs/127.5)-1
+ np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs)
+ np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels)
+
+ nclass = np.max(labels)+1
+
+ np.random.seed(seed)
+ if FLAGS.num_experiments is not None:
+ np.random.seed(0)
+ keep = np.random.uniform(0,1,size=(FLAGS.num_experiments, FLAGS.dataset_size))
+ order = keep.argsort(0)
+ keep = order < int(FLAGS.pkeep * FLAGS.num_experiments)
+ keep = np.array(keep[FLAGS.expid], dtype=bool)
+ else:
+ keep = np.random.uniform(0, 1, size=FLAGS.dataset_size) <= FLAGS.pkeep
+
+ if FLAGS.only_subset is not None:
+ keep[FLAGS.only_subset:] = 0
+
+ xs = inputs[keep]
+ ys = labels[keep]
+
+ if FLAGS.augment == 'weak':
+ aug = lambda x: augment(x, 4)
+ elif FLAGS.augment == 'mirror':
+ aug = lambda x: augment(x, 0)
+ elif FLAGS.augment == 'none':
+ aug = lambda x: augment(x, 0, mirror=False)
+ else:
+ raise
+
+ train = DataSet.from_arrays(xs, ys,
+ augment_fn=aug)
+ test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
+ train = train.cache().shuffle(8192).repeat().parse().augment().batch(FLAGS.batch)
+ train = train.nchw().one_hot(nclass).prefetch(16)
+ test = test.cache().parse().batch(FLAGS.batch).nchw().prefetch(16)
+
+ return train, test, xs, ys, keep, nclass
+
+def main(argv):
+ del argv
+ tf.config.experimental.set_visible_devices([], "GPU")
+
+ seed = FLAGS.seed
+ if seed is None:
+ import time
+ seed = np.random.randint(0, 1000000000)
+ seed ^= int(time.time())
+
+ args = EasyDict(arch=FLAGS.arch,
+ lr=FLAGS.lr,
+ batch=FLAGS.batch,
+ weight_decay=FLAGS.weight_decay,
+ augment=FLAGS.augment,
+ seed=seed)
+
+
+ if FLAGS.tunename:
+ logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
+ elif FLAGS.expid is not None:
+ logdir = "experiment-%d_%d"%(FLAGS.expid,FLAGS.num_experiments)
+ else:
+ logdir = "experiment-"+str(seed)
+ logdir = os.path.join(FLAGS.logdir, logdir)
+
+ if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%FLAGS.epochs)):
+ print(f"run {FLAGS.expid} already completed.")
+ return
+ else:
+ if os.path.exists(logdir):
+ print(f"deleting run {FLAGS.expid} that did not complete.")
+ shutil.rmtree(logdir)
+
+ print(f"starting run {FLAGS.expid}.")
+ if not os.path.exists(logdir):
+ os.makedirs(logdir)
+
+ train, test, xs, ys, keep, nclass = get_data(seed)
+
+ # Define the network and train_it
+ tm = MemModule(network(FLAGS.arch), nclass=nclass,
+ mnist=FLAGS.dataset == 'mnist',
+ epochs=FLAGS.epochs,
+ expid=FLAGS.expid,
+ num_experiments=FLAGS.num_experiments,
+ pkeep=FLAGS.pkeep,
+ save_steps=FLAGS.save_steps,
+ only_subset=FLAGS.only_subset,
+ **args
+ )
+
+ r = {}
+ r.update(tm.params)
+
+ open(os.path.join(logdir,'hparams.json'),"w").write(json.dumps(tm.params))
+ np.save(os.path.join(logdir,'keep.npy'), keep)
+
+ tm.train(FLAGS.epochs, len(xs), train, test, logdir,
+ save_steps=FLAGS.save_steps, patience=FLAGS.patience)
+
+
+
+if __name__ == '__main__':
+ flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
+ flags.DEFINE_float('lr', 0.1, 'Learning rate.')
+ flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
+ flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
+ flags.DEFINE_integer('batch', 256, 'Batch size')
+ flags.DEFINE_integer('epochs', 501, 'Training duration in number of epochs.')
+ flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
+ flags.DEFINE_integer('seed', None, 'Training seed.')
+ flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
+ flags.DEFINE_integer('expid', None, 'Experiment ID')
+ flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
+ flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
+ flags.DEFINE_integer('only_subset', None, 'Only train on a subset of images.')
+ flags.DEFINE_integer('dataset_size', 50000, 'number of examples to keep.')
+ flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
+ flags.DEFINE_integer('abort_after_epoch', None, 'stop trainin early at an epoch')
+ flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
+ flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
+ flags.DEFINE_bool('tunename', False, 'Use tune name?')
+ app.run(main)