`main.py`的实现中支持了稀疏训练，其中下面这行代码即添加了稀疏训练的惩罚系数，注意是作用在BN层的缩放系数上的：

```parser.add_argument('--s', type=float, default=0.0001,
help='scale sparse rate (default: 0.0001)')
```

```def updateBN():
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
```

```def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
output = model(data)
loss = F.cross_entropy(output, target)
pred = output.data.max(1, keepdim=True)[1]
loss.backward()
if args.sr:
updateBN()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
100. * batch_idx / len(train_loader), loss.data[0]))

def test():
model.eval()
test_loss = 0
correct = 0
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.cross_entropy(output, target, size_average=False).data[0] # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum()

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(

def save_checkpoint(state, is_best, filepath):
torch.save(state, os.path.join(filepath, 'checkpoint.pth.tar'))
if is_best:
shutil.copyfile(os.path.join(filepath, 'checkpoint.pth.tar'), os.path.join(filepath, 'model_best.pth.tar'))

best_prec1 = 0.
for epoch in range(args.start_epoch, args.epochs):
if epoch in [args.epochs*0.5, args.epochs*0.75]:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
train(epoch)
prec1 = test()
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer': optimizer.state_dict(),
}, is_best, filepath=args.save)

print("Best accuracy: "+str(best_prec1))
```

# VGG16的剪枝

## 模型加载

```model = vgg(dataset=args.dataset, depth=args.depth)
if args.cuda:
model.cuda()

if args.model:
if os.path.isfile(args.model):
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
.format(args.model, checkpoint['epoch'], best_prec1))
else:
print("=> no checkpoint found at '{}'".format(args.resume))

print(model)
```

### 预剪枝

```# 计算需要剪枝的变量个数total
total = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]

# 确定剪枝的全局阈值
bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
# 按照权值大小排序
y, i = torch.sort(bn)
thre_index = int(total * args.percent)
# 确定要剪枝的阈值
thre = y[thre_index]
#********************************预剪枝*********************************#
pruned = 0
cfg = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
# 剪枝掉的通道数个数
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
elif isinstance(m, nn.MaxPool2d):
cfg.append('M')

pruned_ratio = pruned/total

print('Pre-processing Successful!')
```

## 对预剪枝后的模型进行测试

```# simple test model after Pre-processing prune (simple set BN scales to zeros)
#********************************预剪枝后model测试*********************************#
def test(model):
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
# 加载测试数据
if args.dataset == 'cifar10':
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
# 对R, G，B通道应该减的均值
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
elif args.dataset == 'cifar100':
datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
else:
raise ValueError("No valid dataset is given.")
model.eval()
correct = 0
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
# 记录类别预测正确的个数
correct += pred.eq(target.data.view_as(pred)).cpu().sum()

print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(

acc = test(model)
```

### 正式剪枝

```# 定义原始模型和新模型的每一层保留通道索引的mask
for [m0, m1] in zip(model.modules(), newmodel.modules()):
# 对BN层和ConV层都要剪枝
if isinstance(m0, nn.BatchNorm2d):
# np.squeeze 从数组的形状中删除单维度条目，即把shape中为1的维度去掉
# np.argwhere(a) 返回非0的数组元组的索引，其中a是要索引数组的条件。
# 如果维度是1，那么就新增一维，这是为了和BN层的weight的维度匹配
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
elif isinstance(m0, nn.Conv2d):
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
# 注意卷积核Tensor维度为[n, c, w, h]，两个卷积层连接，下一层的输入维度n就等于当前层的c
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
elif isinstance(m0, nn.Linear):
# 注意卷积核Tensor维度为[n, c, w, h]，两个卷积层连接，下一层的输入维度n'就等于当前层的c
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()

torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar'))

print(newmodel)
model = newmodel
test(model)
```

```python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 16 --epochs 160
```

# ResNet的剪枝

## 设置通道鉴别层

```class channel_selection(nn.Module):
"""
从BN层的输出中选择通道。它应该直接放在BN层之后，此层的输出形状由self.indexes中的1的个数决定
"""
def __init__(self, num_channels):
"""
使用长度和通道数相同的全1向量初始化"indexes", 剪枝过程中，将要剪枝的通道对应的indexes位置设为0
"""
super(channel_selection, self).__init__()
self.indexes = nn.Parameter(torch.ones(num_channels))

def forward(self, input_tensor):
"""
参数：
输入Tensor维度: (N,C,H,W)，这也是BN层的输出Tensor
"""
selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy()))
if selected_index.size == 1:
selected_index = np.resize(selected_index, (1,))
output = input_tensor[:, selected_index, :, :]
return output
```

## 将通道鉴别层放入ResNet

```class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, cfg, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
# 新增的通道鉴别层，放在BN之后
self.select = channel_selection(inplanes)
self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(cfg[1])
self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride,
self.bn3 = nn.BatchNorm2d(cfg[2])
self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.bn1(x)
out = self.select(out)
out = self.relu(out)
out = self.conv1(out)

out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)

out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual

return out

class resnet(nn.Module):
def __init__(self, depth=164, dataset='cifar10', cfg=None):
super(resnet, self).__init__()
assert (depth - 2) % 9 == 0, 'depth should be 9n+2'

n = (depth - 2) // 9
block = Bottleneck

if cfg is None:
# Construct config variable.
cfg = [[16, 16, 16], [64, 16, 16]*(n-1), [64, 32, 32], [128, 32, 32]*(n-1), [128, 64, 64], [256, 64, 64]*(n-1), [256]]
cfg = [item for sub_list in cfg for item in sub_list]

self.inplanes = 16

self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
bias=False)
self.layer1 = self._make_layer(block, 16, n, cfg = cfg[0:3*n])
self.layer2 = self._make_layer(block, 32, n, cfg = cfg[3*n:6*n], stride=2)
self.layer3 = self._make_layer(block, 64, n, cfg = cfg[6*n:9*n], stride=2)
self.bn = nn.BatchNorm2d(64 * block.expansion)
# 新增的通道鉴别层，放在BN之后
self.select = channel_selection(64 * block.expansion)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(8)

if dataset == 'cifar10':
self.fc = nn.Linear(cfg[-1], 10)
elif dataset == 'cifar100':
self.fc = nn.Linear(cfg[-1], 100)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()

def _make_layer(self, block, planes, blocks, cfg, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
)

layers = []
layers.append(block(self.inplanes, planes, cfg[0:3], stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, cfg[3*i: 3*(i+1)]))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)

x = self.layer1(x)  # 32x32
x = self.layer2(x)  # 16x16
x = self.layer3(x)  # 8x8
x = self.bn(x)
x = self.select(x)
x = self.relu(x)

x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)

return x
```

## 对Resnet进行剪枝

```for layer_id in range(len(old_modules)):
m0 = old_modules[layer_id]
m1 = new_modules[layer_id]
# 对BN层和ConV层都要剪枝
if isinstance(m0, nn.BatchNorm2d):
# np.squeeze 从数组的形状中删除单维度条目，即把shape中为1的维度去掉
# np.argwhere(a) 返回非0的数组元组的索引，其中a是要索引数组的条件。
# 如果维度是1，那么就新增一维，这是为了和BN层的weight的维度匹配
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
# 如果下一层是通道选择层，这个是ResNet和VGG剪枝的唯一不同之处
if isinstance(old_modules[layer_id + 1], channel_selection):
# 如果下一层是通道选择层，这一层就不剪枝
m1.weight.data = m0.weight.data.clone()
m1.bias.data = m0.bias.data.clone()
m1.running_mean = m0.running_mean.clone()
m1.running_var = m0.running_var.clone()

# We need to set the channel selection layer.
m2 = new_modules[layer_id + 1]
m2.indexes.data.zero_()
m2.indexes.data[idx1.tolist()] = 1.0

layer_id_in_cfg += 1
else:
# 否则正常剪枝
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
elif isinstance(m0, nn.Conv2d):
if conv_count == 0:
m1.weight.data = m0.weight.data.clone()
conv_count += 1
continue
# 正常剪枝就好
if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d):
# This convers the convolutions in the residual block.
# The convolutions are either after the channel selection layer or after the batch normalization layer.
conv_count += 1
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()

# If the current convolution is not the last convolution in the residual block, then we can change the
# number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
if conv_count % 3 != 1:
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
continue

# We need to consider the case where there are downsampling convolutions.
# For these convolutions, we just copy the weights.
m1.weight.data = m0.weight.data.clone()
elif isinstance(m0, nn.Linear):
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))

m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()
```

# 备注

