Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support 1dcnn distilling #25

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ The details are shown in the table below, and the code can refer to examples\res
| + pruned + distill | 76.39 | 6954152 ( 72.8% pruned) | 1075M | 27M|
| + pruned + distill + quantization(TF-Lite) | 75.938 | - | - | 7.1M|

We also impletement a 1D-CNN distillation which shows distillation is also effective on Encrypted Traffic Classification.
You can get detailed instructions from [here](doc/CNN-1D-tiny-Distillation.md). Following this instruction, you can build
your own dataset and model to train and distill under adlik model optimizer.

## 1. Pruning and quantization principle

### 1.1 Filter pruning
Expand Down
113 changes: 113 additions & 0 deletions doc/CNN-1D-tiny-Distillation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Tiny 1D-CNN Knowledge Distillation

The following uses 1D-CNN on the 12 classes session all dataset as teacher model to illustrate how to use the model
optimizer to improve the preformance of tiny 1D-CNN by knowledge distillation.

The 1D-CNN model is from Wang's paper[Wang, W.; Zhu, M.; Wang, J.; Zeng, X.; Yang, Z. End-to-end encrypted traffic
classification with one-dimensional convolution neural networks.] The tiny 1D-CNN model is a slim version of the 1D-CNN
model mentioned before. Using 1D-CNN model as the teacher to ditstill tiny 1D-CNN model, performance can be improved by
5.66%.

The details are shown in the table below, and the code can refer to examples\cnn1d_tiny_iscx_session_all_distill.py.

| Model | Accuracy | Params | Model Size |
| --------- | -------- | -------------------- | ---------------------------- |
| cnn1d | 92.67% | 5832588 | 23M|
| cnn1d_tiny | 87.62% | 134988 | 546K|
| cnn1d_tiny+ distill | 93.28% | 134988 | 546K|


## 1 Create custom dataset
Using [ISCX dataset](https://www.unb.ca/cic/datasets/vpn.html), you can get the processed 12-classes-session-all dataset
from [wang's github](https://github.com/echowei/DeepTraffic/blob/master/2.encrypted_traffic_classification/3.PerprocessResults/12class.zip).
We name the dataset as iscx_session_all. In the iscx_session_all, there are 35501 training samples, the shape is (35501, 28, 28),
3945 testing samples.

Now that you have the dataset, you can implement your custom dateset by extending model_optimizer.prunner.dataset.
dataset_base.DatasetBase and implementing:

1. \__init__, required, where you can do all dataset initialization
2. parse_fn, required, where is the map function of the dataset
3. parse_fn_distill, required, where is the map function of the dataset used in distillation
4. build, optional, where is the process of building the dataset. If your dataset is not in tfrecord format, you must
implement this function.

Here in the custom dataset, we reshape the samples from (None, 28, 28, 1) to (None, 1, 784, 1) for the following 1D-CNN
models.

After that, all you need is put the dataset name in the following files:
1. src/model_optimizer/prunner/config_schema.json the "enum" list
2. src/model_optimizer/prunner/dataset/\__init__.py. Add the dataset name in Line 19 and add the dataset instance in the
if-else clause.

## Create custom model
Create your own model using The Keras functional API in model_optimizer.prunner.models.

After that, all you need is put the model name and initialize the model in the following files:
1. src/model_optimizer/prunner/models/\__init__.py. Add the model name in Line 21 and add the model instance in the
if-else clause.

## Create custom learner
Implement your own learner by extending model_optimizer.prunner.learner.learner_base.LearnerBase and implementing:
1. \__init__, required, where you can define your own learning rate callback
2. get_optimizer, required, where you can define your own optimizer
3. get_losses, required, where you can define your own loss function
4. get_metrics, required, where you can define your own metrics

After that, all you need is put the model name and dataset name and initialize the learner in the following files:
1. src/model_optimizer/prunner/learner/\__init__.py

## Create the training process of the teacher model, and train the teacher model
Enter the examples directory, create cnn1d_iscx_session_all_train.py for cnn1d model.

> Note
>
> > the "model_name" and "dataset" in the request must be the same as you defined before

Execute:

```shell
cd examples
python3 cnn1d_iscx_session_all_train.py
```

After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d, and the inference
checkpoint file will be generated in ./models_eval_ckpt/cnn1d. You can also modify the checkpoint_path
and checkpoint_eval_path of the cnn1d_iscx_session_all_train.py file to change the generated file path.

## Convert the teacher model to logits output
Enter the tools directory and execute:
```shell
cd tools
python3 convert_softmax_model_to_logits.py
```

After execution, the default checkpoint file of logits model will be generated in examples/models_eval_ckpt/cnn1d/
checkpoint-60-logits.h5

## Create the distilling process and distill the cnn1d_tiny model
Create the configuration file in the src/model_optimizer/pruner/scheduler/distill,like "cnn1d_tiny_0.3.yaml" where the
distillation parameters is configured.

Enter the examples directory, create cnn1d_tiny_iscx_session_all_distill.py for cnn1d_tiny model. In the distilling
process, the teacher is cnn1d, the student is cnn1d_tiny.

> Note
>
> > the "model_name" and "dataset" in the request must be the same as you defined before

```shell
python3 cnn1d_tiny_iscx_session_all_distill.py
```

After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d_tiny_distill, and the inference
checkpoint file will be generated in ./models_eval_ckpt/cnn1d_tiny_distill. You can also modify the checkpoint_path and
checkpoint_eval_path of the cnn1d_tiny_iscx_session_all_distill.py file to change the
generated file path.

> Note
>
> > i. The model in the checkpoint_path is not the pure cnn1d_tiny model. It's the hybird of cnn1d_tiny(student) and
> > cnn1d(teacher)
> >
> > ii. The model in the checkpoint_eval_path is the distilled model, i.e. pure cnn1d_tiny model
36 changes: 36 additions & 0 deletions examples/cnn1d_iscx_session_all_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2019 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Train a cnn1d model on iscx_session_all dataset
"""
import os
# If you did not execute the setup.py, uncomment the following four lines
# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer import prune_model # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
request = {
"dataset": "iscx_session_all",
"model_name": "cnn1d",
"data_dir": "/data/12class/SessionAllLayers",
"batch_size": 500,
"batch_size_val": 100,
"learning_rate": 1e-3,
"epochs": 60,
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d"),
"checkpoint_save_period": 1, # save a checkpoint every epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d"),
"scheduler": "train"
}
prune_model(request)


if __name__ == "__main__":
_main()
37 changes: 37 additions & 0 deletions examples/cnn1d_tiny_iscx_session_all_distill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2019 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Distill a cnn1d_tiny model from a trained cnn1d model on the iscx_session_all dataset
"""
import os
# If you did not execute the setup.py, uncomment the following four lines
# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer import prune_model # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
request = {
"dataset": "iscx_session_all",
"model_name": "cnn1d_tiny",
"data_dir": "/data/12class/SessionAllLayers",
"batch_size": 500,
"batch_size_val": 100,
"learning_rate": 1e-3,
"epochs": 200,
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny_distill"),
"checkpoint_save_period": 10, # save a checkpoint every epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny_distill"),
"scheduler": "distill",
"scheduler_file_name": "cnn1d_tiny_0.3.yaml"
}
prune_model(request)


if __name__ == "__main__":
_main()
36 changes: 36 additions & 0 deletions examples/cnn1d_tiny_iscx_session_all_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2019 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Train a cnn1d_tiny model on iscx_session_all dataset
"""
import os
# If you did not execute the setup.py, uncomment the following four lines
# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer import prune_model # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
request = {
"dataset": "iscx_session_all",
"model_name": "cnn1d_tiny",
"data_dir": "/data/12class/SessionAllLayers",
"batch_size": 500,
"batch_size_val": 100,
"learning_rate": 1e-3,
"epochs": 60,
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny"),
"checkpoint_save_period": 10, # save a checkpoint every epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny"),
"scheduler": "train"
}
prune_model(request)


if __name__ == "__main__":
_main()
3 changes: 2 additions & 1 deletion src/model_optimizer/pruner/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
"enum": [
"mnist",
"cifar10",
"imagenet"
"imagenet",
"iscx_session_all"
],
"description": "dataset name"
},
Expand Down
5 changes: 4 additions & 1 deletion src/model_optimizer/pruner/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0):
:return: class of Dataset
"""
dataset_name = config.get_attribute('dataset')
if dataset_name not in ['mnist', 'cifar10', 'imagenet']:
if dataset_name not in ['mnist', 'cifar10', 'imagenet', 'iscx_session_all']:
raise Exception('Not support dataset %s' % dataset_name)
if dataset_name == 'mnist':
from .mnist import MnistDataset
Expand All @@ -27,5 +27,8 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0):
elif dataset_name == 'imagenet':
from .imagenet import ImagenetDataset
return ImagenetDataset(config, is_training, num_shards, shard_index)
elif dataset_name == 'iscx_session_all':
from .iscx_session_all import ISCXDataset
return ISCXDataset(config, is_training, num_shards, shard_index)
else:
raise Exception('Not support dataset {}'.format(dataset_name))
1 change: 1 addition & 0 deletions src/model_optimizer/pruner/dataset/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, config, is_training):
self.buffer_size = 10000
self.num_samples_of_train = 50000
self.num_samples_of_val = 10000
self.data_shape = (32, 32, 3)

# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
def parse_fn(self, example_serialized):
Expand Down
4 changes: 2 additions & 2 deletions src/model_optimizer/pruner/dataset/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def build(self, is_distill=False):
dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE)
else:
dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return self.__build_batch(dataset)
return self.build_batch(dataset)

def __build_batch(self, dataset):
def build_batch(self, dataset):
"""
Make an batch from tf.data.Dataset.
:param dataset: tf.data.Dataset object
Expand Down
1 change: 1 addition & 0 deletions src/model_optimizer/pruner/dataset/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, config, is_training, num_shards=1, shard_index=0):
self.buffer_size = 10000
self.num_samples_of_train = 1281167
self.num_samples_of_val = 50000
self.data_shape = (224, 224, 3)

# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
def parse_fn(self, example_serialized):
Expand Down
85 changes: 85 additions & 0 deletions src/model_optimizer/pruner/dataset/iscx_session_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2019 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
ISCX 12 class session all dataset
https://github.com/echowei/DeepTraffic/blob/master/2.encrypted_traffic_classification/3.PerprocessResults/12class.zip

"""
import os
import gzip
import tensorflow as tf
from tensorflow.python.keras.utils.data_utils import get_file
import numpy as np
from .dataset_base import DatasetBase


class ISCXDataset(DatasetBase):
"""
ISCX session all layer dataset
"""
def __init__(self, config, is_training, num_shards=1, shard_index=0):
"""
Constructor function.
:param config: Config object
:param is_training: whether to construct the training subset
:return:
"""
super().__init__(config, is_training, num_shards, shard_index)
if is_training:
self.batch_size = self.batch_size
else:
self.batch_size = self.batch_size_val
self.buffer_size = 5000
self.num_samples_of_train = 35501
self.num_samples_of_val = 3945
self.data_shape = (1, 784, 1)

# pylint: disable=R0201
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
def parse_fn(self, *content):
data, label = content
return data, label

def parse_fn_distill(self, *content):
"""
Parse dataset for distillation
:param content: item content of the dataset
:return: {image, label},{}
"""
image, label = self.parse_fn(*content)
inputs = {"image": image, "label": label}
targets = {}
return inputs, targets

def build(self, is_distill=False):
"""
Build dataset
:param is_distill: is distilling or not
:return: batch of a dataset
"""
if self.is_training:
x_path = os.path.join(self.data_dir, 'train-images-idx3-ubyte.gz')
y_path = os.path.join(self.data_dir, 'train-labels-idx1-ubyte.gz')
else:
x_path = os.path.join(self.data_dir, 't10k-images-idx3-ubyte.gz')
y_path = os.path.join(self.data_dir, 't10k-labels-idx1-ubyte.gz')

with gzip.open(y_path, 'rb') as lbpath:
y_data = np.frombuffer(lbpath.read(), np.uint8, offset=8)

with gzip.open(x_path, 'rb') as imgpath:
x_data = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_data), 1, 784)

dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))

if self.num_shards != 1:
dataset = dataset.shard(num_shards=self.num_shards, index=self.shard_index)
if self.is_training:
dataset = dataset.shuffle(buffer_size=self.buffer_size).repeat()
if is_distill:
dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE)
else:
dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return self.build_batch(dataset)
1 change: 1 addition & 0 deletions src/model_optimizer/pruner/dataset/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, config, is_training):
self.buffer_size = 10000
self.num_samples_of_train = 60000
self.num_samples_of_val = 10000
self.data_shape = (28, 28, 1)

# pylint: disable=R0201
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
Expand Down
Loading