这次是这周新复现的一些超分辨率相关的机器学习的东西, 所选文章是知乎帖子 [从SRCNN到EDSR,总结深度学习端到端超分辨率方法发展历程]https://zhuanlan.zhihu.com/p/31664818 整理而来(文末点击原文可以跳转), 顺序接着上篇【AI】超分辨率经典论文复现(1)——2016年. 本文4.7k字, 篇幅不长. 由于复现了这么多网络, 现在对这个领域也熟悉了起来, 接下来我还会继续复现别的网络但是不会再这样按照一篇文章的顺序来进行了, 而是对感兴趣的进行复现然后凑够一定数量就发一篇.
才疏学浅, 错漏在所难免, 如果我的复现中有对论文的理解问题希望大家在留言处指出. 本文同步存于我的Github仓库, 有错误会在那里更新(https://github.com/ZFhuang/Study-Notes/tree/main/Content/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/%E8%B6%85%E5%88%86%E8%BE%A8%E7%8E%87%E5%AE%9E%E8%B7%B5)
Image Super-Resolution via Deep Recursive Residual Network
DRNN
可以理解为VDSR和DRCN的结合体, 网络的大骨架是VSDR式的全残差型, 但是图中每个RB块都是多个递归的小残差块组成, 也就是右图中的conv-conv组合是参数共享的递归形式的, 这样的结构使得不断学习之下也能较好表征深度特征. 但是和其它深度网络类似, 由于参数数目过多导致训练起来非常耗时, 而且要注意在训练函数中进行梯度裁剪减少梯度爆炸或梯度弥散的出现.
class DRRN(nn.Module):
# 文中建议的各层组合是令 d=(1+2*num_resid_units)*num_recur_blocks+1 为 20左右
def __init__(self, num_recur_blocks=1, num_resid_units=15, num_filter=128, filter_size=3):
super(DRRN, self).__init__()
# 多个递归块连接
seq = []
for i in range(num_recur_blocks):
if i == 0:
# 第一个递归块
seq.append(RecursiveBlock(
num_resid_units, 1, num_filter, filter_size))
else:
seq.append(RecursiveBlock(num_resid_units,
num_filter, num_filter, filter_size))
self.residual_blocks = nn.Sequential(*seq)
# 最终的出口卷积
self.last_conv = nn.Conv2d(
num_filter, 1, filter_size, padding=filter_size//2)
def forward(self, img):
skip = img
img = self.residual_blocks(img)
img = self.last_conv(img)
# 总残差
img = skip+img
return img
class RecursiveBlock(nn.Module):
# 类似DRCN的递归残差结构, 在RecursiveBlock内部的多个ResidualBlock权值是共享的
def __init__(self, num_resid_units=3, input_channel=128, output_channel=128, filter_size=3):
super(RecursiveBlock, self).__init__()
self.num_resid_units = num_resid_units
# 递归块的入口卷积
self.input_conv = nn.Conv2d(
input_channel, output_channel, filter_size, padding=filter_size//2)
self.residual_unit = nn.Sequential(
# 两个conv组, 都有一套激活和加权
nn.Conv2d(output_channel, output_channel, filter_size,
padding=filter_size//2),
nn.BatchNorm2d(output_channel),
nn.ReLU(True),
nn.Conv2d(output_channel, output_channel, 1),
nn.Conv2d(output_channel, output_channel, filter_size,
padding=filter_size//2),
nn.BatchNorm2d(output_channel),
nn.ReLU(True),
nn.Conv2d(output_channel, output_channel, 1)
)
def forward(self, x_b):
x_b = self.input_conv(x_b)
skip = x_b
# 多次残差, 重复利用同一个递归块
for i in range(self.num_resid_units):
x_b = self.residual_unit(x_b)
x_b = skip+x_b
return x_b
Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution
LapSRN
借用【超分辨率】Laplacian Pyramid Networks(LapSRN)的网络配图, 可以看到LapSRN的特点是递归结构减少参数量和逐级上采样的结构, 这使得即使面对高倍率的放大任务这个网络也能得到比较稳定的重建结果.
且由于使用了递归结构和逐级的参数共享, 残差与原始图分流的金字塔结构, 这个网络执行高效, 训练也不困难, 很值得学习.
Charbonnier损失
网络还采用了Charbonnier损失函数, 称这个L1loss的变种可以让重建出来的图片不像MSEloss那么模糊, 测试中感觉实际效果有限.
class LapSRN(nn.Module):
def __init__(self, fea_chan=64, scale=2, conv_num=3):
super(LapSRN, self).__init__()
# 根据所需的放大倍数计算递归次数
self.level_num = int(scale/2)
# 名字带有share的层会在递归中共享参数
self.share_ski_upsample = nn.ConvTranspose2d(
1, 1, 4, stride=scale, padding=1)
self.input_conv = nn.Conv2d(1, fea_chan, 3, padding=1)
seq = []
for _ in range(conv_num):
seq.append(nn.Conv2d(fea_chan, fea_chan, 3, padding=1))
seq.append(nn.LeakyReLU(0.2, True))
self.share_embedding = nn.Sequential(*seq)
self.share_fea_upsample = nn.ConvTranspose2d(
fea_chan, fea_chan, 4, stride=scale, padding=1)
self.share_output_conv = nn.Conv2d(fea_chan, 1, 3, padding=1)
def forward(self, img):
# 特点是既要准备一个向深层传递的残差层, 也要保持一个向下传递的原始层
tmp = self.input_conv(img)
for _ in range(self.level_num):
skip = self.share_ski_upsample(img)
img = self.share_embedding(tmp)
img = self.share_fea_upsample(img)
tmp = img
img = self.share_output_conv(img)
img = img+skip
return img
Image Super-Resolution Using Dense Skip Connections
SRDenseNet
SRDenseNet的结构比较复杂, 结合了稠密块和残差网络的思想. 其完整形式就是上图的结构, 用类似残差的思想连接多个稠密块, 目的是提取出最深层有效的特征, 然后用瓶颈层减少前面过多的通道数, 最后反卷积得到超分辨率. 这个结构的特点是稍微改改超参数网络的规模就会极具扩大, 尽管稠密块可以较好地复用之前的参数让深度网络训练变得容易, 但是稠密块和残差连接使得显存消耗很大, 而且训练时推理速度较慢, 由于没有使用论文中那么大的数据集(5w)因此本地实验结果效果不是很好.
DenseNet
稠密块就是上图的结构, 网上有很多介绍, 核心特点是其内部每层卷积向前传递时都会把当前的参数直接连接到下一层, 使得网络训练起来变得更容易.
class SRDenseNet(nn.Module):
def __init__(self, scale=4, dense_input_chan=16, growth_rate=16, bottleneck_channel=256, num_dense_conv=8, num_dense=8):
super(SRDenseNet, self).__init__()
self.dense_input_chan = dense_input_chan
self.growth_rate = growth_rate
self.num_dense_conv = num_dense_conv
self.num_dense = num_dense
# 输入层, 通道数转为dense_input_chan
self.input = nn.Sequential(
nn.Conv2d(1, dense_input_chan, 3, padding=1),
nn.ReLU(True)
)
# 稠密层, 由多个稠密块组成, 有skip连接, 输出通道num_dense*num_dense_conv*growth_rate+dense_input_chan
seq = []
for i in range(num_dense):
seq.append(DenseBlock((i*num_dense_conv*growth_rate) +
dense_input_chan, growth_rate, num_dense_conv))
self.dense_blocks = nn.Sequential(*seq)
# 缩小输出时候的维度的瓶颈层, 输出通道bottleneck_channel
self.bottleneck = bottleneck_layer(
num_dense*num_dense_conv*growth_rate+dense_input_chan, bottleneck_channel)
# 用于上采样的反卷积层, 通道保持bottleneck_channel
seq = []
for _ in range(scale//2):
seq.append(nn.ConvTranspose2d(bottleneck_channel,
bottleneck_channel, 3, stride=2, padding=1))
self.deconv = nn.Sequential(*seq)
# 输出层, 输出通道1
self.output = nn.Conv2d(bottleneck_channel, 1, 3, padding=1)
def forward(self, x):
x = self.input(x)
dense_skip = x
for i in range(self.num_dense):
x = self.dense_blocks[i](x)
# 稠密残差连接, 不断连接新加入的维度的输出
dense_skip = torch.cat(
(dense_skip, x[:, (i*self.num_dense_conv*self.growth_rate)+self.dense_input_chan:, :, :]), dim=1)
x = self.bottleneck(dense_skip)
x = self.deconv(x)
x = self.output(x)
return x
def bottleneck_layer(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, 1),
)
def conv_layer(in_channel, out_channel):
return nn.Sequential(
nn.BatchNorm2d(in_channel),
nn.ReLU(True),
nn.Conv2d(in_channel, out_channel, 3, padding=1),
)
class DenseBlock(nn.Module):
def __init__(self, in_channel=16, growth_rate=16, num_convs=8):
super(DenseBlock, self).__init__()
self.num_convs = num_convs
seq = []
for _ in range(num_convs):
# 不断连接并增加着的特征图
seq.append(conv_layer(in_channel, growth_rate))
in_channel = in_channel+growth_rate
self.convs = nn.Sequential(*seq)
def forward(self, x):
for i in range(self.num_convs):
# 拼接之前得到的特征图
y = self.convs[i](x)
x = torch.cat((x, y), dim=1)
return x
Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
SRGAN
GAN就是需要两个网络共同运作, 一个称为生成网络, 用来产生超分辨率的图, 一个判别网络, 用来检测生成的图是不是与真实高分辨率图很接近. 需要在迭代中依次训练两个网络, 让判别网络的结果作为loss强化生成网络的生成, 更强的生成网络又反过来让判别网络更难判断. 通过GAN的网络结构可以得到视觉表现上更好的超分辨率结果(不过在量化结果上没有其它方法那么好).
SRGAN损失
GAN超分辨率除了两个网络互相配合外, 核心就是将两个网络连接在一起的损失函数. 文章让生成网络的损失是MSE损失和判别网络判别概率求和加权如上式, 这个权值的改变将影响生成出来的图片是更偏向于MSE指标还是更偏向人眼特征.
class SRGAN_generator(nn.Module):
# 基于SRResNet, 用来生成图像
def __init__(self, scale=4, in_channel=3, num_filter=64, num_resiblk=16):
super(SRGAN_generator, self).__init__()
self.num_filter = num_filter
self.num_resiblk = num_resiblk
self.input_conv = nn.Sequential(
nn.Conv2d(in_channel, num_filter, 9, padding=4),
nn.PReLU()
)
# 大量的残差块
seq = []
for _ in range(num_resiblk):
seq.append(nn.Sequential(
nn.Conv2d(num_filter, num_filter, 3, padding=1),
nn.BatchNorm2d(num_filter),
nn.PReLU(),
nn.Conv2d(num_filter, num_filter, 3, padding=1),
nn.BatchNorm2d(num_filter),
))
self.residual_blocks = nn.Sequential(*seq)
self.resi_out = nn.Sequential(
nn.Conv2d(num_filter, num_filter, 3, padding=1),
nn.BatchNorm2d(num_filter),
)
# 上采样
seq = []
for _ in range(scale//2):
seq.append(nn.Sequential(
nn.Conv2d(num_filter, num_filter*4, 3, padding=1),
nn.PixelShuffle(2),
nn.PReLU()
))
self.upsample = nn.Sequential(*seq)
self.output_conv = nn.Conv2d(num_filter, in_channel, 3, padding=1)
def forward(self, x):
x = self.input_conv(x)
# 内外两种残差连接
skip = x
resi_skip = x
for i in range(self.num_resiblk):
x = self.residual_blocks[i](x)+resi_skip
resi_skip = x
x = self.resi_out(x)+skip
x = self.upsample(x)
return self.output_conv(x)
class SRGAN_discriminator(nn.Module):
# 基于VGG19, 用来判别当前图像是否是真实的
def __init__(self, in_channel=3):
super(SRGAN_discriminator, self).__init__()
self.input_conv = nn.Sequential(
nn.Conv2d(in_channel, 64, 3, padding=1),
nn.LeakyReLU(inplace=True)
)
# 大量卷积层来提取特征
self.convs = nn.Sequential(
conv_layer(64, 64, 2),
conv_layer(64, 128, 1),
conv_layer(128, 128, 2),
conv_layer(128, 256, 1),
conv_layer(256, 256, 2),
conv_layer(256, 512, 1),
conv_layer(512, 512, 2)
)
self.output_conv = nn.Sequential(
# 这里通过池化和卷积将高维数据变为单一的正负输出
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, 1024, 1, padding=0),
nn.LeakyReLU(0.2),
nn.Conv2d(1024, 1, 1, padding=0)
)
def forward(self, x):
batch_size = x.size(0)
x = self.input_conv(x)
x = self.convs(x)
x = self.output_conv(x)
# 注意分类网络最后的激活
return torch.sigmoid(x.view(batch_size))
def conv_layer(in_channel, out_channel, stride):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3, stride=stride, padding=1),
nn.BatchNorm2d(out_channel),
nn.LeakyReLU(inplace=True)
)
class Adversarial_loss(nn.Module):
# 损失函数, 是两种损失的结合
def __init__(self, disc_alpha=1e-3):
super(Adversarial_loss, self).__init__()
self.alpha = disc_alpha
self.mse_loss = nn.MSELoss()
def forward(self, X, Y, loss_disc):
# 图像本身loss是MSE
loss = self.mse_loss(X,Y)
# 判别器loss
_loss_disc=loss_disc.detach()
_loss_disc=torch.sum(-torch.log(_loss_disc))
# 结合
loss = loss+self.alpha*_loss_disc
return loss
Enhanced Deep Residual Networks for Single Image Super-Resolution
EDSR
EDSR是从SRResNet改进而来的, 主要是删去了SRResNet的大量batchnorm层, 因为文中说batchnorm大大影响了网络的灵活性, 删去batchnorm层后每个残差块结尾都加上了一个scale层降低残差输出的强度, 目的是减少多层残差网络容易出现的数值不稳定性. 实际测试中EDSR表现还算不错, 训练比较快效果也不错.
class EDSR(nn.Module):
def __init__(self, scale=4, in_channel=3, num_filter=256, num_resiblk=32, resi_scale=0.1):
super(EDSR, self).__init__()
self.num_filter = num_filter
self.num_resiblk = num_resiblk
self.resi_scale=resi_scale
self.input_conv = nn.Sequential(
nn.Conv2d(in_channel, num_filter, 9, padding=4),
)
# 大量的残差块, 去掉了bn层, 且残差以外不再有relu
seq = []
for _ in range(num_resiblk):
seq.append(nn.Sequential(
nn.Conv2d(num_filter, num_filter, 3, padding=1),
nn.ReLU(),
nn.Conv2d(num_filter, num_filter, 3, padding=1),
))
self.residual_blocks = nn.Sequential(*seq)
self.resi_out = nn.Sequential(
nn.Conv2d(num_filter, num_filter, 3, padding=1),
)
# 上采样
seq = []
for _ in range(scale//2):
seq.append(nn.Sequential(
nn.Conv2d(num_filter, num_filter*4, 3, padding=1),
nn.PixelShuffle(2),
))
self.upsample = nn.Sequential(*seq)
self.output_conv = nn.Conv2d(num_filter, in_channel, 3, padding=1)
def forward(self, x):
x = self.input_conv(x)
# 内外两种残差连接
skip = x
resi_skip = x
for i in range(self.num_resiblk):
x = self.residual_blocks[i](x)*self.resi_scale+resi_skip
resi_skip = x
x = self.resi_out(x)+skip
x = self.upsample(x)
return self.output_conv(x)