论文:https://arxiv.org/pdf/2305.12972.pdf
来自华为诺亚、悉尼大学的研究者们提出了一种极简的神经网络模型 VanillaNet,以极简主义的设计为理念,网络中仅仅包含最简单的卷积计算,去掉了残差和注意力模块,在计算机视觉中的各种任务上都取得了不俗的效果。
VanillaNet,这是一种设计优雅的神经网络架构。 通过避免高深度、shortcuts和自注意力等复杂操作,VanillaNet 简洁明了但功能强大。
深度为6的网络即可取得76.36%@ImageNet的精度,深度为13的VanillaNet甚至取得了83.1%的惊人性能。
所提出的 VanillaNet 具有十分惊艳的速度和精度指标,例如 VanillaNet-9 仅仅使用 9 层,就在 ImageNet 上达到了接近 80% 的精度,和同精度的 ResNet-50 相比,速度提升一倍以上(2.91ms v.s. 7.64ms),而 13 层的 VanillaNet 已经可以达到 83% 的 Top-1 准确率,和相同精度的 Swin-S 网络相比速度快 1 倍以上。尽管 VanillaNet 的参数量和计算量都远高于复杂网络,但由于其极简设计带来的优势,速度反而更快。
核心代码:
class VanillaBlock(nn.Module):
def __init__(self, dim, dim_out, act_num=3, stride=2, deploy=False, ada_pool=None):
super().__init__()
self.act_learn = 1
self.deploy = deploy
if self.deploy:
self.conv = nn.Conv2d(dim, dim_out, kernel_size=1)
else:
self.conv1 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1),
nn.BatchNorm2d(dim, eps=1e-6),
)
self.conv2 = nn.Sequential(
nn.Conv2d(dim, dim_out, kernel_size=1),
nn.BatchNorm2d(dim_out, eps=1e-6)
)
if not ada_pool:
self.pool = nn.Identity() if stride == 1 else nn.MaxPool2d(stride)
else:
self.pool = nn.Identity() if stride == 1 else nn.AdaptiveMaxPool2d((ada_pool, ada_pool))
self.act = activation(dim_out, act_num)
def forward(self, x):
if self.deploy:
x = self.conv(x)
else:
x = self.conv1(x)
x = torch.nn.functional.leaky_relu(x,self.act_learn)
x = self.conv2(x)
x = self.pool(x)
x = self.act(x)
return x
def _fuse_bn_tensor(self, conv, bn):
kernel = conv.weight
bias = conv.bias
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta + (bias - running_mean) * gamma / std
def switch_to_deploy(self):
kernel, bias = self._fuse_bn_tensor(self.conv1[0], self.conv1[1])
self.conv1[0].weight.data = kernel
self.conv1[0].bias.data = bias
# kernel, bias = self.conv2[0].weight.data, self.conv2[0].bias.data
kernel, bias = self._fuse_bn_tensor(self.conv2[0], self.conv2[1])
self.conv = self.conv2[0]
self.conv.weight.data = torch.matmul(kernel.transpose(1,3), self.conv1[0].weight.data.squeeze(3).squeeze(2)).transpose(1,3)
self.conv.bias.data = bias + (self.conv1[0].bias.data.view(1,-1,1,1)*kernel).sum(3).sum(2).sum(1)
self.__delattr__('conv1')
self.__delattr__('conv2')
self.act.switch_to_deploy()
self.deploy = True
详见:
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。