diff --git a/README.md b/README.md index 7eea204..4c0d3c5 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,7 @@ If you used the master branch before Sep. 26 2017 and its corresponding pretrain The old master branch in now under old_master, you can still run the code and download the pretrained model, but the pretrained model for that old master is not compatible to the current master! The main differences between new and old master branch are in this two commits: [9d4c24e](https://github.com/ruotianluo/pytorch-faster-rcnn/commit/9d4c24e83c3e4ec33751e50d5e4d8b1dd793dfaa), [c899ce7](https://github.com/ruotianluo/pytorch-faster-rcnn/commit/c899ce70dae62e3db1a5805eda96df88e4b59ca6) -The change is related to this [issue](https://github.com/ruotianluo/pytorch-faster-rcnn/issues/6); master now matches all the details (of resnet101) in [tf-faster-rcnn](https://github.com/endernewton/tf-faster-rcnn) so that we can now convert pretrained tf model to pytorch model (only resnet101 is supported). - -(vgg16 conversion is not among the top of my to-do list. May reach it some day.) +The change is related to this [issue](https://github.com/ruotianluo/pytorch-faster-rcnn/issues/6); master now matches all the details in [tf-faster-rcnn](https://github.com/endernewton/tf-faster-rcnn) so that we can now convert pretrained tf model to pytorch model. # pytorch-faster-rcnn A pytorch implementation of faster RCNN detection framework based on Xinlei Chen's [tf-faster-rcnn](https://github.com/endernewton/tf-faster-rcnn). Xinlei Chen's repository is based on the python Caffe implementation of faster RCNN available [here](https://github.com/rbgirshick/py-faster-rcnn). @@ -16,9 +14,9 @@ A pytorch implementation of faster RCNN detection framework based on Xinlei Chen The current code supports **VGG16**, **Resnet V1** and ~~**Mobilenet V1**~~ models. We mainly tested it on plain VGG16 and Resnet101 architecture. As the baseline, we report numbers using a single model on a single convolution layer, so no multi-scale, no multi-stage bounding box regression, no skip-connection, no extra input is used. The only data augmentation technique is left-right flipping during training following the original Faster RCNN. All models are released. With VGG16 (``conv5_3``): - - Train on VOC 2007 trainval and test on VOC 2007 test, **70.48**(crop and resize), **69.95**(roi pooling) (**71.2** for tf-faster-rcnn). - - Train on VOC 2007+2012 trainval and test on VOC 2007 test ([R-FCN](https://github.com/daijifeng001/R-FCN) schedule), **74.83**(crop and resize) **74.59**(roi pooling)(**75.3** for tf-faster-rcnn). - - Train on COCO 2014 [trainval35k](https://github.com/rbgirshick/py-faster-rcnn/tree/master/models) and test on [minival](https://github.com/rbgirshick/py-faster-rcnn/tree/master/models) (900k/1190k) **27.0** (**29.5** for tf-faster-rcnn). ((350k/490k) **24.6**(crop and resize) **21.8**(roi pooling)). + - Train on VOC 2007 trainval and test on VOC 2007 test, **70.48**(from scratch) **70.90**(converted) (**71.2** for tf-faster-rcnn). + - Train on VOC 2007+2012 trainval and test on VOC 2007 test ([R-FCN](https://github.com/daijifeng001/R-FCN) schedule), **74.83**(from scratch) **75.07**(converted) (**75.3** for tf-faster-rcnn). + - Train on COCO 2014 [trainval35k](https://github.com/rbgirshick/py-faster-rcnn/tree/master/models) and test on [minival](https://github.com/rbgirshick/py-faster-rcnn/tree/master/models) (900k/1190k) **27.0**(from scratch) **29.0**(converted) (**29.5** for tf-faster-rcnn). With Resnet101 (last ``conv4``): - Train on VOC 2007 trainval and test on VOC 2007 test, **74.84**(from scratch) **75.08**(converted) (**75.2** for tf-faster-rcnn). @@ -27,8 +25,8 @@ With Resnet101 (last ``conv4``): More Results: - Train Mobilenet (1.0, 224) on COCO 2014 trainval35k and test on minival (900k/1190k), ~~**21.9**~~. - - Train Resnet50 on COCO 2014 trainval35k and test on minival (900k/1190k), ~~**31.6**~~. - - Train Resnet152 on COCO 2014 trainval35k and test on minival (900k/1190k), ~~**35.2**~~. + - Train Resnet50 on COCO 2014 trainval35k and test on minival (900k/1190k), **31.4**(converted) (**31.6** for tf-faster-rcnn). + - Train Resnet152 on COCO 2014 trainval35k and test on minival (900k/1190k), **34.9**(converted) (**35.2** for tf-faster-rcnn). Approximate *baseline* [setup](https://github.com/endernewton/tf-faster-rcnn/blob/master/experiments/cfgs/res101-lg.yml) from [FPN](https://arxiv.org/abs/1612.03144) (this repo does not contain training code for FPN yet): - Train Resnet50 on COCO 2014 trainval35k and test on minival (900k/1190k), ~~**33.4**~~. @@ -36,13 +34,12 @@ Approximate *baseline* [setup](https://github.com/endernewton/tf-faster-rcnn/blo - Train Resnet152 on COCO 2014 trainval35k and test on minival (1000k/1390k), ~~**37.2**~~. **Note**: - - Compared to tf-faster-rcnn, we use roi pooling instead of crop_and_resize; we don't know how this affects result compared to tf-faster-rcnn. - - ~~Due to the randomness in GPU training with Tensorflow espeicially for VOC, the best numbers are reported (with 2-3 attempts) here. According to my experience, for COCO you can almost always get a very close number (within ~0.2%) despite the randomness.~~ + - Due to the randomness in GPU training espeicially for VOC, the best numbers are reported (with 2-3 attempts) here. According to Xinlei's experience, for COCO you can almost always get a very close number (within ~0.2%) despite the randomness. - **All** the numbers are obtained with a different testing scheme without selecting region proposals using non-maximal suppression (TEST.MODE top), the default and original testing scheme (TEST.MODE nms) will likely result in slightly worse performance (see [report](https://arxiv.org/pdf/1702.02138.pdf), for COCO it drops 0.X AP). - Since we keep the small proposals (\< 16 pixels width/height), our performance is especially good for small objects. - For other minor modifications, please check the [report](https://arxiv.org/pdf/1702.02138.pdf). Notable ones include ~~using ``crop_and_resize``~~, and excluding ground truth boxes in RoIs during training. - - For COCO, we find the performance improving with more iterations (VGG16 350k/490k: 26.9, 600k/790k: 28.3, 900k/1190k: 29.5), and potentially better performance can be achieved with even more iterations. - - For Resnets, we fix the first block (total 4) when fine-tuning the network, and only use ~~``crop_and_resize``~~ roi pooling to resize the RoIs (7x7) without max-pool (~~which I find useless especially for COCO~~). The final feature maps are average-pooled for classification and regression. All batch normalization parameters are fixed. Weight decay is set to Renset101 default 1e-4. Learning rate for biases is not doubled. + - For COCO, we find the performance improving with more iterations (VGG16 350k/490k: 26.9, 600k/790k: 28.3, 900k/1190k: 29.5) (number from tf-faster-rcnn), and potentially better performance can be achieved with even more iterations. + - For Resnets, we fix the first block (total 4) when fine-tuning the network, and only use ``crop_and_resize`` to resize the RoIs (7x7) without max-pool (which Xinlei find useless especially for COCO). The final feature maps are average-pooled for classification and regression. All batch normalization parameters are fixed. Weight decay is set to Renset101 default 1e-4. Learning rate for biases is not doubled. - For approximate [FPN](https://arxiv.org/abs/1612.03144) baseline setup we simply resize the image with 800 pixels, add 32^2 anchors, and take 1000 proposals during testing. - Check out ~~[here](http://ladoga.graphics.cs.cmu.edu/xinleic/tf-faster-rcnn/)/[here](http://gs11655.sp.cs.cmu.edu/xinleic/tf-faster-rcnn/)/~~[here](https://drive.google.com/open?id=0B7fNdx_jAqhtWERtcnZOanZGSG8) for the latest models, including longer COCO VGG16 models and Resnet ones. @@ -126,11 +123,12 @@ If you find it useful, the ``data/cache`` folder created on my side is also shar - Google drive [here](https://drive.google.com/open?id=0B7fNdx_jAqhtNE10TDZDbFRuU0E). **(Optional)** -Instead of downloading my pretrained or converted model, you can also convert from tf-faster-rcnn model (Only support res101 currently). +Instead of downloading my pretrained or converted model, you can also convert from tf-faster-rcnn model. You can download the tensorflow pretrained model from [tf-faster-rcnn](https://github.com/endernewton/tf-faster-rcnn/#demo-and-test-with-pre-trained-models). Then run: ```Shell -python tools/convert_from_tensorflow.py --tensorflow_model model_name.ckpt +python tools/convert_from_tensorflow.py --tensorflow_model resnet_model.ckpt +python tools/convert_from_tensorflow_vgg.py --tensorflow_model vgg_model.ckpt ``` This script will create a `.pth` file with the same name in the same folder as the tensorflow model. diff --git a/tools/convert_from_tensorflow_vgg.py b/tools/convert_from_tensorflow_vgg.py new file mode 100644 index 0000000..8edda11 --- /dev/null +++ b/tools/convert_from_tensorflow_vgg.py @@ -0,0 +1,83 @@ +import tensorflow as tf +from tensorflow.python import pywrap_tensorflow +from collections import OrderedDict +import re +import torch + +import argparse +parser = argparse.ArgumentParser(description='Convert tf-faster-rcnn model to pytorch-faster-rcnn model') +parser.add_argument('--tensorflow_model', + help='the path of tensorflow_model', + default=None, type=str) + +args = parser.parse_args() + +reader = pywrap_tensorflow.NewCheckpointReader(args.tensorflow_model) +var_to_shape_map = reader.get_variable_to_shape_map() +var_dict = {k:reader.get_tensor(k) for k in var_to_shape_map.keys()} + +del var_dict['Variable'] + +for k in var_dict.keys(): + if 'Momentum' in k: + del var_dict[k] + +for k in var_dict.keys(): + if k.find('/') >= 0: + var_dict['vgg' + k[k.find('/'):]] = var_dict[k] + del var_dict[k] + +dummy_replace = OrderedDict([ + ('weights', 'weight'),\ + ('biases', 'bias'),\ + ('vgg/rpn_conv/3x3', 'rpn_net'),\ + ('vgg/rpn_cls_score', 'rpn_cls_score_net'),\ + ('vgg/cls_score', 'cls_score_net'),\ + ('vgg/rpn_bbox_pred', 'rpn_bbox_pred_net'),\ + ('vgg/bbox_pred', 'bbox_pred_net'),\ + ('/', '.')]) + +for a, b in dummy_replace.items(): + for k in var_dict.keys(): + if a in k: + var_dict[k.replace(a,b)] = var_dict[k] + del var_dict[k] + +layer_map = OrderedDict([ + ('conv1.conv1_1', 'features.0'),\ + ('conv1.conv1_2', 'features.2'),\ + ('conv2.conv2_1', 'features.5'),\ + ('conv2.conv2_2', 'features.7'),\ + ('conv3.conv3_1', 'features.10'),\ + ('conv3.conv3_2', 'features.12'),\ + ('conv3.conv3_3', 'features.14'),\ + ('conv4.conv4_1', 'features.17'),\ + ('conv4.conv4_2', 'features.19'),\ + ('conv4.conv4_3', 'features.21'),\ + ('conv5.conv5_1', 'features.24'),\ + ('conv5.conv5_2', 'features.26'),\ + ('conv5.conv5_3', 'features.28'),\ + ('fc6', 'classifier.0'),\ + ('fc7', 'classifier.3')]) + +for a, b in layer_map.items(): + for k in var_dict.keys(): + if a in k: + var_dict[k.replace(a,b)] = var_dict[k] + del var_dict[k] + +for k in var_dict.keys(): + if 'classifier.0' in k: + if var_dict[k].ndim == 2: # weight + var_dict[k] = var_dict[k].reshape(7,7,512,4096).transpose((3, 2, 0, 1)).reshape(4096, -1).copy(order='C') + else: + if var_dict[k].ndim == 4: + var_dict[k] = var_dict[k].transpose((3, 2, 0, 1)).copy(order='C') + if var_dict[k].ndim == 2: + var_dict[k] = var_dict[k].transpose((1, 0)).copy(order='C') + # assert x[k].shape == var_dict[k].shape, k + +for k in var_dict.keys(): + var_dict[k] = torch.from_numpy(var_dict[k]) + +torch.save(var_dict, args.tensorflow_model[:args.tensorflow_model.find('.ckpt')]+'.pth')