Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pruning #14

Open
Cydia2018 opened this issue Dec 11, 2021 · 9 comments
Open

Pruning #14

Cydia2018 opened this issue Dec 11, 2021 · 9 comments

Comments

@Cydia2018
Copy link

作者可否将剪枝部分的代码也放出,对论文剪枝部分的内容和目前的代码还存在一些疑惑,感谢。

@fxmeng
Copy link
Owner

fxmeng commented Dec 24, 2021

作者可否将剪枝部分的代码也放出,对论文剪枝部分的内容和目前的代码还存在一些疑惑,感谢。

之前没有打算放剪枝的代码,一方面是代码比较乱,另一反面觉得我的实现方式有点简单,RMNet剪枝应该有更好的表现,希望别人在我放出来的模型基础上能实现的更好。不过既然有需要,我还是把代码整理了出来:
https://github.com/fxmeng/RMNet/blob/242f849c6e5e891646bbc90f89310268d183c310/train_pruning.py

@Cydia2018
Copy link
Author

作者可否将剪枝部分的代码也放出,对论文剪枝部分的内容和目前的代码还存在一些疑惑,感谢。

之前没有打算放剪枝的代码,一方面是代码比较乱,另一反面觉得我的实现方式有点简单,RMNet剪枝应该有更好的表现,希望别人在我放出来的模型基础上能实现的更好。不过既然有需要,我还是把代码整理了出来: https://github.com/fxmeng/RMNet/blob/242f849c6e5e891646bbc90f89310268d183c310/train_pruning.py

感谢您的工作!

@Cydia2018
Copy link
Author

Cydia2018 commented Dec 29, 2021

您好,我跑了剪枝训练的代码,发现不收敛。于是根据Network Slimming的思想做了一些改动

def update_mask(self,sr,threshold):
    for m in self.modules():
        if isinstance(m,nn.Conv2d):
            if m.kernel_size==(1,1) and m.groups!=1:
                m.weight.grad.data.add_(sr * torch.sign(m.weight.data))
                # m1 = m.weight.data.abs()>threshold
                # m.weight.grad.data*=m1
                # m.weight.data*=m1
def prune(self,use_bn=True,threshold=0.1):
    features=[]
    in_mask=torch.ones(3)>0
    blocks=self.deploy()
    for i,m in enumerate(blocks):
        if isinstance(m,nn.BatchNorm2d):
            mask=m.weight.data.abs().reshape(-1)>threshold
            ...

从头稀疏训练res18,lr=0.1,sr=1e-4,cifar10上的效果如下:

thresh params flops acc(%)
原模型 - 15.38M 803.75M 94.96
修剪之后 2e-3 4.06M 397.52M 94.81

@fxmeng
Copy link
Owner

fxmeng commented Dec 29, 2021

如果不收敛,说明sr和threshold设置的太大了,建议调整一下这两个值再试试。另外也鼓励尝试更多剪枝方案,只需要注意,新增的通道也需要进行裁剪。

@fxmeng
Copy link
Owner

fxmeng commented Jan 5, 2022

您好!请问在训练/减枝/微调训练中的参数是怎么设置的?

The sparsity factor is selected from 1e-4 to 1e-3, and the threshold is selected from 5e-4 to 5e-3.

@Serissa
Copy link

Serissa commented Jan 5, 2022

谢谢大佬,我对模型训练和减枝训练存在一些疑惑,以下训练过程中的参数设置是否正确?感谢!
训练模型:python train_pruning.py --lr 0.1 --sr 1e-4 --threshold 2e-3 --finetune None --debn False --eval None
测试模型:python train_pruning.py --eval xxx/ckpt.pth
减枝训练:python train_pruning.py --lr 0.1 --sr 1e-4 --threshold 2e-3 --finetune xxx/ckpt.pth --debn False --eval None

@fxmeng
Copy link
Owner

fxmeng commented Jan 5, 2022

谢谢大佬,我对模型训练和减枝训练存在一些疑惑,以下训练过程中的参数设置是否正确?感谢! 训练模型:python train_pruning.py --lr 0.1 --sr 1e-4 --threshold 2e-3 --finetune None --debn False --eval None 测试模型:python train_pruning.py --eval xxx/ckpt.pth 减枝训练:python train_pruning.py --lr 0.1 --sr 1e-4 --threshold 2e-3 --finetune xxx/ckpt.pth --debn False --eval None

可以的,我当时调参的范围是:
1e-3 >=sr>=1e-4, 5e-3 >= threshold >= 5e-4,
你在这个范围内调一调,应该没问题,另外finetune的时候,需要减小learning rate,到大概0.01左右。

@fxmeng
Copy link
Owner

fxmeng commented Jan 19, 2022

您好,我把如下代码的改成RM形式,请问您帮我确认一下是否正确。谢谢! `class Bottleneck(nn.Module):

# Standard bottleneck
def __init__(self, c1, c2, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
    super(Bottleneck, self).__init__()
    c_ = int(c2 * e)  # hidden channels
    self.cv1 = nn.Conv2d(c1, c_, 1, 1)
    self.bn1 = nn.BatchNorm2d(c_)
    self.relu1 = nn.ReLU(inplace=True)
    self.cv2 = nn.Conv2d(c_, c2, 3, 1)
    self.bn2 = nn.BatchNorm2d(c_)
    self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
    return x +self.relu2(self.bn2( self.cv2(self.relu1(self.bn1(self.cv1(x))))))`

改成RM形式 `class Bottleneck(nn.Module):

# Standard bottleneck
def __init__(self, c1, c2, e=0.5): 
    super(Bottleneck, self).__init__()
    c_ = int(c2 * e)  # hidden channels
    self.in_planes1 = c1
    self.out_planes1 = c1 + c_
    self.in_planes2 = c1 + c_
    self.out_planes2 = c1 + c2
    self.out_planes = c2

    self.conv1 = nn.Conv2d(self.in_planes1, self.out_planes1-c1, kernel_size=1, stride=1, bias=False)
    self.bn1 = nn.BatchNorm2d(self.out_planes1-c1)
    self.mask1 = nn.Conv2d(self.out_planes1-c1, self.out_planes1-c1, 1, groups=self.out_planes1-c1, bias=False)
    self.relu1 = nn.ReLU(inplace=True)
    
    self.conv2 = nn.Conv2d(self.in_planes2-c1, self.out_planes2-c1, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(self.out_planes2-c1)
    self.mask2 = nn.Conv2d(self.out_planes2-c1, self.out_planes2-c1, 1, groups=self.out_planes2-c1, bias=False)
    self.relu2 = nn.ReLU(inplace=False)
    
        self.mask_res = nn.Sequential(
            *[nn.Conv2d(self.in_planes1, self.in_planes1, 1, groups=self.in_planes1, bias=False),
              nn.ReLU(inplace=False)])

        self.running1 = nn.BatchNorm2d(self.in_planes1, affine=False)
        self.running2 = nn.BatchNorm2d(self.out_planes, affine=False)

        nn.init.ones_(self.mask1.weight)
        nn.init.ones_(self.mask2.weight)
        nn.init.ones_(self.mask_res[0].weight)

def forward(self, x):
    self.running1(x)
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.mask1(out)
    out = self.relu1(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.mask2(out)
    out = self.relu2(out)
    out += self.mask_res(x)
    self.running2(out)
    return out

def deploy(self, merge_bn=False):
    idconv1 = nn.Conv2d(self.in_planes1, self.out_planes1, kernel_size=self.kernel1, stride=self.stride1, padding=1,
                        bias=False).eval()
    idbn1 = nn.BatchNorm2d(self.out_planes1).eval()
    # init dirac_ kernel weight, bias, mean var to idconv1
    nn.init.dirac_(idconv1.weight.data[:self.in_planes1])
    bn_var_sqrt = torch.sqrt(self.running1.running_var + self.running1.eps)
    idbn1.weight.data[:self.in_planes1] = bn_var_sqrt
    idbn1.bias.data[:self.in_planes1] = self.running1.running_mean
    idbn1.running_mean.data[:self.in_planes1] = self.running1.running_mean
    idbn1.running_var.data[:self.in_planes1] = self.running1.running_var
    # init conv1 to idconv1
    idconv1.weight.data[self.in_planes1:] = self.conv1.weight.data
    idbn1.weight.data[self.in_planes1:] = self.bn1.weight.data
    idbn1.bias.data[self.in_planes1:] = self.bn1.bias.data
    idbn1.running_mean.data[self.in_planes1:] = self.bn1.running_mean
    idbn1.running_var.data[self.in_planes1:] = self.bn1.running_var
    # init mask_res mask to mask1
    mask1 = nn.Conv2d(self.out_planes1, self.out_planes1, 1, groups=self.out_planes1, bias=False)
    mask1.weight.data[:self.in_planes1] = self.mask_res[0].weight.data*(self.mask_res[0].weight.data > 0)
    mask1.weight.data[self.in_planes1:] = self.mask1.weight.data
    idbn1.weight.data *= mask1.weight.data.reshape(-1)
    idbn1.bias.data *= mask1.weight.data.reshape(-1)

    # conv2
    idconv2 = nn.Conv2d(self.in_planes2, self.out_planes2, kernel_size=self.kernel2, stride=self.stride2, padding=1,
                        bias=False).eval()
    idbn2 = nn.BatchNorm2d(self.out_planes2).eval()
    # init dirac_ kernel weight, bias, mean var to idconv1
    nn.init.dirac_(idconv2.weight.data[:self.in_planes1])
    bn_var_sqrt = torch.sqrt(self.running1.running_var + self.running1.eps)
    idbn2.weight.data[:self.in_planes1] = bn_var_sqrt
    idbn2.bias.data[:self.in_planes1] = self.running1.running_mean
    idbn2.running_mean.data[:self.in_planes1] = self.running1.running_mean
    idbn2.running_var.data[:self.in_planes1] = self.running1.running_var
    # init conv2 to idconv2
    idconv2.weight.data[self.in_planes1:, self.in_planes1:, :, :] = self.conv2.weight.data
    idbn2.weight.data[self.in_planes1:] = self.bn2.weight.data
    idbn2.bias.data[self.in_planes1:] = self.bn2.bias.data
    idbn2.running_mean.data[self.in_planes1:] = self.bn2.running_mean
    idbn2.running_var.data[self.in_planes1:] = self.bn2.running_var
    # init mask_res mask to mask2
    mask2 = nn.Conv2d(self.out_planes2, self.out_planes2, 1, groups=self.out_planes2, bias=False)
    mask2.weight.data[:self.in_planes1] = self.mask_res[0].weight.data*(self.mask_res[0].weight.data > 0)
    mask2.weight.data[self.in_planes1:] = self.mask2.weight.data
    idbn2.weight.data *= mask2.weight.data.reshape(-1)
    idbn2.bias.data *= mask2.weight.data.reshape(-1)

    # init idconv3
    idconv3 = nn.Conv2d(self.out_planes2, self.out_planes, kernel_size=1, stride=1, padding=0, bias=False).eval()
    idbn3 = nn.BatchNorm2d(self.out_planes).eval()
    nn.init.dirac_(idconv3.weight.data[:, :self.in_planes1])
    nn.init.dirac_(idconv3.weight.data[:, self.in_planes1:])
    bn_var_sqrt = torch.sqrt(self.running2.running_var + self.running2.eps)
    idbn3.weight.data = bn_var_sqrt
    idbn3.bias.data = self.running2.running_mean
    idbn3.running_mean.data = self.running2.running_mean
    idbn3.running_var.data = self.running2.running_var
    # init mask_res mask to mask2
    mask3 = nn.Conv2d(self.out_planes, self.out_planes, 1, groups=self.out_planes, bias=False)
    mask3.weight.data = self.mask_res[0].weight.data
    idbn3.weight.data *= mask3.weight.data.reshape(-1)
    idbn3.bias.data *= mask3.weight.data.reshape(-1)

    return [idconv1, idbn1, nn.ReLU(True), idconv2, idbn2, nn.ReLU(True), idconv3, idbn3]`

只要对比下变化之前和变化之后的值是不是相等就可以,注意需要在eval()模式下对比。

@Serissa
Copy link

Serissa commented Jan 20, 2022

您好!我把代码修改了一下,在eval()模式下对比变化之前和变化之后的输出值不相等。我调试了一天不知道问题出在哪里,能帮我看看代码是哪里写错了吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants