💡💡💡本文独家原创改进:轻量级 ViT 的高效架构选择,逐步增强标准轻量级 CNN(特别是 MobileNetV3)的移动友好性。 最终产生了一个新的纯轻量级 CNN 系列,即 RepViT
RepViTBlock即插即用,助力检测 | 亲测在多个数据集能够实现涨点,并实现轻量化
RepViT介绍
论文:https://arxiv.org/pdf/2307.09283.pdf
重点探讨了在资源有限的移动设备上,通过重新审视轻量级卷积神经网络的设计,并整合轻量级 ViTs
的有效架构选择,来提升轻量级 CNNs
的性能。
本文贡献:
MSHA
)可以让模型学习全局表示。然而,轻量级 ViTs 和轻量级 CNNs 之间的架构差异尚未得到充分研究。MobileNetV3
的移动友好性。这便衍生出一个新的纯轻量级 CNN 家族的诞生,即RepViT
。值得注意的是,尽管 RepViT 具有 MetaFormer 结构,但它完全由卷积组成。RepViT
超越了现有的最先进的轻量级 ViTs,并在各种视觉任务上显示出优于现有最先进轻量级ViTs的性能和效率,包括 ImageNet 分类、COCO-2017 上的目标检测和实例分割,以及 ADE20k 上的语义分割。特别地,在ImageNet
上,RepViT
在 iPhone 12
上达到了近乎 1ms 的延迟和超过 80% 的Top-1 准确率,这是轻量级模型的首次突破。
通过集成轻量级 ViT 的高效架构选择,逐步增强标准轻量级 CNN(特别是 MobileNetV3)的移动友好性。 最终产生了一个新的纯轻量级 CNN 系列,即 RepViT。
RepViT 通过逐层微观设计来调整轻量级 CNN,这包括选择合适的卷积核大小和优化挤压-激励(Squeeze-and-excitation,简称SE)层的位置。这两种方法都能显著改善模型性能。
浅层网络使用卷积提取器
更深的下采样层,作者们首先使用一个 1x1 卷积来调整通道维度,然后将两个 1x1 卷积的输入和输出通过残差连接,形成一个前馈网络。此外,他们还在前面增加了一个 RepViT 块以进一步加深下采样层,这一步提高了 top-1 准确率到 75.4%,同时延迟为 0.96ms。
最终,通过整合上述改进策略,我们便得到了模型RepViT
的整体架构,该模型有多个变种,例如RepViT-M1/M2/M3
。同样地,不同的变种主要通过每个阶段的通道数和块数来区分。
核心代码:
class RepViTBlock(nn.Module):
def __init__(self,in1, inp, hidden_dim, oup, kernel_size=3, stride=2, use_se=0, use_hs=0):
super(RepViTBlock, self).__init__()
assert stride in [1, 2]
self.identity = stride == 1 and inp == oup
print(inp)
print(hidden_dim)
print(oup)
assert(hidden_dim == 2 * inp)
if stride == 2:
self.token_mixer = nn.Sequential(
Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
)
self.channel_mixer = Residual(nn.Sequential(
# pw
Conv2d_BN(oup, 2 * oup, 1, 1, 0),
nn.GELU() if use_hs else nn.GELU(),
# pw-linear
Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
))
else:
assert(self.identity)
self.token_mixer = nn.Sequential(
RepVGGDW(inp),
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
)
self.channel_mixer = Residual(nn.Sequential(
# pw
Conv2d_BN(inp, hidden_dim, 1, 1, 0),
nn.GELU() if use_hs else nn.GELU(),
# pw-linear
Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
))
def forward(self, x):
return self.channel_mixer(self.token_mixer(x))
详见:
https://blog.csdn.net/m0_63774211/article/details/131939062
2023腾讯技术创作特训营第二期有奖征文,瓜分万元奖池和键盘手表
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。