前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >NAFNet :无需非线性激活,真“反直觉”!但复原性能也是真强!

NAFNet :无需非线性激活,真“反直觉”!但复原性能也是真强!

作者头像
AIWalker
发布2022-04-27 14:37:34
2.7K0
发布2022-04-27 14:37:34
举报
文章被收录于专栏:AIWalker

本文提出一种超简基线方案Baseline,它不仅计算高效同时性能优于之前SOTA方案;在所得Baseline基础上进一步简化得到了NAFNet:移除了非线性激活单元且性能进一步提升。所提方案在SIDD降噪与GoPro去模糊任务上均达到了新的SOTA性能,同时计算量大幅降低(可参考下图)。

1Building A Simple Baseline

在该部分内容中,我们将从头开始构建一个用于图像复原的简单基线(Simple Baseline)。为保证结构的简洁性,基本原则是:如无必要,勿增实体(奥卡姆剃刀)。参考HINet-Simple,我们主要在16GMACs计算量(输入为

256\times 256

)范围进行试验分析,其他计算量模型的结果见试验部分;在任务方面,我们主要在SIDD降噪、GoPro去模糊上进行验证。

Architecture

上图给出了图像复原领域常用架构示意图,包含多阶段架构、多尺度融合架构以及UNet架构。为减少块间(inter-block)复杂度,我们采用了带跳过连接的UNet架构

2A Plain Block

神经网络一般采用模块堆叠方式构建,所选用的UNet架构决定了模块堆叠方式,但模块的设计仍然是个问题。

上图a给出了Restormer一文所构建的模块,我们以其作为参考并进行简化:采用卷积替代Transformer(见上图b)。这里的替换主要是基于以下三个考量:

  • 尽管Transformer在CV领域表现出了惊人的优势,但一些研究表明:Transformer并非达成SOTA结果的必要条件;
  • depthwise卷积比自注意力更简单;
  • 本文并非旨在讨论Transformer与卷积的优劣,而仅在于提供了一个简单基线。

Normalization

归一化技术在high-level任务中已被广泛应用,但在low-level任务中应用极少。但是,依托于Transformer,LN得到了越来越多的应用。基于该事实,我们猜想:LN可能是达成SOTA复原器的关键,故在上述Plain模块中添加了LN(见上面图示c)。LN的引入使得训练更平滑,甚至可以将学习率放大10倍更大的学习率可以带来显著性能提升:0.44dB@SIDD(39.29dB→39.73dB),3.39dB@GoPro(28.51dB → 31.90dB)。

Activation

尽管ReLU是最常用的激活函数,现有SOTA方案中采用GELU进行代替。激活函数的替代在性能方面导致:-0.02dB@SIDD(39.73dB → 39.71dB),0.21dB@GoPro(31.90dB → 21.11dB)。由于GELU可以保持降噪性能相当且大幅提升去模糊性能,故我们采用GELU替代ReLU(见上面图示c)。

Attention

受启发于Restormer中的注意力机制,我们意识到:普通通道注意力可以满足计算效率需求并引入全局信息;此外通道注意力的有效性已在多个图像复原任务中得到验证。因此,我们进一步添加通道注意力,见上面图示c。通道注意力可以带来额外的性能提升:0.14dB@SIDD(37.71dB → 39.85dB)、0.24dB@GoPro(32.11dB → 32.35dB)。

Summary

到此,我们从头开始构建了本文的Baseline(结果见上表)。尽管所设计模块中的每个成分都非常简单,但组合后可以得到一个强基线方案:在SIDD与GoPro数据集上超越了其他SOTA方案,同时计算量大幅降低。

3Nonlinear Activation Free Network

尽管上述所提Baseline足够简单且竞争力,那么是否可能在确保简洁性的同时进一步提升性能呢是否可以更简介且无性能损失呢?我们尝试从SOTA方案(VRT, MAXIM, Restormer)中寻找共性点以回答上述问题,我们发现:这些方案均采用了Gated Linear Units(GLU,定义如下)。

Gate(X, f, g, \sigma) = f(X) \odot \sigma(g(X))

将GLU引入到Baseline中可能会改善性能,但同时会导致块内(intra-block)计算复杂度提升,而这并非我们所期望的。

为此,我们对Baseline中的激活函数进行了回顾,其定义与近似实现如下:

GELU(x) = x \Phi(x) \approx 0.5x(1+tanh[\sqrt{2/\pi}(x+0.044715 x^3)])

GELU与GLU的实现可以发现:GELU是GLU的一种特例。我们从另一个角度猜想:GLU可视作一种广义激活函数,它是可以用于替代非线性激活函数。此外,我们注意到:GLU自身已包含非线性且该非线性并不依赖

\sigma

基于上述,我们提出了一种简化版GLU变种(见上图c):直接将特征沿通道维度分成两部分并相乘。采用所提SimpleGate对GELU进行替换导致的性能提升为:0.08dB@SIDD(39.85dB → 39.93dB)、0.41dB@GoPro(32.35dB → 32.76dB)。相比GELU的复杂实现,SimpleGate的实现非常简单:

SimpleGate(X, Y) = X \odot Y

Simplified Channel Attention Baseline方案中采用了通道注意力(见上图a),它定义如下:

CA(X) = X * \Psi(X) = X \sigma(W_2 max(0, W_1 pool(X)))

可以看到:CA的定义与GLU非常像。这就是促使我们将CA视作GLU的一种特例并可进一步简化。通过保留通道注意力的两个重要作用(全局信息聚合、通道信息交互),我们提出了如下简化版通道注意力(见上面图示b):

SCA(X) = X * W pool(X)

为公平对比,我们调整了CA的特征维度以保持与SCA计算复杂度相当。尽管SCA足够简单,但它并未造成性能损失:0.03dB@SIDD(39.93dB → 39.96dB),0.09dB@GoPro(32.76dB → 32.85dB)。

以Baseline为基础,我们采用SimpleGate替换GELU、采用SCA替换CA达成了进一步的简化,且未噪声性能损失。值得一提的是:简化后的网络中不包含非线性激活函数。因此,我们将所得方案称之为NAFNet(Nonlinear Activation Free Network)。

4Experiments

上图与表为SIDD数据集上不同方案的性能对比,可以看到:

  • 所提Baseline与NAFNet以0.28dB指标优于此前最佳方案Restormer,同时计算量更低
  • 在重建效果方面,相比其他方案,所提方案可以重建更细粒度细节。

上图与表为GoPro数据集上不同方案的性能对比,可以看到:

  • 所提Baseline与NAFNet分别比此前最佳方案MPRNet-local高0.09dB与0.38dB,同时仅需8.4%NG)计算量
  • 在重建效果方面,相比其他方案,所提方案的重建结果更锐利。

上图与表为Raw降噪与JPEG伪影+去模糊组合任务(NTIRE2021图像去模糊Track2)上的性能对比,可以看到:

  • 相比PMRID,所提方案NAFNet(通道数与模块数进行了减少以确保计算量相当)具有更高的PSNR指标,同时具有更优的重建效果。该实验同时说明了NAFNet的灵活缩放性。
  • 从Table8可以看到:相比NTIRE2021竞赛冠军方案HINet与MAXIM,所提NAFNet取得了更优的PSNr与SSIM指标,同时具有更低的计算量(约三分之一)。

5Code Implement

代码语言:javascript
复制
class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2
        
class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_ratio=0):
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(c, dw_channel, 1)
        self.conv2 = nn.Conv2d(dw_channel, dw_channel, 3, 1, 1 group=dw_channel)
        self.conv3 = nn.Conv2d(dw_channel//2, dw_channel, 1)
        
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dw_channel//2, dw_channel//2, 1)
        )
        
        self.sg = SimpleGate()
        
        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(c, ffn_channel, 1)
        self.conv5 = nn.Conv2d(ffn_channel, c, 1)
        
        self.norm1 = LayerNorm2d()
        self.norm2 = LayerNorm2d()
        
        # skip-init trick to stabilize training.
        self.beta = nn.Parameter(torch.zeros((1,c,1,1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1,c,1,1)), requires_grad=True)
        
    def forward(self, inp):
        x = inp 
        
        x = self.norm(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)
        
        y = inp + x * self.beta
        
        x = self.norm2(y)
        x = self.conv4(x)
        x = self.sg(x)
        x = self.conv5(x)
        
        y = y + x * self.gamma
        return y

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-04-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AIWalker 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1Building A Simple Baseline
    • Architecture
    • 2A Plain Block
      • Normalization
        • Activation
          • Attention
            • Summary
            • 3Nonlinear Activation Free Network
            • 4Experiments
            • 5Code Implement
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档