UNET可以理解为FCN的一种扩展。 相关资料:
1. U-Net网络架构
U-Net这个命名是很形象的,因为它的架构看起来就是个“U”,让人记忆深刻:
左侧可以理解为编码器,右侧可以理解为解码器。编码器又分为4个子模块,每个子模块包含2个卷积层和1个max pool下采样层,编码器同样分为4个子模块,每个子模块也是2个卷积层和1个上采样层。下采样的时候,分辨率减半,上采样的时候分辨率乘以2,但是这并不代表该网络输入和输出的分辨率是一样的,因为每次经过卷积层分辨率都在减少。另外,该网络还使用了跳层连接,处于同一水平上的模块连接在一起,类似RenNet的残差模块。从这点上看,它应该是比较适用于去噪去水印之类的任务的。
模型的最后使用1*1的卷积核,将通道数映射成和类别数一致。
对比一下FCN,它就是不断的下采样(特征图分辨率减少,通道数增加),最后用一个上采样还原分辨率,核心思想和UNet已经很接近了。
2. U-Net实现
看文字描述比较晕,还是看代码比较清晰:
# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
这个实现个人觉得并不是太好,至少模块化上并不是很清晰。相关的基础模块都在https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py 这里实现,有兴趣可以自己看看。
我比较关心上采样是怎么实现的:
# https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
# ConvTranspose2d逆卷积
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
显然,代码里有两种上采样的方式:
注意:在跳层连接的时候,只是使用torch.cat将两个张量拼接起来,这样计算简单。
这个代码的实现看起来跟论文的不太一样,它是对当前的特征图进行padding,以匹配左边对应层次的特征图,这样对应层次的特征图的分辨率就是一样的,整体上的输入输出就是一样的。
这种padding的方式是否真的好呢?毕竟图像分割是需要像素级别的,如果使用一个上采样来完成padding的操作会怎么样?我觉得会更好,只是计算量加大了。
3. Overlap-tile策略
这有一个说明文章:https://www.zhihu.com/question/268331470/answer/368865906
这个策略其实也不能理解:
这个策略有一个背景,那就是医学图像的size可能是非常大的,基本不太可能一次直接放进模型去做预测,于是作者搞出了这样一个策略,应该是类似滑动窗口的意思。例如当我们需要预测上图(左)中的橙色矩形时,我们就将该图的周边一起作为输入。
不过这样也会带来一个问题,就是如果橙色框刚好就是原图的边界怎么办(如上图右)?按照普通的策略,那也很简单,直接加空白padding嘛,但是这可能并不是最好的,于是论文又提出了一种镜像padding的方式。在上图中,从右往左看应该就能理解什么是镜像padding了。直接加空白padding,会造成边界的不连续,如果以镜像的方式加padding则基本能解决这个问题。
4. 怎么训练
要想理解怎么训练,那我们最重要的是要理解损失函数。我们看UNet的网络结构,输入的是572*572,输出的是388*388,size并不一样,那该怎么计算损失呢?当然,size要处理成一样也很容易,一次缩放就可以了,但是UNet也是这样的么?
从模型最后输出的通道数和类别数相等来看,因为unet本来是用来做实例分割的,应该是将每个类别做成了一个二值mask(属于该类别的像素点为1,其余为零),这样就能对每个通道计算损失,加起来也就是总的损失。总体思路应该是这样的,关键是每个通道的特征图和对应mask怎么计算损失呢?
我们看原论文,或者看一些网上的文章,说得好像挺复杂的,如果看代码实现,好简单的:
# https://github.com/milesial/Pytorch-UNet/blob/master/train.py
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
# somethings ...
loss = criterion(masks_pred, true_masks)
这个实现,计算loss其实就是直接计算真实mask和预测mask的交叉熵。不过作者论文说的是使用了带边界权重的Softmax:
增大边界像素点的权重,显然是有助于模型对边缘像素的学习的。不过我看第三方的代码实现,并没有实现这个,可能是计算量问题?对于权重的计算,论文有给出其计算公式:
这个公式里,d1和d2分别代表像素x与最近的两个细胞(UNet是用于医学图像分割,所说的边界也是指细胞的边界)的距离(一个像素到细胞的距离就可能有两种定义方式,一种是到细胞中心的距离,另一种是到细胞边缘的最近距离,作者论文用的可能是后一种),但问题是为什么要设置两个距离值呢?因为如果用两个边界值的话,一个细胞如果距离别的细胞比较远,那它的边界可能权重也不高(可能是医学图像上,这种情况出现的概率很低吧)。
显然,d1和d2这两个值越大,exp的值就越小越接近0;而这两个值如果越小越接近零,exp的值就接近1。
我觉得这个权重的算法并不是很好,计算量很大。如果我来定义,就直接先检测细胞的边界,然后根据像素点到边界的距离来设置权重。不过即使这种,计算量也不小,而且怎么并行到GPU上去计算还是个问题。
5. 网络的输入输出分辨率不同的问题
我们回头看UNet的网络结构,输入是572*572,输出是388*388,原始的输入其实是512*512,因为上下左右都有30像素的padding。即使这样,512和388的差距该怎么理解呢?难不成mask已经缩放成388*388了?
6. 小结
还有几个问题是没有解决的,等我理解了再补充。不过我想论文的这些参数可能作者实验下的最佳选择,迁移到我们的场景下,可能未必是最佳的,我们完全可以自己的构造一个适合我们场景的UNet。
对于UNet,我觉得有几点收获:
20201016