diff --git a/docs/source/experiments/cifar.rst b/docs/source/experiments/cifar.rst
index 34bd16ce..5e057720 100644
--- a/docs/source/experiments/cifar.rst
+++ b/docs/source/experiments/cifar.rst
@@ -57,19 +57,6 @@ Train Your Own Model
--eval evaluating
-Extending the Software
-----------------------
-
-This code is well written, easy to use and extendable for your own models or datasets:
-
-- Write your own Dataloader ``mydataset.py`` to ``dataset/`` folder
-
-- Write your own Model ``mymodel.py`` to ``model/`` folder
-
-- Run the program::
-
- python main.py --dataset mydataset --model mymodel
-
Citation
--------
diff --git a/docs/source/experiments/segmentation.rst b/docs/source/experiments/segmentation.rst
index 16ac660d..7d1e2a60 100644
--- a/docs/source/experiments/segmentation.rst
+++ b/docs/source/experiments/segmentation.rst
@@ -38,25 +38,19 @@ Test Pre-trained Model
.. role:: raw-html(raw)
:format: html
-+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
-| Model | pixAcc | mIoU | Command | Logs |
-+==================================+===========+===========+==============================================================================================+============+
-| Encnet_ResNet50_PContext | 79.2% | 51.0% | :raw-html:`cmd` | ENC50PC_ |
-+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
-| EncNet_ResNet101_PContext | 80.7% | 54.1% | :raw-html:`cmd` | ENC101PC_ |
-+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
-| EncNet_ResNet50_ADE | 80.1% | 41.5% | :raw-html:`cmd` | ENC50ADE_ |
-+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
-| EncNet_ResNet101_ADE | 81.3% | 44.4% | :raw-html:`cmd` | ENC101ADE_ |
-+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
-| EncNet_ResNet101_VOC | N/A | 85.9% | :raw-html:`cmd` | ENC101VOC_ |
-+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
-
-.. _ENC50PC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet50_pcontext.log?raw=true
-.. _ENC101PC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet101_pcontext.log?raw=true
-.. _ENC50ADE: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet50_ade.log?raw=true
-.. _ENC101ADE: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet101_ade.log?raw=true
-.. _ENC101VOC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet101_voc.log?raw=true
++----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
+| Model | pixAcc | mIoU | Command |
++==================================+===========+===========+==============================================================================================+
+| Encnet_ResNet50_PContext | 79.2% | 51.0% | :raw-html:`cmd` |
++----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
+| EncNet_ResNet101_PContext | 80.7% | 54.1% | :raw-html:`cmd` |
++----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
+| EncNet_ResNet50_ADE | 80.1% | 41.5% | :raw-html:`cmd` |
++----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
+| EncNet_ResNet101_ADE | 81.3% | 44.4% | :raw-html:`cmd` |
++----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
+| EncNet_ResNet101_VOC | N/A | 85.9% | :raw-html:`cmd` |
++----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
.. raw:: html
diff --git a/docs/source/experiments/texture.rst b/docs/source/experiments/texture.rst
index 769c2a49..1278c429 100644
--- a/docs/source/experiments/texture.rst
+++ b/docs/source/experiments/texture.rst
@@ -22,16 +22,11 @@ Test Pre-trained Model
cd PyTorch-Encoding/
python scripts/prepare_minc.py
-- Download pre-trained model (pre-trained on train-1 split using single training size of 224, with an error rate of :math:`18.96\%` using single crop on test-1 set)::
+- Test pre-trained model on MINC-2500. The pre-trained weight will be automatic downloaded (pre-trained on train-1 split using single training size of 224, with an error rate of :math:`18.96\%` using single crop on test-1 set)::
- cd experiments/recognition
- python model/download_models.py
-
-- Test pre-trained model on MINC-2500::
-
- python main.py --dataset minc --model deepten --nclass 23 --resume deepten_minc.pth --eval
+ python main.py --dataset minc --model deepten_resnet50_minc --nclass 23 --pretrained --eval
# Teriminal Output:
- # Loss: 1.005 | Err: 18.96% (1090/5750): 100%|████████████████████| 23/23 [00:18<00:00, 1.26it/s]
+ # Loss: 0.995 | Err: 18.957% (1090/5750): 100%|████████████████████| 23/23 [00:18<00:00, 1.26it/s]
Train Your Own Model
@@ -39,7 +34,7 @@ Train Your Own Model
- Example training command for training above model::
- CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --dataset minc --model deepten --nclass 23 --model deepten --batch-size 512 --lr 0.004 --epochs 80 --lr-step 60 --lr-scheduler step
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --dataset minc --model deepten_resnet50_minc --batch-size 512 --lr 0.004 --epochs 80 --lr-step 60 --lr-scheduler step --weight-decay 5e-4
- Detail training options::
@@ -62,20 +57,6 @@ Train Your Own Model
--eval evaluating
-Extending the Software
-----------------------
-
-This code is well written, easy to use and extendable for your own models or datasets:
-
-- Write your own Dataloader ``mydataset.py`` to ``dataset/`` folder
-
-- Write your own Model ``mymodel.py`` to ``model/`` folder
-
-- Run the program::
-
- python main.py --dataset mydataset --model mymodel
-
-
Citation
--------
diff --git a/docs/source/functions.rst b/docs/source/functions.rst
deleted file mode 100644
index e3f3c8c4..00000000
--- a/docs/source/functions.rst
+++ /dev/null
@@ -1,32 +0,0 @@
-.. role:: hidden
- :class: hidden-section
-
-encoding.functions
-==================
-
-.. automodule:: encoding.functions
-
-.. currentmodule:: encoding.functions
-
-
-:hidden:`batchnormtrain`
-~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. autofunction:: batchnormtrain
-
-:hidden:`aggregate`
-~~~~~~~~~~~~~~~~~~~
-
-.. autofunction:: aggregate
-
-
-:hidden:`scaled_l2`
-~~~~~~~~~~~~~~~~~~~
-
-.. autofunction:: scaled_l2
-
-
-:hidden:`sum_square`
-~~~~~~~~~~~~~~~~~~~~
-
-.. autofunction:: sum_square
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 5302df1f..fb8a9567 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -30,8 +30,7 @@ An optimized PyTorch package with CUDA backend.
nn
parallel
- dilated
- functions
+ models
utils
Indices and tables
diff --git a/docs/source/dilated.rst b/docs/source/models.rst
similarity index 91%
rename from docs/source/dilated.rst
rename to docs/source/models.rst
index 5aef805f..0ec7a81f 100644
--- a/docs/source/dilated.rst
+++ b/docs/source/models.rst
@@ -1,9 +1,15 @@
.. role:: hidden
:class: hidden-section
-encoding.dilated
+encoding.models
================
+.. automodule:: encoding.models.resnet
+.. currentmodule:: encoding.models.resnet
+
+ResNet
+------
+
We provide correct dilated pre-trained ResNet and DenseNet (stride of 8) for semantic segmentation.
For dilation of DenseNet, we provide :class:`encoding.nn.DilatedAvgPool2d`.
All provided models have been verified.
@@ -14,12 +20,6 @@ All provided models have been verified.
* Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. "Context Encoding for Semantic Segmentation" *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
-.. automodule:: encoding.dilated
-.. currentmodule:: encoding.dilated
-
-ResNet
-------
-
:hidden:`ResNet`
~~~~~~~~~~~~~~~~
diff --git a/docs/source/nn.rst b/docs/source/nn.rst
index 927d7036..7310ac15 100644
--- a/docs/source/nn.rst
+++ b/docs/source/nn.rst
@@ -14,10 +14,10 @@ Customized NN modules in Encoding Package. For Synchronized Cross-GPU Batch Norm
.. autoclass:: Encoding
:members:
-:hidden:`BatchNorm2d`
+:hidden:`SyncBatchNorm`
~~~~~~~~~~~~~~~~~~~~~~~~
-.. autoclass:: BatchNorm2d
+.. autoclass:: SyncBatchNorm
:members:
:hidden:`BatchNorm1d`
@@ -26,6 +26,12 @@ Customized NN modules in Encoding Package. For Synchronized Cross-GPU Batch Norm
.. autoclass:: BatchNorm1d
:members:
+:hidden:`BatchNorm2d`
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: BatchNorm2d
+ :members:
+
:hidden:`BatchNorm3d`
~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/docs/source/notes/compile.rst b/docs/source/notes/compile.rst
index 22b1ea20..40d570d8 100644
--- a/docs/source/notes/compile.rst
+++ b/docs/source/notes/compile.rst
@@ -2,13 +2,10 @@ Install and Citations
=====================
-Install from Source
--------------------
+Installation
+------------
- * Install PyTorch by following the `PyTorch instructions `_.
- This package relies on PyTorch master branch (higher than stable released v0.4.0), please follow
- `the instruction `_ to install
- PyTorch from source.
+ * Install PyTorch 1.0 by following the `PyTorch instructions `_.
* PIP Install::
diff --git a/encoding/__init__.py b/encoding/__init__.py
index 5dc68d55..2c33ce83 100644
--- a/encoding/__init__.py
+++ b/encoding/__init__.py
@@ -10,4 +10,4 @@
"""An optimized PyTorch package with CUDA backend."""
from .version import __version__
-from . import nn, functions, dilated, parallel, utils, models, datasets
+from . import nn, functions, parallel, utils, models, datasets, transforms
diff --git a/encoding/datasets/__init__.py b/encoding/datasets/__init__.py
index cdab5d76..ed9be3cf 100644
--- a/encoding/datasets/__init__.py
+++ b/encoding/datasets/__init__.py
@@ -1,3 +1,5 @@
+import warnings
+from torchvision.datasets import *
from .base import *
from .coco import COCOSegmentation
from .ade20k import ADE20KSegmentation
@@ -5,6 +7,10 @@
from .pascal_aug import VOCAugSegmentation
from .pcontext import ContextSegmentation
from .cityscapes import CitySegmentation
+from .imagenet import ImageNetDataset
+from .minc import MINCDataset
+
+from ..utils import EncodingDeprecationWarning
datasets = {
'coco': COCOSegmentation,
@@ -13,7 +19,40 @@
'pascal_aug': VOCAugSegmentation,
'pcontext': ContextSegmentation,
'citys': CitySegmentation,
+ 'imagenet': ImageNetDataset,
+ 'minc': MINCDataset,
+ 'cifar10': CIFAR10,
+}
+
+acronyms = {
+ 'coco': 'coco',
+ 'pascal_voc': 'voc',
+ 'pascal_aug': 'voc',
+ 'pcontext': pcontext,
+ 'ade20k': 'ade',
+ 'citys': 'citys',
+ 'minc': 'minc',
+ 'cifar10': 'cifar10',
}
-def get_segmentation_dataset(name, **kwargs):
+def get_dataset(name, **kwargs):
return datasets[name.lower()](**kwargs)
+
+def _make_deprecate(meth, old_name):
+ new_name = meth.__name__
+
+ def deprecated_init(*args, **kwargs):
+ warnings.warn("encoding.dataset.{} is now deprecated in favor of encoding.dataset.{}."
+ .format(old_name, new_name), EncodingDeprecationWarning)
+ return meth(*args, **kwargs)
+
+ deprecated_init.__doc__ = r"""
+ {old_name}(...)
+ .. warning::
+ This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
+ See :func:`~torch.nn.init.{new_name}` for details.""".format(
+ old_name=old_name, new_name=new_name)
+ deprecated_init.__name__ = old_name
+ return deprecated_init
+
+get_segmentation_dataset = _make_deprecate(get_dataset, 'get_segmentation_dataset')
diff --git a/encoding/datasets/ade20k.py b/encoding/datasets/ade20k.py
index 4ad1f853..56b172d1 100644
--- a/encoding/datasets/ade20k.py
+++ b/encoding/datasets/ade20k.py
@@ -57,6 +57,39 @@ def __getitem__(self, index):
mask = self.target_transform(mask)
return img, mask
+ def _sync_transform(self, img, mask):
+ # random mirror
+ if random.random() < 0.5:
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
+ crop_size = self.crop_size
+ w, h = img.size
+ long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.5))
+ if h > w:
+ oh = long_size
+ ow = int(1.0 * w * long_size / h + 0.5)
+ short_size = ow
+ else:
+ ow = long_size
+ oh = int(1.0 * h * long_size / w + 0.5)
+ short_size = oh
+ img = img.resize((ow, oh), Image.BILINEAR)
+ mask = mask.resize((ow, oh), Image.NEAREST)
+ # pad crop
+ if short_size < crop_size:
+ padh = crop_size - oh if oh < crop_size else 0
+ padw = crop_size - ow if ow < crop_size else 0
+ img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
+ mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
+ # random crop crop_size
+ w, h = img.size
+ x1 = random.randint(0, w - crop_size)
+ y1 = random.randint(0, h - crop_size)
+ img = img.crop((x1, y1, x1+crop_size, y1+crop_size))
+ mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size))
+ # final transform
+ return img, self._mask_transform(mask)
+
def _mask_transform(self, mask):
target = np.array(mask).astype('int64') - 1
return torch.from_numpy(target)
diff --git a/encoding/datasets/base.py b/encoding/datasets/base.py
index d2d476f9..52b38fd5 100644
--- a/encoding/datasets/base.py
+++ b/encoding/datasets/base.py
@@ -67,15 +67,16 @@ def _sync_transform(self, img, mask):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size
- # random scale (short edge from 480 to 720)
- short_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
w, h = img.size
+ long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
if h > w:
- ow = short_size
- oh = int(1.0 * h * ow / w)
+ oh = long_size
+ ow = int(1.0 * w * long_size / h + 0.5)
+ short_size = ow
else:
- oh = short_size
- ow = int(1.0 * w * oh / h)
+ ow = long_size
+ oh = int(1.0 * h * long_size / w + 0.5)
+ short_size = oh
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
@@ -90,10 +91,6 @@ def _sync_transform(self, img, mask):
y1 = random.randint(0, h - crop_size)
img = img.crop((x1, y1, x1+crop_size, y1+crop_size))
mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size))
- # gaussian blur as in PSP
- if random.random() < 0.5:
- img = img.filter(ImageFilter.GaussianBlur(
- radius=random.random()))
# final transform
return img, self._mask_transform(mask)
diff --git a/encoding/datasets/cityscapes.py b/encoding/datasets/cityscapes.py
index c5eeaaa2..8e3b2842 100644
--- a/encoding/datasets/cityscapes.py
+++ b/encoding/datasets/cityscapes.py
@@ -87,46 +87,6 @@ def __getitem__(self, index):
mask = self.target_transform(mask)
return img, mask
- def _sync_transform(self, img, mask):
- # random mirror
- if random.random() < 0.5:
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
- mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
- crop_size = self.crop_size
- # random scale (short edge from 480 to 720)
- short_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
- w, h = img.size
- if h > w:
- ow = short_size
- oh = int(1.0 * h * ow / w)
- else:
- oh = short_size
- ow = int(1.0 * w * oh / h)
- img = img.resize((ow, oh), Image.BILINEAR)
- mask = mask.resize((ow, oh), Image.NEAREST)
- # random rotate -10~10, mask using NN rotate
- deg = random.uniform(-10, 10)
- img = img.rotate(deg, resample=Image.BILINEAR)
- mask = mask.rotate(deg, resample=Image.NEAREST)
- # pad crop
- if short_size < crop_size:
- padh = crop_size - oh if oh < crop_size else 0
- padw = crop_size - ow if ow < crop_size else 0
- img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
- mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
- # random crop crop_size
- w, h = img.size
- x1 = random.randint(0, w - crop_size)
- y1 = random.randint(0, h - crop_size)
- img = img.crop((x1, y1, x1+crop_size, y1+crop_size))
- mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size))
- # gaussian blur as in PSP
- if random.random() < 0.5:
- img = img.filter(ImageFilter.GaussianBlur(
- radius=random.random()))
- # final transform
- return img, self._mask_transform(mask)
-
def _mask_transform(self, mask):
#target = np.array(mask).astype('int32') - 1
target = self._class_to_index(np.array(mask).astype('int32'))
diff --git a/encoding/datasets/coco.py b/encoding/datasets/coco.py
index 6bd39194..0cce3564 100644
--- a/encoding/datasets/coco.py
+++ b/encoding/datasets/coco.py
@@ -23,6 +23,7 @@ def __init__(self, root=os.path.expanduser('~/.encoding/data'), split='train',
self.root = os.path.join(root, 'train2017')
else:
print('val set')
+ assert split == 'val'
ann_file = os.path.join(root, 'annotations/instances_val2017.json')
ids_file = os.path.join(root, 'annotations/val_ids.pth')
self.root = os.path.join(root, 'val2017')
@@ -99,6 +100,7 @@ def _preprocess(self, ids, ids_file):
print('Found number of qualified images: ', len(new_ids))
torch.save(new_ids, ids_file)
return new_ids
+
"""
NUM_CHANNEL = 91
[] background
@@ -123,4 +125,3 @@ def _preprocess(self, ids, ids_file):
[7] train
[72] tv
"""
-
diff --git a/encoding/datasets/imagenet.py b/encoding/datasets/imagenet.py
new file mode 100644
index 00000000..78b375f3
--- /dev/null
+++ b/encoding/datasets/imagenet.py
@@ -0,0 +1,21 @@
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+## Created by: Hang Zhang
+## Email: zhanghang0704@gmail.com
+## Copyright (c) 2018
+##
+## This source code is licensed under the MIT-style license found in the
+## LICENSE file in the root directory of this source tree
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+import os
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+
+class ImageNetDataset(datasets.ImageFolder):
+ BASE_DIR = "ILSVRC2012"
+ def __init__(self, root=os.path.expanduser('~/.encoding/data'), transform=None,
+ target_transform=None, train=True, **kwargs):
+ split='train' if train == True else 'val'
+ root = os.path.join(root, self.BASE_DIR, split)
+ super(ImageNetDataset, self).__init__(
+ root, transform, target_transform)
diff --git a/encoding/datasets/minc.py b/encoding/datasets/minc.py
new file mode 100644
index 00000000..f64d1a4a
--- /dev/null
+++ b/encoding/datasets/minc.py
@@ -0,0 +1,63 @@
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+## Created by: Hang Zhang
+## ECE Department, Rutgers University
+## Email: zhang.hang@rutgers.edu
+## Copyright (c) 2017
+##
+## This source code is licensed under the MIT-style license found in the
+## LICENSE file in the root directory of this source tree
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+import os
+from PIL import Image
+
+import torch
+import torch.utils.data as data
+
+class MINCDataset(data.Dataset):
+ NUM_CLASS = 23
+ def __init__(self, root=os.path.expanduser('~/.encoding/data/minc-2500/'),
+ split='train', transform=None):
+ self.transform = transform
+ classes, class_to_idx = find_classes(root + '/images')
+ if split=='train':
+ filename = os.path.join(root, 'labels/train1.txt')
+ else:
+ filename = os.path.join(root, 'labels/test1.txt')
+
+ self.images, self.labels = make_dataset(filename, root,
+ class_to_idx)
+ assert (len(self.images) == len(self.labels))
+
+ def __getitem__(self, index):
+ _img = Image.open(self.images[index]).convert('RGB')
+ _label = self.labels[index]
+ if self.transform is not None:
+ _img = self.transform(_img)
+
+ return _img, _label
+
+ def __len__(self):
+ return len(self.images)
+
+def find_classes(dir):
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
+ classes.sort()
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
+ return classes, class_to_idx
+
+
+def make_dataset(filename, datadir, class_to_idx):
+ images = []
+ labels = []
+ with open(os.path.join(filename), "r") as lines:
+ for line in lines:
+ _image = os.path.join(datadir, line.rstrip('\n'))
+ _dirname = os.path.split(os.path.dirname(_image))[1]
+ assert os.path.isfile(_image)
+ label = class_to_idx[_dirname]
+ images.append(_image)
+ labels.append(label)
+
+ return images, labels
+
diff --git a/encoding/dilated/__init__.py b/encoding/dilated/__init__.py
deleted file mode 100644
index ed888108..00000000
--- a/encoding/dilated/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Dilated ResNet and DenseNet"""
-from .resnet import *
diff --git a/encoding/functions/syncbn.py b/encoding/functions/syncbn.py
index cf7dd615..e989f4a1 100644
--- a/encoding/functions/syncbn.py
+++ b/encoding/functions/syncbn.py
@@ -9,71 +9,291 @@
"""Synchronized Cross-GPU Batch Normalization functions"""
import torch
+import torch.cuda.comm as comm
from torch.autograd import Variable, Function
+from torch.autograd.function import once_differentiable
from .. import lib
-__all__ = ['sum_square', 'batchnormtrain']
-
-def sum_square(input):
- r"""Calculate sum of elements and sum of squares for Batch Normalization"""
- return _sum_square.apply(input)
+__all__ = ['moments', 'syncbatchnorm', 'inp_syncbatchnorm']
+class moments(Function):
+ @staticmethod
+ def forward(ctx, x):
+ if x.is_cuda:
+ ex, ex2 = lib.gpu.expectation_forward(x)
+ else:
+ raise NotImplemented
+ return ex, ex2
-class _sum_square(Function):
@staticmethod
- def forward(ctx, input):
- ctx.save_for_backward(input)
- if input.is_cuda:
- xsum, xsqusum = lib.gpu.sumsquare_forward(input)
+ def backward(ctx, dex, dex2):
+ if x.is_cuda:
+ dx = lib.gpu.expectation_backward(x, dex, dex2)
else:
- xsum, xsqusum = lib.cpu.sumsquare_forward(input)
- return xsum, xsqusum
+ raise NotImplemented
+ return dx
+
+class syncbatchnorm_(Function):
+ @classmethod
+ def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
+ extra, sync=True, training=True, momentum=0.1, eps=1e-05,
+ activation="none", slope=0.01):
+ # save context
+ cls._parse_extra(ctx, extra)
+ ctx.sync = sync
+ ctx.training = training
+ ctx.momentum = momentum
+ ctx.eps = eps
+ ctx.activation = activation
+ ctx.slope = slope
+ assert activation == 'none'
+
+ # continous inputs
+ x = x.contiguous()
+ gamma = gamma.contiguous()
+ beta = beta.contiguous()
+
+ if ctx.training:
+ if x.is_cuda:
+ _ex, _exs = lib.gpu.expectation_forward(x)
+ else:
+ raise NotImplemented
+
+ if ctx.sync:
+ if ctx.is_master:
+ _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
+ for _ in range(ctx.master_queue.maxsize):
+ _ex_w, _exs_w = ctx.master_queue.get()
+ ctx.master_queue.task_done()
+ _ex.append(_ex_w.unsqueeze(0))
+ _exs.append(_exs_w.unsqueeze(0))
+
+ _ex = comm.gather(_ex).mean(0)
+ _exs = comm.gather(_exs).mean(0)
+
+ tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
+ for ts, queue in zip(tensors[1:], ctx.worker_queues):
+ queue.put(ts)
+ else:
+ ctx.master_queue.put((_ex, _exs))
+ _ex, _exs = ctx.worker_queue.get()
+ ctx.worker_queue.task_done()
+
+ # Update running stats
+ _var = _exs - _ex ** 2
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
+
+ # Mark in-place modified tensors
+ ctx.mark_dirty(running_mean, running_var)
+ else:
+ _ex, _var = running_mean.contiguous(), running_var.contiguous()
+ _exs = _var + _ex ** 2
+
+ # BN forward + activation
+ if x.is_cuda:
+ y = lib.gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
+ else:
+ y = lib.cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
+
+ # Output
+ ctx.save_for_backward(x, _ex, _exs, gamma, beta)
+ return y
@staticmethod
- def backward(ctx, gradSum, gradSquare):
- input, = ctx.saved_variables
- if input.is_cuda:
- gradInput = lib.gpu.sumsquare_backward(input, gradSum, gradSquare)
+ @once_differentiable
+ def backward(ctx, dz):
+ x, _ex, _exs, gamma, beta = ctx.saved_tensors
+ dz = dz.contiguous()
+
+ # BN backward
+ if dz.is_cuda:
+ dx, _dex, _dexs, dgamma, dbeta = \
+ lib.gpu.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
else:
raise NotImplemented
- return gradInput
+ if ctx.training:
+ if ctx.sync:
+ if ctx.is_master:
+ _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
+ for _ in range(ctx.master_queue.maxsize):
+ _dex_w, _dexs_w = ctx.master_queue.get()
+ ctx.master_queue.task_done()
+ _dex.append(_dex_w.unsqueeze(0))
+ _dexs.append(_dexs_w.unsqueeze(0))
+
+ _dex = comm.gather(_dex).mean(0)
+ _dexs = comm.gather(_dexs).mean(0)
+
+ tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
+ for ts, queue in zip(tensors[1:], ctx.worker_queues):
+ queue.put(ts)
+ else:
+ ctx.master_queue.put((_dex, _dexs))
+ _dex, _dexs = ctx.worker_queue.get()
+ ctx.worker_queue.task_done()
+
+ if x.is_cuda:
+ dx_ = lib.gpu.expectation_backward(x, _dex, _dexs)
+ else:
+ raise NotImplemented
+ dx = dx + dx_
+
+ return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
-class _batchnormtrain(Function):
@staticmethod
- def forward(ctx, input, mean, std, gamma, beta):
- ctx.save_for_backward(input, mean, std, gamma, beta)
- if input.is_cuda:
- output = lib.gpu.batchnorm_forward(input, mean, std, gamma, beta)
+ def _parse_extra(ctx, extra):
+ ctx.is_master = extra["is_master"]
+ if ctx.is_master:
+ ctx.master_queue = extra["master_queue"]
+ ctx.worker_queues = extra["worker_queues"]
+ ctx.worker_ids = extra["worker_ids"]
+ else:
+ ctx.master_queue = extra["master_queue"]
+ ctx.worker_queue = extra["worker_queue"]
+
+def _act_forward(ctx, x):
+ if ctx.activation.lower() == "leaky_relu":
+ if x.is_cuda:
+ lib.gpu.leaky_relu_forward(x, ctx.slope)
+ else:
+ raise NotImplemented
+ else:
+ assert activation == 'none'
+
+def _act_backward(ctx, x, dx):
+ if ctx.activation.lower() == "leaky_relu":
+ if x.is_cuda:
+ lib.gpu.leaky_relu_backward(x, dx, ctx.slope)
else:
- output = lib.cpu.batchnorm_forward(input, mean, std, gamma, beta)
- return output
+ raise NotImplemented
+ else:
+ assert activation == 'none'
+
+class inp_syncbatchnorm_(Function):
+ @classmethod
+ def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
+ extra, sync=True, training=True, momentum=0.1, eps=1e-05,
+ activation="none", slope=0.01):
+ # save context
+ cls._parse_extra(ctx, extra)
+ ctx.sync = sync
+ ctx.training = training
+ ctx.momentum = momentum
+ ctx.eps = eps
+ ctx.activation = activation
+ ctx.slope = slope
+
+ # continous inputs
+ x = x.contiguous()
+ gamma = gamma.contiguous()
+ beta = beta.contiguous()
+
+ if ctx.training:
+ if x.is_cuda:
+ _ex, _exs = lib.gpu.expectation_forward(x)
+ else:
+ raise NotImplemented
+
+ if ctx.sync:
+ if ctx.is_master:
+ _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
+ for _ in range(ctx.master_queue.maxsize):
+ _ex_w, _exs_w = ctx.master_queue.get()
+ ctx.master_queue.task_done()
+ _ex.append(_ex_w.unsqueeze(0))
+ _exs.append(_exs_w.unsqueeze(0))
+
+ _ex = comm.gather(_ex).mean(0)
+ _exs = comm.gather(_exs).mean(0)
+
+ tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
+ for ts, queue in zip(tensors[1:], ctx.worker_queues):
+ queue.put(ts)
+ else:
+ ctx.master_queue.put((_ex, _exs))
+ _ex, _exs = ctx.worker_queue.get()
+ ctx.worker_queue.task_done()
+
+ # Update running stats
+ _var = _exs - _ex ** 2
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
+
+ # Mark in-place modified tensors
+ ctx.mark_dirty(x, running_mean, running_var)
+ else:
+ _ex, _var = running_mean.contiguous(), running_var.contiguous()
+ _exs = _var + _ex ** 2
+ ctx.mark_dirty(x)
+
+ # BN forward + activation
+ if x.is_cuda:
+ lib.gpu.batchnorm_inp_forward(x, _ex, _exs, gamma, beta, ctx.eps)
+ else:
+ raise NotImplemented
+
+ _act_forward(ctx, x)
+
+ # Output
+ ctx.save_for_backward(x, _ex, _exs, gamma, beta)
+ return x
@staticmethod
- def backward(ctx, gradOutput):
- input, mean, std, gamma, beta = ctx.saved_variables
- if gradOutput.is_cuda:
- gradInput, gradMean, gradStd, gradGamma, gradBeta = \
- lib.gpu.batchnorm_backward(gradOutput, input, mean,
- std, gamma, beta, True)
+ @once_differentiable
+ def backward(ctx, dz):
+ z, _ex, _exs, gamma, beta = ctx.saved_tensors
+ dz = dz.contiguous()
+
+ # Undo activation
+ _act_backward(ctx, z, dz)
+
+ # BN backward
+ if dz.is_cuda:
+ dx, _dex, _dexs, dgamma, dbeta = \
+ lib.gpu.batchnorm_inp_backward(dz, z, _ex, _exs, gamma, beta, ctx.eps)
else:
raise NotImplemented
- return gradInput, gradMean, gradStd, gradGamma, gradBeta
+ if ctx.training:
+ if ctx.sync:
+ if ctx.is_master:
+ _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
+ for _ in range(ctx.master_queue.maxsize):
+ _dex_w, _dexs_w = ctx.master_queue.get()
+ ctx.master_queue.task_done()
+ _dex.append(_dex_w.unsqueeze(0))
+ _dexs.append(_dexs_w.unsqueeze(0))
-def batchnormtrain(input, mean, std, gamma, beta):
- r"""Applies Batch Normalization over a 3d input that is seen as a
- mini-batch.
+ _dex = comm.gather(_dex).mean(0)
+ _dexs = comm.gather(_dexs).mean(0)
- .. _encoding.batchnormtrain:
+ tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
+ for ts, queue in zip(tensors[1:], ctx.worker_queues):
+ queue.put(ts)
+ else:
+ ctx.master_queue.put((_dex, _dexs))
+ _dex, _dexs = ctx.worker_queue.get()
+ ctx.worker_queue.task_done()
- .. math::
+ if z.is_cuda:
+ lib.gpu.expectation_inp_backward(dx, z, _dex, _dexs, _ex, _exs, gamma, beta, ctx.eps)
+ else:
+ raise NotImplemented
- y = \frac{x - \mu[x]}{ \sqrt{var[x] + \epsilon}} * \gamma + \beta
+ return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
- Shape:
- - Input: :math:`(N, C)` or :math:`(N, C, L)`
- - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
+ @staticmethod
+ def _parse_extra(ctx, extra):
+ ctx.is_master = extra["is_master"]
+ if ctx.is_master:
+ ctx.master_queue = extra["master_queue"]
+ ctx.worker_queues = extra["worker_queues"]
+ ctx.worker_ids = extra["worker_ids"]
+ else:
+ ctx.master_queue = extra["master_queue"]
+ ctx.worker_queue = extra["worker_queue"]
- """
- return _batchnormtrain.apply(input, mean, std, gamma, beta)
+syncbatchnorm = syncbatchnorm_.apply
+inp_syncbatchnorm = inp_syncbatchnorm_.apply
diff --git a/encoding/lib/__init__.py b/encoding/lib/__init__.py
index ff821e05..5675dfc6 100644
--- a/encoding/lib/__init__.py
+++ b/encoding/lib/__init__.py
@@ -17,9 +17,11 @@
if torch.cuda.is_available():
gpu = load('enclib_gpu', [
os.path.join(gpu_path, 'operator.cpp'),
+ os.path.join(gpu_path, 'activation_kernel.cu'),
os.path.join(gpu_path, 'encoding_kernel.cu'),
os.path.join(gpu_path, 'encodingv2_kernel.cu'),
os.path.join(gpu_path, 'syncbn_kernel.cu'),
os.path.join(gpu_path, 'roi_align_kernel.cu'),
os.path.join(gpu_path, 'nms_kernel.cu'),
- ], build_directory=gpu_path, verbose=False)
+ ], extra_cuda_cflags=["--expt-extended-lambda"],
+ build_directory=gpu_path, verbose=False)
diff --git a/encoding/lib/cpu/__init__.py b/encoding/lib/cpu/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/encoding/lib/cpu/nms_cpu.cpp b/encoding/lib/cpu/nms_cpu.cpp
index 82f1c7b9..d078f30e 100644
--- a/encoding/lib/cpu/nms_cpu.cpp
+++ b/encoding/lib/cpu/nms_cpu.cpp
@@ -1,4 +1,4 @@
-#include
+#include
#include
#include
diff --git a/encoding/lib/cpu/roi_align.cpp b/encoding/lib/cpu/roi_align.cpp
deleted file mode 100644
index bfbbafff..00000000
--- a/encoding/lib/cpu/roi_align.cpp
+++ /dev/null
@@ -1,28 +0,0 @@
-#include
-// CPU declarations
-
-at::Tensor ROIAlignForwardCPU(
- const at::Tensor& input,
- const at::Tensor& bottom_rois,
- int64_t pooled_height,
- int64_t pooled_width,
- double spatial_scale,
- int64_t sampling_ratio);
-
-at::Tensor ROIAlignBackwardCPU(
- const at::Tensor& bottom_rois,
- const at::Tensor& grad_output, // gradient of the output of the layer
- int64_t b_size,
- int64_t channels,
- int64_t height,
- int64_t width,
- int64_t pooled_height,
- int64_t pooled_width,
- double spatial_scale,
- int64_t sampling_ratio);
-
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("roi_align_forward", &ROIAlignForwardCPU, "ROI Align forward (CPU)");
- m.def("roi_align_backward", &ROIAlignBackwardCPU, "ROI Align backward (CPU)");
-}
diff --git a/encoding/lib/cpu/roi_align_cpu.cpp b/encoding/lib/cpu/roi_align_cpu.cpp
index 4472bc59..52a4295b 100644
--- a/encoding/lib/cpu/roi_align_cpu.cpp
+++ b/encoding/lib/cpu/roi_align_cpu.cpp
@@ -1,4 +1,4 @@
-#include
+#include
#include
//#include
diff --git a/encoding/lib/cpu/syncbn_cpu.cpp b/encoding/lib/cpu/syncbn_cpu.cpp
index 64cf5fbe..10e4dea2 100644
--- a/encoding/lib/cpu/syncbn_cpu.cpp
+++ b/encoding/lib/cpu/syncbn_cpu.cpp
@@ -1,4 +1,4 @@
-#include
+#include
#include
#include
diff --git a/encoding/lib/gpu/__init__.py b/encoding/lib/gpu/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/encoding/lib/gpu/activation_kernel.cu b/encoding/lib/gpu/activation_kernel.cu
new file mode 100644
index 00000000..d58118d7
--- /dev/null
+++ b/encoding/lib/gpu/activation_kernel.cu
@@ -0,0 +1,45 @@
+#include
+#include
+#include
+// #include
+
+#include
+
+#include
+#include
+
+
+namespace {
+
+template
+inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
+ // Create thrust pointers
+ thrust::device_ptr th_z = thrust::device_pointer_cast(z);
+ thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
+
+ thrust::transform_if(th_dz, th_dz + count, th_z, th_dz,
+ [slope] __device__ (const T& dz) { return dz * slope; },
+ [] __device__ (const T& z) { return z < 0; });
+ thrust::transform_if(th_z, th_z + count, th_z,
+ [slope] __device__ (const T& z) { return z / slope; },
+ [] __device__ (const T& z) { return z < 0; });
+}
+
+}
+
+void LeakyRelu_Forward_CUDA(at::Tensor z, float slope) {
+ at::leaky_relu_(z, slope);
+}
+
+void LeakyRelu_Backward_CUDA(at::Tensor z, at::Tensor dz, float slope) {
+ int64_t count = z.numel();
+
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "LeakyRelu_Backward_CUDA", ([&] {
+ leaky_relu_backward_impl(z.data(), dz.data(), slope, count);
+ }));
+ /*
+ // unstable after scaling
+ at::leaky_relu_(z, 1.0 / slope);
+ at::leaky_relu_backward(dz, z, slope);
+ */
+}
diff --git a/encoding/lib/gpu/encoding_kernel.cu b/encoding/lib/gpu/encoding_kernel.cu
index bd40e151..a3e91c55 100644
--- a/encoding/lib/gpu/encoding_kernel.cu
+++ b/encoding/lib/gpu/encoding_kernel.cu
@@ -1,5 +1,5 @@
#include
-#include
+#include
#include
#include
diff --git a/encoding/lib/gpu/encodingv2_kernel.cu b/encoding/lib/gpu/encodingv2_kernel.cu
index 97330309..068c2bd5 100644
--- a/encoding/lib/gpu/encodingv2_kernel.cu
+++ b/encoding/lib/gpu/encodingv2_kernel.cu
@@ -1,5 +1,5 @@
#include
-#include
+#include
#include
#include
#include
diff --git a/encoding/lib/gpu/nms_kernel.cu b/encoding/lib/gpu/nms_kernel.cu
index 464d0a6e..9c350a7f 100644
--- a/encoding/lib/gpu/nms_kernel.cu
+++ b/encoding/lib/gpu/nms_kernel.cu
@@ -1,4 +1,4 @@
-#include
+#include
#include
#include "ATen/NativeFunctions.h"
#include
diff --git a/encoding/lib/gpu/operator.cpp b/encoding/lib/gpu/operator.cpp
index 3faae98d..5d21a16e 100644
--- a/encoding/lib/gpu/operator.cpp
+++ b/encoding/lib/gpu/operator.cpp
@@ -9,9 +9,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scaled_l2_forward", &ScaledL2_Forward_CUDA, "ScaledL2 forward (CUDA)");
m.def("scaled_l2_backward", &ScaledL2_Backward_CUDA, "ScaledL2 backward (CUDA)");
m.def("batchnorm_forward", &BatchNorm_Forward_CUDA, "BatchNorm forward (CUDA)");
+ m.def("batchnorm_inp_forward", &BatchNorm_Forward_Inp_CUDA, "BatchNorm forward (CUDA)");
m.def("batchnorm_backward", &BatchNorm_Backward_CUDA, "BatchNorm backward (CUDA)");
- m.def("sumsquare_forward", &Sum_Square_Forward_CUDA, "SumSqu forward (CUDA)");
- m.def("sumsquare_backward", &Sum_Square_Backward_CUDA, "SumSqu backward (CUDA)");
+ m.def("batchnorm_inp_backward", &BatchNorm_Inp_Backward_CUDA, "BatchNorm backward (CUDA)");
+ m.def("expectation_forward", &Expectation_Forward_CUDA, "Expectation forward (CUDA)");
+ m.def("expectation_backward", &Expectation_Backward_CUDA, "Expectation backward (CUDA)");
+ m.def("expectation_inp_backward", &Expectation_Inp_Backward_CUDA,
+ "Inplace Expectation backward (CUDA)");
m.def("encoding_dist_forward", &Encoding_Dist_Forward_CUDA, "EncDist forward (CUDA)");
m.def("encoding_dist_backward", &Encoding_Dist_Backward_CUDA, "Assign backward (CUDA)");
m.def("encoding_dist_inference_forward", &Encoding_Dist_Inference_Forward_CUDA,
@@ -20,4 +24,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Assign Inference backward (CUDA)");
m.def("aggregatev2_forward", &AggregateV2_Forward_CUDA, "AggregateV2 forward (CUDA)");
m.def("aggregatev2_backward", &AggregateV2_Backward_CUDA, "AggregateV2 backward (CUDA)");
+ m.def("leaky_relu_forward", &LeakyRelu_Forward_CUDA, "Learky ReLU forward (CUDA)");
+ m.def("leaky_relu_backward", &LeakyRelu_Backward_CUDA, "Learky ReLU backward (CUDA)");
}
diff --git a/encoding/lib/gpu/operator.h b/encoding/lib/gpu/operator.h
index 67e2972f..64dbe1de 100644
--- a/encoding/lib/gpu/operator.h
+++ b/encoding/lib/gpu/operator.h
@@ -1,4 +1,4 @@
-#include
+#include
#include
at::Tensor ROIAlign_Forward_CUDA(
@@ -54,36 +54,65 @@ at::Tensor BatchNorm_Forward_CUDA(
const at::Tensor mean_,
const at::Tensor std_,
const at::Tensor gamma_,
- const at::Tensor beta_);
+ const at::Tensor beta_,
+ float eps);
+
+at::Tensor BatchNorm_Forward_Inp_CUDA(
+ const at::Tensor input_,
+ const at::Tensor ex_,
+ const at::Tensor exs_,
+ const at::Tensor gamma_,
+ const at::Tensor beta_,
+ float eps);
std::vector BatchNorm_Backward_CUDA(
const at::Tensor gradoutput_,
const at::Tensor input_,
- const at::Tensor mean_,
- const at::Tensor std_,
+ const at::Tensor ex_,
+ const at::Tensor exs_,
+ const at::Tensor gamma_,
+ const at::Tensor beta_,
+ float eps);
+
+std::vector BatchNorm_Inp_Backward_CUDA(
+ const at::Tensor gradoutput_,
+ const at::Tensor output_,
+ const at::Tensor ex_,
+ const at::Tensor exs_,
const at::Tensor gamma_,
- const at::Tensor beta_,
- bool train);
+ const at::Tensor beta_,
+ float eps);
-std::vector Sum_Square_Forward_CUDA(
+std::vector Expectation_Forward_CUDA(
const at::Tensor input_);
-at::Tensor Sum_Square_Backward_CUDA(
+at::Tensor Expectation_Backward_CUDA(
const at::Tensor input_,
- const at::Tensor gradSum_,
- const at::Tensor gradSquare_);
+ const at::Tensor gradEx_,
+ const at::Tensor gradExs_);
+
+at::Tensor Expectation_Inp_Backward_CUDA(
+ const at::Tensor gradInput_,
+ const at::Tensor output_,
+ const at::Tensor gradEx_,
+ const at::Tensor gradExs_,
+ const at::Tensor ex_,
+ const at::Tensor exs_,
+ const at::Tensor gamma_,
+ const at::Tensor beta_,
+ float eps);
at::Tensor Encoding_Dist_Inference_Forward_CUDA(
- const at::Tensor X_,
- const at::Tensor C_,
- const at::Tensor STD_);
+ const at::Tensor X_,
+ const at::Tensor C_,
+ const at::Tensor STD_);
std::vector Encoding_Dist_Inference_Backward_CUDA(
- const at::Tensor GKD_,
- const at::Tensor KD_,
- const at::Tensor X_,
- const at::Tensor C_,
- const at::Tensor STD_);
+ const at::Tensor GKD_,
+ const at::Tensor KD_,
+ const at::Tensor X_,
+ const at::Tensor C_,
+ const at::Tensor STD_);
std::vector Encoding_Dist_Forward_CUDA(
const at::Tensor X,
@@ -91,12 +120,12 @@ std::vector Encoding_Dist_Forward_CUDA(
double eps);
std::vector Encoding_Dist_Backward_CUDA(
- const at::Tensor GKD_,
- const at::Tensor GSTD_,
- const at::Tensor KD_,
- const at::Tensor X_,
- const at::Tensor C_,
- const at::Tensor STD_);
+ const at::Tensor GKD_,
+ const at::Tensor GSTD_,
+ const at::Tensor KD_,
+ const at::Tensor X_,
+ const at::Tensor C_,
+ const at::Tensor STD_);
at::Tensor AggregateV2_Forward_CUDA(
const at::Tensor A_,
@@ -111,3 +140,7 @@ std::vector AggregateV2_Backward_CUDA(
const at::Tensor X_,
const at::Tensor C_,
const at::Tensor STD_);
+
+void LeakyRelu_Forward_CUDA(at::Tensor z, float slope);
+
+void LeakyRelu_Backward_CUDA(at::Tensor z, at::Tensor dz, float slope);
diff --git a/encoding/lib/gpu/roi_align_kernel.cu b/encoding/lib/gpu/roi_align_kernel.cu
index c55ee841..3c033537 100644
--- a/encoding/lib/gpu/roi_align_kernel.cu
+++ b/encoding/lib/gpu/roi_align_kernel.cu
@@ -1,4 +1,4 @@
-#include
+#include
#include
#include
diff --git a/encoding/lib/gpu/setup.py b/encoding/lib/gpu/setup.py
index 924b9998..f0ac8169 100644
--- a/encoding/lib/gpu/setup.py
+++ b/encoding/lib/gpu/setup.py
@@ -6,6 +6,7 @@
ext_modules=[
CUDAExtension('enclib_gpu', [
'operator.cpp',
+ 'activation_kernel.cu',
'encoding_kernel.cu',
'encodingv2_kernel.cu',
'syncbn_kernel.cu',
diff --git a/encoding/lib/gpu/syncbn_kernel.cu b/encoding/lib/gpu/syncbn_kernel.cu
index 930bb953..ed509869 100644
--- a/encoding/lib/gpu/syncbn_kernel.cu
+++ b/encoding/lib/gpu/syncbn_kernel.cu
@@ -1,5 +1,5 @@
#include
-#include
+#include
#include
#include
@@ -11,14 +11,14 @@ namespace {
template
struct GradOp {
__device__ GradOp(Acctype m, const DeviceTensor3 i, const DeviceTensor3 g)
- : mean(m), input(i), gradOutput(g) {}
+ : beta(m), output(i), gradOutput(g) {}
__device__ __forceinline__ Float2 operator()(int batch, int plane, int n) {
DType g = gradOutput[batch][plane][n];
- DType c = ScalarConvert::to(input[batch][plane][n] - mean);
+ DType c = ScalarConvert::to(output[batch][plane][n] - beta);
return Float2(g, g * c);
}
- const Acctype mean;
- const DeviceTensor3 input;
+ const Acctype beta;
+ const DeviceTensor3 output;
const DeviceTensor3 gradOutput;
};
@@ -88,6 +88,72 @@ __global__ void BatchNorm_Forward_kernel (
}
}
+template
+__global__ void BatchNorm_Forward_Inp_kernel (
+ DeviceTensor input,
+ DeviceTensor mean,
+ DeviceTensor std,
+ DeviceTensor gamma,
+ DeviceTensor beta) {
+ int c = blockIdx.x;
+ /* main operation */
+ for (int b = 0; b < input.getSize(0); ++b) {
+ for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
+ DType inp = input[b][c][x];
+ input[b][c][x] = gamma[c] * (inp - mean[c]) /
+ std[c] + beta[c];
+ }
+ }
+}
+
+template
+__global__ void BatchNorm_Backward_Inp_kernel (
+ DeviceTensor gradoutput,
+ DeviceTensor output,
+ DeviceTensor gradinput,
+ DeviceTensor gradgamma,
+ DeviceTensor gradbeta,
+ DeviceTensor mean,
+ DeviceTensor std,
+ DeviceTensor gamma,
+ DeviceTensor beta,
+ DeviceTensor gradEx,
+ DeviceTensor gradExs) {
+ /* declarations of the variables */
+ /* Get the index and channels */
+ int c = blockIdx.x;
+ /* main operation */
+ GradOp> g(beta[c], output, gradoutput);
+ Float2 res = reduce,
+ GradOp>,
+ DeviceTensor>(g, gradoutput, c);
+ DType gradOutputSum = res.v1;
+ DType dotP = res.v2;
+ DType invstd = DType(1.0) / std[c];
+ DType gradScale = invstd * gamma[c];
+ if (threadIdx.x == 0) {
+ gradEx[c] = - gradOutputSum * gradScale + mean[c] * invstd * invstd * dotP;
+ gradExs[c] = - 0.5 * invstd * invstd * dotP;
+ }
+ if (gradinput.numElements() > 0) {
+ for (int batch = 0; batch < gradoutput.getSize(0); ++batch) {
+ for (int x = threadIdx.x; x < gradoutput.getSize(2); x += blockDim.x) {
+ gradinput[batch][c][x] = gradoutput[batch][c][x] * gradScale;
+ }
+ }
+ }
+ if (gradgamma.numElements() > 0) {
+ if (threadIdx.x == 0) {
+ gradgamma[c] += dotP / gamma[c];
+ }
+ }
+ if (gradbeta.numElements() > 0) {
+ if (threadIdx.x == 0) {
+ gradbeta[c] += gradOutputSum;
+ }
+ }
+}
+
template
__global__ void BatchNorm_Backward_kernel (
DeviceTensor gradoutput,
@@ -99,9 +165,8 @@ __global__ void BatchNorm_Backward_kernel (
DeviceTensor std,
DeviceTensor gamma,
DeviceTensor beta,
- DeviceTensor gradMean,
- DeviceTensor gradStd,
- bool train) {
+ DeviceTensor gradEx,
+ DeviceTensor gradExs) {
/* declarations of the variables */
/* Get the index and channels */
int c = blockIdx.x;
@@ -114,9 +179,9 @@ __global__ void BatchNorm_Backward_kernel (
DType dotP = res.v2;
DType invstd = DType(1.0) / std[c];
DType gradScale = invstd * gamma[c];
- if (train && threadIdx.x == 0) {
- gradMean[c] = - gradOutputSum * gamma[c] * invstd;
- gradStd[c] = - dotP * gamma[c] * invstd * invstd;
+ if (threadIdx.x == 0) {
+ gradEx[c] = - gradOutputSum * gradScale + mean[c] * invstd * invstd * dotP * gradScale;
+ gradExs[c] = - 0.5 * invstd * invstd * dotP * gradScale;
}
if (gradinput.numElements() > 0) {
for (int batch = 0; batch < gradoutput.getSize(0); ++batch) {
@@ -139,10 +204,11 @@ __global__ void BatchNorm_Backward_kernel (
template
-__global__ void Sum_Square_Forward_kernel (
+__global__ void Expectation_Forward_kernel (
DeviceTensor input,
- DeviceTensor sum,
- DeviceTensor square) {
+ DeviceTensor ex,
+ DeviceTensor exs,
+ DType norm) {
int c = blockIdx.x;
/* main operation */
SumOp g(input);
@@ -151,37 +217,60 @@ __global__ void Sum_Square_Forward_kernel (
DType xsum = res.v1;
DType xsquare = res.v2;
if (threadIdx.x == 0) {
- sum[c] = xsum;
- square[c] = xsquare;
+ ex[c] = xsum * norm;
+ exs[c] = xsquare * norm;
}
}
template
-__global__ void Sum_Square_Backward_kernel (
+__global__ void Expectation_Backward_kernel (
DeviceTensor gradInput,
DeviceTensor input,
- DeviceTensor gradSum,
- DeviceTensor gradSquare) {
+ DeviceTensor gradEx,
+ DeviceTensor gradExs,
+ DType norm) {
+ int c = blockIdx.x;
+ /* main operation */
+ for (int batch = 0; batch < gradInput.getSize(0); ++batch) {
+ for (int x = threadIdx.x; x < gradInput.getSize(2); x += blockDim.x) {
+ gradInput[batch][c][x] = gradEx[c] * norm + 2 * gradExs[c] *
+ input[batch][c][x] * norm;
+ }
+ }
+}
+
+template
+__global__ void Expectation_Backward_Inp_kernel (
+ DeviceTensor gradInput,
+ DeviceTensor output,
+ DeviceTensor gradEx,
+ DeviceTensor gradExs,
+ DeviceTensor mean,
+ DeviceTensor std,
+ DeviceTensor gamma,
+ DeviceTensor beta,
+ DType norm) {
int c = blockIdx.x;
/* main operation */
for (int batch = 0; batch < gradInput.getSize(0); ++batch) {
- for (int x = threadIdx.x; x < gradInput.getSize(2); x += blockDim.x)
- {
- gradInput[batch][c][x] = gradSum[c] + 2 * gradSquare[c] *
- input[batch][c][x];
+ for (int x = threadIdx.x; x < gradInput.getSize(2); x += blockDim.x) {
+ gradInput[batch][c][x] += gradEx[c] * norm + 2 * gradExs[c] *
+ ((output[batch][c][x] - beta[c]) / gamma[c] * std[c] + mean[c]) * norm;
}
- }
+ }
}
-} // namespcae
+} // namespace
at::Tensor BatchNorm_Forward_CUDA(
const at::Tensor input_,
- const at::Tensor mean_,
- const at::Tensor std_,
+ const at::Tensor ex_,
+ const at::Tensor exs_,
const at::Tensor gamma_,
- const at::Tensor beta_) {
+ const at::Tensor beta_,
+ float eps) {
auto output_ = at::zeros_like(input_);
+ auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
@@ -189,85 +278,157 @@ at::Tensor BatchNorm_Forward_CUDA(
/* Device tensors */
DeviceTensor output = devicetensor(output_);
DeviceTensor input = devicetensor(input_);
- DeviceTensor mean = devicetensor(mean_);
+ DeviceTensor ex = devicetensor(ex_);
DeviceTensor std = devicetensor(std_);
DeviceTensor gamma = devicetensor(gamma_);
DeviceTensor beta = devicetensor(beta_);
/* kernel function */
BatchNorm_Forward_kernel<<>>(
- output, input, mean, std, gamma, beta);
+ output, input, ex, std, gamma, beta);
}));
AT_ASSERT(cudaGetLastError() == cudaSuccess);
return output_;
}
+at::Tensor BatchNorm_Forward_Inp_CUDA(
+ const at::Tensor input_,
+ const at::Tensor ex_,
+ const at::Tensor exs_,
+ const at::Tensor gamma_,
+ const at::Tensor beta_,
+ float eps) {
+ auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ dim3 blocks(input_.size(1));
+ dim3 threads(getNumThreads(input_.size(2)));
+ AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Forward_CUDA", ([&] {
+ /* Device tensors */
+ DeviceTensor input = devicetensor(input_);
+ DeviceTensor ex = devicetensor(ex_);
+ DeviceTensor std = devicetensor(std_);
+ DeviceTensor gamma = devicetensor(gamma_);
+ DeviceTensor beta = devicetensor(beta_);
+ /* kernel function */
+ BatchNorm_Forward_Inp_kernel<<>>(
+ input, ex, std, gamma, beta);
+ }));
+ AT_ASSERT(cudaGetLastError() == cudaSuccess);
+ return input_;
+}
+
+
+std::vector BatchNorm_Inp_Backward_CUDA(
+ const at::Tensor gradoutput_,
+ const at::Tensor output_,
+ const at::Tensor ex_,
+ const at::Tensor exs_,
+ const at::Tensor gamma_,
+ const at::Tensor beta_,
+ float eps) {
+ /* outputs*/
+ auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
+ auto gradinput_ = at::zeros_like(output_);
+ auto gradgamma_ = at::zeros_like(gamma_);
+ auto gradbeta_ = at::zeros_like(beta_);
+ auto gradEx_ = at::zeros_like(ex_);
+ auto gradExs_ = at::zeros_like(std_);
+ /* cuda utils*/
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ dim3 blocks(output_.size(1));
+ dim3 threads(getNumThreads(output_.size(2)));
+ AT_DISPATCH_FLOATING_TYPES(output_.type(), "BatchNorm_Inp_Backward_CUDA", ([&] {
+ /* Device tensors */
+ DeviceTensor gradoutput = devicetensor(gradoutput_);
+ DeviceTensor output = devicetensor(output_);
+ DeviceTensor gradinput = devicetensor(gradinput_);
+ DeviceTensor gradgamma = devicetensor(gradgamma_);
+ DeviceTensor gradbeta = devicetensor(gradbeta_);
+ DeviceTensor ex = devicetensor(ex_);
+ DeviceTensor std = devicetensor(std_);
+ DeviceTensor gamma = devicetensor(gamma_);
+ DeviceTensor beta = devicetensor(beta_);
+ DeviceTensor gradEx = devicetensor(gradEx_);
+ DeviceTensor gradExs = devicetensor(gradExs_);
+ /* kernel function */
+ BatchNorm_Backward_Inp_kernel
+ <<>>(
+ gradoutput, output, gradinput, gradgamma, gradbeta, ex, std,
+ gamma, beta, gradEx, gradExs);
+ }));
+ AT_ASSERT(cudaGetLastError() == cudaSuccess);
+ return {gradinput_, gradEx_, gradExs_, gradgamma_, gradbeta_};
+}
+
+
std::vector BatchNorm_Backward_CUDA(
const at::Tensor gradoutput_,
const at::Tensor input_,
- const at::Tensor mean_,
- const at::Tensor std_,
+ const at::Tensor ex_,
+ const at::Tensor exs_,
const at::Tensor gamma_,
- const at::Tensor beta_,
- bool train) {
+ const at::Tensor beta_,
+ float eps) {
/* outputs*/
- at::Tensor gradinput_ = at::zeros_like(input_);
- at::Tensor gradgamma_ = at::zeros_like(gamma_);
- at::Tensor gradbeta_ = at::zeros_like(beta_);
- at::Tensor gradMean_ = at::zeros_like(mean_);
- at::Tensor gradStd_ = at::zeros_like(std_);
+ auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
+ auto gradinput_ = at::zeros_like(input_);
+ auto gradgamma_ = at::zeros_like(gamma_);
+ auto gradbeta_ = at::zeros_like(beta_);
+ auto gradEx_ = at::zeros_like(ex_);
+ auto gradExs_ = at::zeros_like(std_);
/* cuda utils*/
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
- AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Backward_CUDA", ([&] {
+ AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Inp_Backward_CUDA", ([&] {
/* Device tensors */
DeviceTensor gradoutput = devicetensor(gradoutput_);
DeviceTensor input = devicetensor(input_);
DeviceTensor gradinput = devicetensor(gradinput_);
DeviceTensor gradgamma = devicetensor(gradgamma_);
DeviceTensor gradbeta = devicetensor(gradbeta_);
- DeviceTensor mean = devicetensor(mean_);
+ DeviceTensor ex = devicetensor(ex_);
DeviceTensor std = devicetensor(std_);
DeviceTensor gamma = devicetensor(gamma_);
DeviceTensor beta = devicetensor(beta_);
- DeviceTensor gradMean = devicetensor(gradMean_);
- DeviceTensor gradStd = devicetensor(gradStd_);
+ DeviceTensor gradEx = devicetensor(gradEx_);
+ DeviceTensor gradExs = devicetensor(gradExs_);
/* kernel function */
BatchNorm_Backward_kernel
<<>>(
- gradoutput, input, gradinput, gradgamma, gradbeta, mean, std,
- gamma, beta, gradMean, gradStd, train);
+ gradoutput, input, gradinput, gradgamma, gradbeta, ex, std,
+ gamma, beta, gradEx, gradExs);
}));
AT_ASSERT(cudaGetLastError() == cudaSuccess);
- return {gradinput_, gradMean_, gradStd_, gradgamma_, gradbeta_};
+ return {gradinput_, gradEx_, gradExs_, gradgamma_, gradbeta_};
}
-std::vector Sum_Square_Forward_CUDA(
+std::vector Expectation_Forward_CUDA(
const at::Tensor input_) {
/* outputs */
- at::Tensor sum_ = torch::zeros({input_.size(1)}, input_.options());
- at::Tensor square_ = torch::zeros({input_.size(1)}, input_.options());
+ auto ex_ = torch::zeros({input_.size(1)}, input_.options());
+ auto exs_ = torch::zeros({input_.size(1)}, input_.options());
/* cuda utils*/
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "SumSquare_forward_CUDA", ([&] {
+ scalar_t norm = scalar_t(1) / (input_.size(0) * input_.size(2));
/* Device tensors */
DeviceTensor input = devicetensor(input_);
- DeviceTensor sum = devicetensor(sum_);
- DeviceTensor square = devicetensor(square_);
+ DeviceTensor ex = devicetensor(ex_);
+ DeviceTensor exs = devicetensor(exs_);
/* kernel function */
- Sum_Square_Forward_kernel
- <<>>(input, sum, square);
+ Expectation_Forward_kernel
+ <<>>(input, ex, exs, norm);
}));
AT_ASSERT(cudaGetLastError() == cudaSuccess);
- return {sum_, square_};
+ return {ex_, exs_};
}
-at::Tensor Sum_Square_Backward_CUDA(
+at::Tensor Expectation_Backward_CUDA(
const at::Tensor input_,
- const at::Tensor gradSum_,
- const at::Tensor gradSquare_) {
+ const at::Tensor gradEx_,
+ const at::Tensor gradExs_) {
/* outputs */
at::Tensor gradInput_ = at::zeros_like(input_);
/* cuda utils*/
@@ -275,14 +436,52 @@ at::Tensor Sum_Square_Backward_CUDA(
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "SumSquare_Backward_CUDA", ([&] {
+ scalar_t norm = scalar_t(1) / (input_.size(0) * input_.size(2));
/* Device tensors */
DeviceTensor gradInput = devicetensor(gradInput_);
DeviceTensor input = devicetensor(input_);
- DeviceTensor gradSum = devicetensor(gradSum_);
- DeviceTensor gradSquare =devicetensor(gradSquare_);
+ DeviceTensor gradEx = devicetensor(gradEx_);
+ DeviceTensor gradExs =devicetensor(gradExs_);
+ /* kernel function */
+ Expectation_Backward_kernel
+ <<>>(gradInput, input, gradEx, gradExs, norm);
+ }));
+ AT_ASSERT(cudaGetLastError() == cudaSuccess);
+ return gradInput_;
+}
+
+at::Tensor Expectation_Inp_Backward_CUDA(
+ const at::Tensor gradInput_,
+ const at::Tensor output_,
+ const at::Tensor gradEx_,
+ const at::Tensor gradExs_,
+ const at::Tensor ex_,
+ const at::Tensor exs_,
+ const at::Tensor gamma_,
+ const at::Tensor beta_,
+ float eps) {
+ /* outputs */
+ //auto gradInput_ = at::zeros_like(output_);
+ auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
+ /* cuda utils*/
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ dim3 blocks(output_.size(1));
+ dim3 threads(getNumThreads(output_.size(2)));
+ AT_DISPATCH_FLOATING_TYPES(output_.type(), "SumSquare_Backward_CUDA", ([&] {
+ scalar_t norm = scalar_t(1) / (output_.size(0) * output_.size(2));
+ /* Device tensors */
+ DeviceTensor gradInput = devicetensor(gradInput_);
+ DeviceTensor input = devicetensor(output_);
+ DeviceTensor gradEx = devicetensor(gradEx_);
+ DeviceTensor gradExs =devicetensor(gradExs_);
+ DeviceTensor ex = devicetensor(ex_);
+ DeviceTensor std = devicetensor(std_);
+ DeviceTensor gamma = devicetensor(gamma_);
+ DeviceTensor beta = devicetensor(beta_);
/* kernel function */
- Sum_Square_Backward_kernel
- <<