Skip to content

Commit

Permalink
support 1dcnn distilling
Browse files Browse the repository at this point in the history
Signed-off-by: priscilla-pan <[email protected]>
  • Loading branch information
priscilla-pan committed Sep 16, 2021
1 parent 385fcf2 commit 1c98bfd
Show file tree
Hide file tree
Showing 27 changed files with 546 additions and 19 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ The details are shown in the table below, and the code can refer to examples\res
| + pruned + distill | 76.39 | 6954152 ( 72.8% pruned) | 1075M | 27M|
| + pruned + distill + quantization(TF-Lite) | 75.938 | - | - | 7.1M|

We also impletement a 1D-CNN distillation which shows distillation is also effective on Encrypted Traffic Classification.
You can get detailed instructions from [here](doc/CNN-1D-tiny-Distillation.md). Following this instruction, you can build
your own dataset and model to train and distill under adlik model optimizer.

## 1. Pruning and quantization principle

### 1.1 Filter pruning
Expand Down
122 changes: 122 additions & 0 deletions doc/CNN-1D-tiny-Distillation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Tiny 1D-CNN Knowledge Distillation

The following uses 1D-CNN on the 12 classes session all dataset as teacher model to illustrate how to use the model
optimizer to improve the preformance of tiny 1D-CNN by knowledge distillation.

The 1D-CNN model is from Wang's paper[Wang, W.; Zhu, M.; Wang, J.; Zeng, X.; Yang, Z. End-to-end encrypted traffic
classification with one-dimensional convolution neural networks.] The tiny 1D-CNN model is a slim version of the
1D-CNN model mentioned before. Using 1D-CNN model as the teacher to ditstill tiny 1D-CNN model, performance can be
improved by 5.66%.

The details are shown in the table below, and the code can refer to examples\cnn1d_tiny_iscx_session_all_distill.py.

| Model | Accuracy | Params | Model Size |
| --------- | -------- | -------------------- | ---------------------------- |
| cnn1d | 92.67% | 5832588 | 23M|
| cnn1d_tiny | 87.62% | 134988 | 546K|
| cnn1d_tiny+ distill | 93.28% | 134988 | 546K|

## 1 Create custom dataset

Using [ISCX dataset](https://www.unb.ca/cic/datasets/vpn.html), you can get the processed 12-classes-session-all dataset
from [wang's github](https://github.com/echowei/DeepTraffic/blob/master/2.encrypted_traffic_classification/3.PerprocessResults/12class.zip).
We name the dataset as iscx_session_all. In the iscx_session_all, there are 35501 training samples, the shape is
(35501, 28, 28), 3945 testing samples.

Now that you have the dataset, you can implement your custom dateset by extending model_optimizer.prunner.dataset.
dataset_base.DatasetBase and implementing:

1. \__init__, required, where you can do all dataset initialization
2. parse_fn, required, where is the map function of the dataset
3. parse_fn_distill, required, where is the map function of the dataset used in distillation
4. build, optional, where is the process of building the dataset. If your dataset is not in tfrecord format, you must
implement this function.

Here in the custom dataset, we reshape the samples from (None, 28, 28, 1) to (None, 1, 784, 1) for the following 1D-CNN
models.

After that, all you need is put the dataset name in the following files:

1. src/model_optimizer/prunner/config_schema.json the "enum" list
2. src/model_optimizer/prunner/dataset/\__init__.py. Add the dataset name in Line 19 and add the dataset instance in the
if-else clause.

## Create custom model

Create your own model using The Keras functional API in model_optimizer.prunner.models.

After that, all you need is put the model name and initialize the model in the following files:

1. src/model_optimizer/prunner/models/\__init__.py. Add the model name in Line 21 and add the model instance in the
if-else clause.

## Create custom learner

Implement your own learner by extending model_optimizer.prunner.learner.learner_base.LearnerBase and implementing:

1. \__init__, required, where you can define your own learning rate callback
2. get_optimizer, required, where you can define your own optimizer
3. get_losses, required, where you can define your own loss function
4. get_metrics, required, where you can define your own metrics

After that, all you need is put the model name and dataset name and initialize the learner in the following files:

1. src/model_optimizer/prunner/learner/\__init__.py

## Create the training process of the teacher model, and train the teacher model

Enter the examples directory, create cnn1d_iscx_session_all_train.py for cnn1d model.

> Note
>
> > the "model_name" and "dataset" in the request must be the same as you defined before
Execute:

```shell
cd examples
python3 cnn1d_iscx_session_all_train.py
```

After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d, and the inference
checkpoint file will be generated in ./models_eval_ckpt/cnn1d. You can also modify the checkpoint_path
and checkpoint_eval_path of the cnn1d_iscx_session_all_train.py file to change the generated file path.

## Convert the teacher model to logits output

Enter the tools directory and execute:

```shell
cd tools
python3 convert_softmax_model_to_logits.py
```

After execution, the default checkpoint file of logits model will be generated in examples/models_eval_ckpt/cnn1d/
checkpoint-60-logits.h5

## Create the distilling process and distill the cnn1d_tiny model

Create the configuration file in the src/model_optimizer/pruner/scheduler/distill,like "cnn1d_tiny_0.3.yaml" where the
distillation parameters is configured.

Enter the examples directory, create cnn1d_tiny_iscx_session_all_distill.py for cnn1d_tiny model. In the distilling
process, the teacher is cnn1d, the student is cnn1d_tiny.

> Note
>
> > the "model_name" and "dataset" in the request must be the same as you defined before
```shell
python3 cnn1d_tiny_iscx_session_all_distill.py
```

After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d_tiny_distill, and the inference
checkpoint file will be generated in ./models_eval_ckpt/cnn1d_tiny_distill. You can also modify the checkpoint_path and
checkpoint_eval_path of the cnn1d_tiny_iscx_session_all_distill.py file to change the generated file path.

> Note
>
> > i. The model in the checkpoint_path is not the pure cnn1d_tiny model. It's the hybird of cnn1d_tiny(student) and
> > cnn1d(teacher)
> >
> > ii. The model in the checkpoint_eval_path is the distilled model, i.e. pure cnn1d_tiny model
36 changes: 36 additions & 0 deletions examples/cnn1d_iscx_session_all_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2019 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Train a cnn1d model on iscx_session_all dataset
"""
import os
# If you did not execute the setup.py, uncomment the following four lines
# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer import prune_model # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
request = {
"dataset": "iscx_session_all",
"model_name": "cnn1d",
"data_dir": "/data/12class/SessionAllLayers",
"batch_size": 500,
"batch_size_val": 100,
"learning_rate": 1e-3,
"epochs": 60,
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d"),
"checkpoint_save_period": 1, # save a checkpoint every epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d"),
"scheduler": "train"
}
prune_model(request)


if __name__ == "__main__":
_main()
37 changes: 37 additions & 0 deletions examples/cnn1d_tiny_iscx_session_all_distill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2019 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Distill a cnn1d_tiny model from a trained cnn1d model on the iscx_session_all dataset
"""
import os
# If you did not execute the setup.py, uncomment the following four lines
# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer import prune_model # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
request = {
"dataset": "iscx_session_all",
"model_name": "cnn1d_tiny",
"data_dir": "/data/12class/SessionAllLayers",
"batch_size": 500,
"batch_size_val": 100,
"learning_rate": 1e-3,
"epochs": 200,
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny_distill"),
"checkpoint_save_period": 10, # save a checkpoint every epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny_distill"),
"scheduler": "distill",
"scheduler_file_name": "cnn1d_tiny_0.3.yaml"
}
prune_model(request)


if __name__ == "__main__":
_main()
36 changes: 36 additions & 0 deletions examples/cnn1d_tiny_iscx_session_all_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2019 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Train a cnn1d_tiny model on iscx_session_all dataset
"""
import os
# If you did not execute the setup.py, uncomment the following four lines
# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer import prune_model # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
request = {
"dataset": "iscx_session_all",
"model_name": "cnn1d_tiny",
"data_dir": "/data/12class/SessionAllLayers",
"batch_size": 500,
"batch_size_val": 100,
"learning_rate": 1e-3,
"epochs": 60,
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny"),
"checkpoint_save_period": 10, # save a checkpoint every epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny"),
"scheduler": "train"
}
prune_model(request)


if __name__ == "__main__":
_main()
2 changes: 1 addition & 1 deletion src/model_optimizer/pruner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def create_config_from_obj(obj) -> object:
:return:
"""
schema_path = os.path.join(os.path.dirname(__file__), 'config_schema.json')
with open(schema_path) as schema_file:
with open(schema_path, encoding='utf-8') as schema_file:
body_schema = json.load(schema_file)

jsonschema.validate(obj, body_schema)
Expand Down
3 changes: 2 additions & 1 deletion src/model_optimizer/pruner/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
"enum": [
"mnist",
"cifar10",
"imagenet"
"imagenet",
"iscx_session_all"
],
"description": "dataset name"
},
Expand Down
2 changes: 1 addition & 1 deletion src/model_optimizer/pruner/core/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_network(model):
for i, layer in enumerate(model.layers):
digraph.add_node(i, name=layer.name, type=str(type(layer)))
for i, layer in enumerate(model.layers):
for j in range(0, len(model.layers)):
for j, _ in enumerate(model.layers):
_inputs = model.layers[j].input
if isinstance(_inputs, list):
for _input in _inputs:
Expand Down
5 changes: 4 additions & 1 deletion src/model_optimizer/pruner/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0):
:return: class of Dataset
"""
dataset_name = config.get_attribute('dataset')
if dataset_name not in ['mnist', 'cifar10', 'imagenet']:
if dataset_name not in ['mnist', 'cifar10', 'imagenet', 'iscx_session_all']:
raise Exception('Not support dataset %s' % dataset_name)
if dataset_name == 'mnist':
from .mnist import MnistDataset
Expand All @@ -27,5 +27,8 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0):
elif dataset_name == 'imagenet':
from .imagenet import ImagenetDataset
return ImagenetDataset(config, is_training, num_shards, shard_index)
elif dataset_name == 'iscx_session_all':
from .iscx_session_all import ISCXDataset
return ISCXDataset(config, is_training, num_shards, shard_index)
else:
raise Exception('Not support dataset {}'.format(dataset_name))
1 change: 1 addition & 0 deletions src/model_optimizer/pruner/dataset/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, config, is_training):
self.buffer_size = 10000
self.num_samples_of_train = 50000
self.num_samples_of_val = 10000
self.data_shape = (32, 32, 3)

# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
def parse_fn(self, example_serialized):
Expand Down
4 changes: 2 additions & 2 deletions src/model_optimizer/pruner/dataset/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def build(self, is_distill=False):
dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE)
else:
dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return self.__build_batch(dataset)
return self.build_batch(dataset)

def __build_batch(self, dataset):
def build_batch(self, dataset):
"""
Make an batch from tf.data.Dataset.
:param dataset: tf.data.Dataset object
Expand Down
1 change: 1 addition & 0 deletions src/model_optimizer/pruner/dataset/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, config, is_training, num_shards=1, shard_index=0):
self.buffer_size = 10000
self.num_samples_of_train = 1281167
self.num_samples_of_val = 50000
self.data_shape = (224, 224, 3)

# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
def parse_fn(self, example_serialized):
Expand Down
Loading

0 comments on commit 1c98bfd

Please sign in to comment.