From 5d42ab4e082c3d4f8862c71d2a121478de18900f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BD=98=E4=BD=B3=E6=87=BF=2000102273?= Date: Tue, 7 Sep 2021 15:35:05 +0800 Subject: [PATCH] support 1dcnn distilling --- README.md | 4 + doc/CNN-1D-tiny-Distillation.md | 113 ++++++++++++++++++ examples/cnn1d_iscx_session_all_train.py | 36 ++++++ .../cnn1d_tiny_iscx_session_all_distill.py | 37 ++++++ examples/cnn1d_tiny_iscx_session_all_train.py | 36 ++++++ src/model_optimizer/pruner/config_schema.json | 3 +- .../pruner/dataset/__init__.py | 5 +- src/model_optimizer/pruner/dataset/cifar10.py | 1 + .../pruner/dataset/dataset_base.py | 4 +- .../pruner/dataset/imagenet.py | 1 + .../pruner/dataset/iscx_session_all.py | 85 +++++++++++++ src/model_optimizer/pruner/dataset/mnist.py | 1 + .../pruner/distill/distiller.py | 5 +- .../pruner/learner/__init__.py | 6 + .../pruner/learner/cnn1d_iscx_session_all.py | 68 +++++++++++ .../pruner/learner/learner_base.py | 6 +- src/model_optimizer/pruner/models/__init__.py | 13 +- src/model_optimizer/pruner/models/cnn1d.py | 68 +++++++++++ .../scheduler/distill/cnn1d_tiny_0.3.yaml | 6 + tests/test_model.py | 2 +- tools/convert_softmax_model_to_logits.py | 31 +++++ 21 files changed, 519 insertions(+), 12 deletions(-) create mode 100644 doc/CNN-1D-tiny-Distillation.md create mode 100644 examples/cnn1d_iscx_session_all_train.py create mode 100644 examples/cnn1d_tiny_iscx_session_all_distill.py create mode 100644 examples/cnn1d_tiny_iscx_session_all_train.py create mode 100644 src/model_optimizer/pruner/dataset/iscx_session_all.py create mode 100644 src/model_optimizer/pruner/learner/cnn1d_iscx_session_all.py create mode 100644 src/model_optimizer/pruner/models/cnn1d.py create mode 100644 src/model_optimizer/pruner/scheduler/distill/cnn1d_tiny_0.3.yaml create mode 100644 tools/convert_softmax_model_to_logits.py diff --git a/README.md b/README.md index ca8ffca..2634af9 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,10 @@ The details are shown in the table below, and the code can refer to examples\res | + pruned + distill | 76.39 | 6954152 ( 72.8% pruned) | 1075M | 27M| | + pruned + distill + quantization(TF-Lite) | 75.938 | - | - | 7.1M| +We also impletement a 1D-CNN distillation which shows distillation is also effective on Encrypted Traffic Classification. +You can get detailed instructions from [here](doc/CNN-1D-tiny-Distillation.md). Following this instruction, you can build +your own dataset and model to train and distill under adlik model optimizer. + ## 1. Pruning and quantization principle ### 1.1 Filter pruning diff --git a/doc/CNN-1D-tiny-Distillation.md b/doc/CNN-1D-tiny-Distillation.md new file mode 100644 index 0000000..60e3dfa --- /dev/null +++ b/doc/CNN-1D-tiny-Distillation.md @@ -0,0 +1,113 @@ +# Tiny 1D-CNN Knowledge Distillation + +The following uses 1D-CNN on the 12 classes session all dataset as teacher model to illustrate how to use the model +optimizer to improve the preformance of tiny 1D-CNN by knowledge distillation. + +The 1D-CNN model is from Wang's paper[Wang, W.; Zhu, M.; Wang, J.; Zeng, X.; Yang, Z. End-to-end encrypted traffic +classification with one-dimensional convolution neural networks.] The tiny 1D-CNN model is a slim version of the 1D-CNN +model mentioned before. Using 1D-CNN model as the teacher to ditstill tiny 1D-CNN model, performance can be improved by +5.66%. + +The details are shown in the table below, and the code can refer to examples\cnn1d_tiny_iscx_session_all_distill.py. + +| Model | Accuracy | Params | Model Size | +| --------- | -------- | -------------------- | ---------------------------- | +| cnn1d | 92.67% | 5832588 | 23M| +| cnn1d_tiny | 87.62% | 134988 | 546K| +| cnn1d_tiny+ distill | 93.28% | 134988 | 546K| + + +## 1 Create custom dataset +Using [ISCX dataset](https://www.unb.ca/cic/datasets/vpn.html), you can get the processed 12-classes-session-all dataset +from [wang's github](https://github.com/echowei/DeepTraffic/blob/master/2.encrypted_traffic_classification/3.PerprocessResults/12class.zip). +We name the dataset as iscx_session_all. In the iscx_session_all, there are 35501 training samples, the shape is (35501, 28, 28), +3945 testing samples. + +Now that you have the dataset, you can implement your custom dateset by extending model_optimizer.prunner.dataset. +dataset_base.DatasetBase and implementing: + +1. \__init__, required, where you can do all dataset initialization +2. parse_fn, required, where is the map function of the dataset +3. parse_fn_distill, required, where is the map function of the dataset used in distillation +4. build, optional, where is the process of building the dataset. If your dataset is not in tfrecord format, you must +implement this function. + +Here in the custom dataset, we reshape the samples from (None, 28, 28, 1) to (None, 1, 784, 1) for the following 1D-CNN +models. + +After that, all you need is put the dataset name in the following files: +1. src/model_optimizer/prunner/config_schema.json the "enum" list +2. src/model_optimizer/prunner/dataset/\__init__.py. Add the dataset name in Line 19 and add the dataset instance in the +if-else clause. + +## Create custom model +Create your own model using The Keras functional API in model_optimizer.prunner.models. + +After that, all you need is put the model name and initialize the model in the following files: +1. src/model_optimizer/prunner/models/\__init__.py. Add the model name in Line 21 and add the model instance in the +if-else clause. + +## Create custom learner +Implement your own learner by extending model_optimizer.prunner.learner.learner_base.LearnerBase and implementing: +1. \__init__, required, where you can define your own learning rate callback +2. get_optimizer, required, where you can define your own optimizer +3. get_losses, required, where you can define your own loss function +4. get_metrics, required, where you can define your own metrics + +After that, all you need is put the model name and dataset name and initialize the learner in the following files: +1. src/model_optimizer/prunner/learner/\__init__.py + +## Create the training process of the teacher model, and train the teacher model +Enter the examples directory, create cnn1d_iscx_session_all_train.py for cnn1d model. + +> Note +> +> > the "model_name" and "dataset" in the request must be the same as you defined before + +Execute: + +```shell +cd examples +python3 cnn1d_iscx_session_all_train.py +``` + +After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d, and the inference +checkpoint file will be generated in ./models_eval_ckpt/cnn1d. You can also modify the checkpoint_path +and checkpoint_eval_path of the cnn1d_iscx_session_all_train.py file to change the generated file path. + +## Convert the teacher model to logits output +Enter the tools directory and execute: +```shell +cd tools +python3 convert_softmax_model_to_logits.py +``` + +After execution, the default checkpoint file of logits model will be generated in examples/models_eval_ckpt/cnn1d/ +checkpoint-60-logits.h5 + +## Create the distilling process and distill the cnn1d_tiny model +Create the configuration file in the src/model_optimizer/pruner/scheduler/distill,like "cnn1d_tiny_0.3.yaml" where the +distillation parameters is configured. + +Enter the examples directory, create cnn1d_tiny_iscx_session_all_distill.py for cnn1d_tiny model. In the distilling +process, the teacher is cnn1d, the student is cnn1d_tiny. + +> Note +> +> > the "model_name" and "dataset" in the request must be the same as you defined before + +```shell +python3 cnn1d_tiny_iscx_session_all_distill.py +``` + +After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d_tiny_distill, and the inference +checkpoint file will be generated in ./models_eval_ckpt/cnn1d_tiny_distill. You can also modify the checkpoint_path and +checkpoint_eval_path of the cnn1d_tiny_iscx_session_all_distill.py file to change the +generated file path. + +> Note +> +> > i. The model in the checkpoint_path is not the pure cnn1d_tiny model. It's the hybird of cnn1d_tiny(student) and +> > cnn1d(teacher) +> > +> > ii. The model in the checkpoint_eval_path is the distilled model, i.e. pure cnn1d_tiny model diff --git a/examples/cnn1d_iscx_session_all_train.py b/examples/cnn1d_iscx_session_all_train.py new file mode 100644 index 0000000..e6ae8dc --- /dev/null +++ b/examples/cnn1d_iscx_session_all_train.py @@ -0,0 +1,36 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Train a cnn1d model on iscx_session_all dataset +""" +import os +# If you did not execute the setup.py, uncomment the following four lines +# import sys +# from os.path import abspath, join, dirname +# sys.path.insert(0, join(abspath(dirname(__file__)), '../src')) +# print(sys.path) + +from model_optimizer import prune_model # noqa: E402 + + +def _main(): + base_dir = os.path.dirname(__file__) + request = { + "dataset": "iscx_session_all", + "model_name": "cnn1d", + "data_dir": "/data/12class/SessionAllLayers", + "batch_size": 500, + "batch_size_val": 100, + "learning_rate": 1e-3, + "epochs": 60, + "checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d"), + "checkpoint_save_period": 1, # save a checkpoint every epoch + "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d"), + "scheduler": "train" + } + prune_model(request) + + +if __name__ == "__main__": + _main() diff --git a/examples/cnn1d_tiny_iscx_session_all_distill.py b/examples/cnn1d_tiny_iscx_session_all_distill.py new file mode 100644 index 0000000..553aee3 --- /dev/null +++ b/examples/cnn1d_tiny_iscx_session_all_distill.py @@ -0,0 +1,37 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Distill a cnn1d_tiny model from a trained cnn1d model on the iscx_session_all dataset +""" +import os +# If you did not execute the setup.py, uncomment the following four lines +# import sys +# from os.path import abspath, join, dirname +# sys.path.insert(0, join(abspath(dirname(__file__)), '../src')) +# print(sys.path) + +from model_optimizer import prune_model # noqa: E402 + + +def _main(): + base_dir = os.path.dirname(__file__) + request = { + "dataset": "iscx_session_all", + "model_name": "cnn1d_tiny", + "data_dir": "/data/12class/SessionAllLayers", + "batch_size": 500, + "batch_size_val": 100, + "learning_rate": 1e-3, + "epochs": 200, + "checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny_distill"), + "checkpoint_save_period": 10, # save a checkpoint every epoch + "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny_distill"), + "scheduler": "distill", + "scheduler_file_name": "cnn1d_tiny_0.3.yaml" + } + prune_model(request) + + +if __name__ == "__main__": + _main() diff --git a/examples/cnn1d_tiny_iscx_session_all_train.py b/examples/cnn1d_tiny_iscx_session_all_train.py new file mode 100644 index 0000000..9b8d6e8 --- /dev/null +++ b/examples/cnn1d_tiny_iscx_session_all_train.py @@ -0,0 +1,36 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Train a cnn1d_tiny model on iscx_session_all dataset +""" +import os +# If you did not execute the setup.py, uncomment the following four lines +# import sys +# from os.path import abspath, join, dirname +# sys.path.insert(0, join(abspath(dirname(__file__)), '../src')) +# print(sys.path) + +from model_optimizer import prune_model # noqa: E402 + + +def _main(): + base_dir = os.path.dirname(__file__) + request = { + "dataset": "iscx_session_all", + "model_name": "cnn1d_tiny", + "data_dir": "/data/12class/SessionAllLayers", + "batch_size": 500, + "batch_size_val": 100, + "learning_rate": 1e-3, + "epochs": 60, + "checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny"), + "checkpoint_save_period": 10, # save a checkpoint every epoch + "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny"), + "scheduler": "train" + } + prune_model(request) + + +if __name__ == "__main__": + _main() diff --git a/src/model_optimizer/pruner/config_schema.json b/src/model_optimizer/pruner/config_schema.json index e9247d9..e8e39d1 100644 --- a/src/model_optimizer/pruner/config_schema.json +++ b/src/model_optimizer/pruner/config_schema.json @@ -6,7 +6,8 @@ "enum": [ "mnist", "cifar10", - "imagenet" + "imagenet", + "iscx_session_all" ], "description": "dataset name" }, diff --git a/src/model_optimizer/pruner/dataset/__init__.py b/src/model_optimizer/pruner/dataset/__init__.py index 3fdd8a7..ca0025a 100644 --- a/src/model_optimizer/pruner/dataset/__init__.py +++ b/src/model_optimizer/pruner/dataset/__init__.py @@ -16,7 +16,7 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0): :return: class of Dataset """ dataset_name = config.get_attribute('dataset') - if dataset_name not in ['mnist', 'cifar10', 'imagenet']: + if dataset_name not in ['mnist', 'cifar10', 'imagenet', 'iscx_session_all']: raise Exception('Not support dataset %s' % dataset_name) if dataset_name == 'mnist': from .mnist import MnistDataset @@ -27,5 +27,8 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0): elif dataset_name == 'imagenet': from .imagenet import ImagenetDataset return ImagenetDataset(config, is_training, num_shards, shard_index) + elif dataset_name == 'iscx_session_all': + from .iscx_session_all import ISCXDataset + return ISCXDataset(config, is_training, num_shards, shard_index) else: raise Exception('Not support dataset {}'.format(dataset_name)) diff --git a/src/model_optimizer/pruner/dataset/cifar10.py b/src/model_optimizer/pruner/dataset/cifar10.py index 556c48a..5731136 100644 --- a/src/model_optimizer/pruner/dataset/cifar10.py +++ b/src/model_optimizer/pruner/dataset/cifar10.py @@ -31,6 +31,7 @@ def __init__(self, config, is_training): self.buffer_size = 10000 self.num_samples_of_train = 50000 self.num_samples_of_val = 10000 + self.data_shape = (32, 32, 3) # pylint: disable=no-value-for-parameter,unexpected-keyword-arg def parse_fn(self, example_serialized): diff --git a/src/model_optimizer/pruner/dataset/dataset_base.py b/src/model_optimizer/pruner/dataset/dataset_base.py index a05dff4..d21cf7e 100644 --- a/src/model_optimizer/pruner/dataset/dataset_base.py +++ b/src/model_optimizer/pruner/dataset/dataset_base.py @@ -78,9 +78,9 @@ def build(self, is_distill=False): dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE) else: dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) - return self.__build_batch(dataset) + return self.build_batch(dataset) - def __build_batch(self, dataset): + def build_batch(self, dataset): """ Make an batch from tf.data.Dataset. :param dataset: tf.data.Dataset object diff --git a/src/model_optimizer/pruner/dataset/imagenet.py b/src/model_optimizer/pruner/dataset/imagenet.py index 637070d..226da49 100644 --- a/src/model_optimizer/pruner/dataset/imagenet.py +++ b/src/model_optimizer/pruner/dataset/imagenet.py @@ -32,6 +32,7 @@ def __init__(self, config, is_training, num_shards=1, shard_index=0): self.buffer_size = 10000 self.num_samples_of_train = 1281167 self.num_samples_of_val = 50000 + self.data_shape = (224, 224, 3) # pylint: disable=no-value-for-parameter,unexpected-keyword-arg def parse_fn(self, example_serialized): diff --git a/src/model_optimizer/pruner/dataset/iscx_session_all.py b/src/model_optimizer/pruner/dataset/iscx_session_all.py new file mode 100644 index 0000000..6b5a4fc --- /dev/null +++ b/src/model_optimizer/pruner/dataset/iscx_session_all.py @@ -0,0 +1,85 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +ISCX 12 class session all dataset +https://github.com/echowei/DeepTraffic/blob/master/2.encrypted_traffic_classification/3.PerprocessResults/12class.zip + +""" +import os +import gzip +import tensorflow as tf +from tensorflow.python.keras.utils.data_utils import get_file +import numpy as np +from .dataset_base import DatasetBase + + +class ISCXDataset(DatasetBase): + """ + ISCX session all layer dataset + """ + def __init__(self, config, is_training, num_shards=1, shard_index=0): + """ + Constructor function. + :param config: Config object + :param is_training: whether to construct the training subset + :return: + """ + super().__init__(config, is_training, num_shards, shard_index) + if is_training: + self.batch_size = self.batch_size + else: + self.batch_size = self.batch_size_val + self.buffer_size = 5000 + self.num_samples_of_train = 35501 + self.num_samples_of_val = 3945 + self.data_shape = (1, 784, 1) + + # pylint: disable=R0201 + # pylint: disable=no-value-for-parameter,unexpected-keyword-arg + def parse_fn(self, *content): + data, label = content + return data, label + + def parse_fn_distill(self, *content): + """ + Parse dataset for distillation + :param content: item content of the dataset + :return: {image, label},{} + """ + image, label = self.parse_fn(*content) + inputs = {"image": image, "label": label} + targets = {} + return inputs, targets + + def build(self, is_distill=False): + """ + Build dataset + :param is_distill: is distilling or not + :return: batch of a dataset + """ + if self.is_training: + x_path = os.path.join(self.data_dir, 'train-images-idx3-ubyte.gz') + y_path = os.path.join(self.data_dir, 'train-labels-idx1-ubyte.gz') + else: + x_path = os.path.join(self.data_dir, 't10k-images-idx3-ubyte.gz') + y_path = os.path.join(self.data_dir, 't10k-labels-idx1-ubyte.gz') + + with gzip.open(y_path, 'rb') as lbpath: + y_data = np.frombuffer(lbpath.read(), np.uint8, offset=8) + + with gzip.open(x_path, 'rb') as imgpath: + x_data = np.frombuffer( + imgpath.read(), np.uint8, offset=16).reshape(len(y_data), 1, 784) + + dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) + + if self.num_shards != 1: + dataset = dataset.shard(num_shards=self.num_shards, index=self.shard_index) + if self.is_training: + dataset = dataset.shuffle(buffer_size=self.buffer_size).repeat() + if is_distill: + dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE) + else: + dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) + return self.build_batch(dataset) diff --git a/src/model_optimizer/pruner/dataset/mnist.py b/src/model_optimizer/pruner/dataset/mnist.py index d251c1f..435d4ef 100644 --- a/src/model_optimizer/pruner/dataset/mnist.py +++ b/src/model_optimizer/pruner/dataset/mnist.py @@ -31,6 +31,7 @@ def __init__(self, config, is_training): self.buffer_size = 10000 self.num_samples_of_train = 60000 self.num_samples_of_val = 10000 + self.data_shape = (28, 28, 1) # pylint: disable=R0201 # pylint: disable=no-value-for-parameter,unexpected-keyword-arg diff --git a/src/model_optimizer/pruner/distill/distiller.py b/src/model_optimizer/pruner/distill/distiller.py index e969539..fd89aa6 100644 --- a/src/model_optimizer/pruner/distill/distiller.py +++ b/src/model_optimizer/pruner/distill/distiller.py @@ -9,11 +9,12 @@ from .distill_loss import DistillLossLayer -def get_distiller(student_model, scheduler_config, teacher_model_load_func=None): +def get_distiller(student_model, scheduler_config, input_shape, teacher_model_load_func=None): """ Get distiller model :param student_model: student model function :param scheduler_config: scheduler config object + :param input_shape: the input shape of the data :param teacher_model_load_func: func to load teacher model :return: keras model of distiller """ @@ -22,7 +23,7 @@ def get_distiller(student_model, scheduler_config, teacher_model_load_func=None) if "model_load_func" in scheduler_config['distill']: teacher_model_load_func = scheduler_config['distill']["model_load_func"] - input_img = tf.keras.layers.Input(shape=(224, 224, 3), name='image') + input_img = tf.keras.layers.Input(shape=input_shape, name='image') input_lbl = tf.keras.layers.Input((), name="label", dtype='int32') student = student_model logits = student(input_img) diff --git a/src/model_optimizer/pruner/learner/__init__.py b/src/model_optimizer/pruner/learner/__init__.py index d0f542b..a795d2e 100644 --- a/src/model_optimizer/pruner/learner/__init__.py +++ b/src/model_optimizer/pruner/learner/__init__.py @@ -32,5 +32,11 @@ def get_learner(config): elif model_name == 'vgg_m_16' and dataset_name == 'cifar10': from .vgg_m_16_cifar10 import Learner return Learner(config) + elif model_name == 'cnn1d' and dataset_name == 'iscx_session_all': + from .cnn1d_iscx_session_all import Learner + return Learner(config) + elif model_name == 'cnn1d_tiny' and dataset_name == 'iscx_session_all': + from .cnn1d_iscx_session_all import Learner + return Learner(config) else: raise Exception('Not support learner: {}_{}'.format(model_name, dataset_name)) diff --git a/src/model_optimizer/pruner/learner/cnn1d_iscx_session_all.py b/src/model_optimizer/pruner/learner/cnn1d_iscx_session_all.py new file mode 100644 index 0000000..e781061 --- /dev/null +++ b/src/model_optimizer/pruner/learner/cnn1d_iscx_session_all.py @@ -0,0 +1,68 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +CNN1d on iscx_session_all learner definition +""" +import os +import tensorflow as tf +import horovod.tensorflow.keras as hvd +from .learner_base import LearnerBase +from .utils import cosine_multiplier + + +class Learner(LearnerBase): + """ + CNN1d on iscx_session_all learner + """ + def __init__(self, config): + super().__init__(config) + self.callbacks = [ + # Horovod: broadcast initial variable states from rank 0 to all other processes. + # This is necessary to ensure consistent initialization of all workers when + # training is started with random weights or restored from a checkpoint. + hvd.callbacks.BroadcastGlobalVariablesCallback(0), + # Horovod: average metrics among workers at the end of every epoch. + # + # Note: This callback must be in the list before the ReduceLROnPlateau, + # TensorBoard or other metrics-based callbacks. + hvd.callbacks.MetricAverageCallback(), + hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, verbose=0), + hvd.callbacks.LearningRateScheduleCallback(start_epoch=5, end_epoch=50, multiplier=1.), + hvd.callbacks.LearningRateScheduleCallback( + start_epoch=50, multiplier=lambda epoch: cosine_multiplier(epoch, total_epoch=self.epochs)) + + ] + # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them. + if hvd.rank() == 0: + self.callbacks.append(tf.keras.callbacks.ModelCheckpoint(os.path.join(self.checkpoint_path, + './checkpoint-{epoch}.h5'), + period=self.checkpoint_save_period)) + + def get_optimizer(self): + """ + Model compile optimizer + :return: Return model compile optimizer + """ + opt = tf.keras.optimizers.SGD(self.learning_rate*hvd.size(), momentum=0.9) + opt = hvd.DistributedOptimizer(opt) + return opt + + def get_losses(self, is_training=True): + """ + Model compile losses + :param is_training: is training or not + :return: Return model compile losses + """ + return 'sparse_categorical_crossentropy' + + def get_metrics(self, is_training=True): + """ + Model compile metrics + :param is_training: is training or not + :return: Return model compile metrics + """ + if (self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill', False)) \ + and is_training: + return None + return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/learner_base.py b/src/model_optimizer/pruner/learner/learner_base.py index dca5c13..de5bfc3 100644 --- a/src/model_optimizer/pruner/learner/learner_base.py +++ b/src/model_optimizer/pruner/learner/learner_base.py @@ -41,8 +41,10 @@ def __init__(self, config): if gpus: tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') self.verbose = 1 if hvd.rank() == 0 else 0 - origin_train_model = get_model(config, is_training=True) - origin_eval_model = get_model(config, is_training=False) + input_shape = get_dataset(self.config, is_training=True, num_shards=hvd.size(), shard_index=hvd.rank())\ + .data_shape + origin_train_model = get_model(config, input_shape, is_training=True) + origin_eval_model = get_model(config, input_shape, is_training=False) self.models_train.append(origin_train_model) self.models_eval.append(origin_eval_model) train_model = tf.keras.models.clone_model(origin_train_model) diff --git a/src/model_optimizer/pruner/models/__init__.py b/src/model_optimizer/pruner/models/__init__.py index ccc2f24..98d83dc 100644 --- a/src/model_optimizer/pruner/models/__init__.py +++ b/src/model_optimizer/pruner/models/__init__.py @@ -9,7 +9,7 @@ # pylint: disable=too-many-return-statements -def get_model(config, is_training=True): +def get_model(config, input_shape, is_training=True): """ Get model :param config: Config object @@ -19,8 +19,9 @@ def get_model(config, is_training=True): model_name = config.get_attribute('model_name') scheduler_config = get_scheduler(config) if model_name not in ['lenet', 'resnet_18', 'vgg_m_16', 'resnet_50', 'resnet_101', - 'mobilenet_v1', 'mobilenet_v2']: + 'mobilenet_v1', 'mobilenet_v2', 'cnn1d','cnn1d_tiny']: raise Exception('Not support model %s' % model_name) + if (config.get_attribute('scheduler') == 'distill' or config.get_attribute('is_distill')) and is_training: classifier_activation = None else: @@ -48,8 +49,14 @@ def get_model(config, is_training=True): from .mobilenet_v2 import mobilenet_v2_1 student_model = mobilenet_v2_1(is_training=is_training, name=model_name, classifier_activation=classifier_activation) + elif model_name == 'cnn1d': + from .cnn1d import cnn1d + student_model = cnn1d(is_training=is_training, name=model_name, classifier_activation=classifier_activation) + elif model_name == 'cnn1d_tiny': + from .cnn1d import cnn1d_tiny + student_model = cnn1d_tiny(is_training=is_training, name=model_name, classifier_activation=classifier_activation) if (config.get_attribute('scheduler') == 'distill' or config.get_attribute('is_distill')) and is_training: - distill_model = get_distiller(student_model, scheduler_config) + distill_model = get_distiller(student_model, scheduler_config, input_shape) else: distill_model = student_model return distill_model diff --git a/src/model_optimizer/pruner/models/cnn1d.py b/src/model_optimizer/pruner/models/cnn1d.py new file mode 100644 index 0000000..470685d --- /dev/null +++ b/src/model_optimizer/pruner/models/cnn1d.py @@ -0,0 +1,68 @@ +""" +1D-CNN model +""" +import tensorflow as tf + + +def cnn1d(is_training=True, name='cnn1d', classifier_activation='softmax'): + """ + This implements a 1D-CNN by Wei Wang + [Wang, W.; Zhu, M.; Wang, J.; Zeng, X.; Yang, Z. End-to-end encrypted traffic classification with one-dimensional + convolution neural networks.] + :param is_training: if training or not + :param name: the model name + :param classifier_activation: classifier_activation can only be None or "softmax" + :return: cnn1d model + """ + input_ = tf.keras.layers.Input(shape=(1, 784, 1), name='input') + x = tf.keras.layers.Conv2D(filters=32, + kernel_size=(1, 25), + padding='same', + activation='relu', + name='conv2d_1')(input_) + x = tf.keras.layers.MaxPool2D(pool_size=(1, 3), strides=(1, 3), padding='same',name='pool_1')(x) + x = tf.keras.layers.Conv2D(filters=64, + kernel_size=(1, 25), + padding='same', + activation='relu', + name='conv2d_2')(x) + x = tf.keras.layers.MaxPool2D(pool_size=(1, 3), strides=(1, 3), padding='same', name='pool_2')(x) + x = tf.keras.layers.Flatten(name='flatten')(x) + x = tf.keras.layers.Dense(1024, activation='relu', + name='dense_1')(x) + if is_training: + x = tf.keras.layers.Dropout(0.5, name='dropout')(x) + if classifier_activation == 'softmax': + output_ = tf.keras.layers.Dense(12, activation='softmax', name='dense_2')(x) + else: + output_ = tf.keras.layers.Dense(12, activation=None, name='dense_2')(x) + model = tf.keras.Model(input_, output_, name=name) + return model + + +def cnn1d_tiny(is_training=True, name='cnn1d_tiny', classifier_activation='softmax'): + """ + This implements a tiny 1D-CNN which is much smaller than teh 1D-CNN by Wei Wang + :param is_training: if training or not + :param name: the model name + :param classifier_activation: classifier_activation can only be None or "softmax" + :return: cnn1d_tiny model + """ + input_ = tf.keras.layers.Input(shape=(1, 784, 1), name='input') + x = tf.keras.layers.Conv2D(filters=16, + kernel_size=(1, 25), + padding='same', + activation='relu', + name='conv2d_1')(input_) + x = tf.keras.layers.MaxPool2D(pool_size=(1, 3), strides=(1, 3), padding='same',name='pool_1')(x) + x = tf.keras.layers.Flatten(name='flatten')(x) + x = tf.keras.layers.Dense(32, activation='relu', + name='dense_1')(x) + if is_training: + x = tf.keras.layers.Dropout(0.1, name='dropout')(x) # 0.1 or 0.2 acc about 86% + if classifier_activation == 'softmax': + output_ = tf.keras.layers.Dense(12, activation='softmax', name='dense_2')(x) + else: + output_ = tf.keras.layers.Dense(12, activation=None, name='dense_2')(x) + model = tf.keras.Model(input_, output_, name=name) + return model diff --git a/src/model_optimizer/pruner/scheduler/distill/cnn1d_tiny_0.3.yaml b/src/model_optimizer/pruner/scheduler/distill/cnn1d_tiny_0.3.yaml new file mode 100644 index 0000000..555c106 --- /dev/null +++ b/src/model_optimizer/pruner/scheduler/distill/cnn1d_tiny_0.3.yaml @@ -0,0 +1,6 @@ +version: 1 +distill: + alpha: 0.3 + temperature: 5 + student_name: "cnn1d_tiny" + teacher_path: "/root/work/adlik_public_traffic/model_optimizer/examples/models_eval_ckpt/cnn1d/checkpoint-60-logits.h5" \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py index 62f7dc8..ff44cd1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -29,7 +29,7 @@ def test_get_model_distill(): } config = prune_conf_from_obj(request) - train_model = get_model(config, is_training=True) + train_model = get_model(config, (244, 244, 3), is_training=True) for layer in train_model.layers: if layer.name == "DistillLoss": assert False diff --git a/tools/convert_softmax_model_to_logits.py b/tools/convert_softmax_model_to_logits.py new file mode 100644 index 0000000..906f1a4 --- /dev/null +++ b/tools/convert_softmax_model_to_logits.py @@ -0,0 +1,31 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Convert softmax model to logits model used for predict +""" + +import argparse +import os +import tensorflow as tf + +if __name__ == '__main__': + base_dir = os.path.dirname(__file__) + parser = argparse.ArgumentParser() + + parser.add_argument( + '--scr_path', + default=os.path.join(base_dir, "../examples/models_eval_ckpt/cnn1d", "checkpoint-60.h5"), + help='path of the model whose output is softmax') + parser.add_argument( + '--dest_path', + default=os.path.join(base_dir, "../examples/models_eval_ckpt/cnn1d", "checkpoint-60-logits.h5"), + help='path of the model whose output is logits') + + args = parser.parse_args() + loaded_model = tf.keras.models.load_model(args.scr_path) + loaded_model_json = loaded_model.to_json() + model_json = loaded_model_json.replace("softmax", "linear") + model_logits = tf.keras.models.model_from_json(model_json) + model_logits.set_weights(loaded_model.get_weights()) + model_logits.save(args.dest_path)