前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于Pytorch的CapsNet源码详解CapsNet基本结构代码实现参考

基于Pytorch的CapsNet源码详解CapsNet基本结构代码实现参考

作者头像
月见樽
发布2018-04-27 11:07:43
1.5K0
发布2018-04-27 11:07:43
举报

CapsNet基本结构

参考CapsNet的论文,提出的基本结构如下所示:

capsnet_mnist.jpg

可以看出,CapsNet的基本结构如下所示:

  • 普通卷积层Conv1:基本的卷积层,感受野较大,达到了9x9
  • 预胶囊层PrimaryCaps:为胶囊层准备,运算为卷积运算,最终输出为[batch,caps_num,caps_length]的三维数据:
    • batch为批大小
    • caps_num为胶囊的数量
    • caps_length为每个胶囊的长度(每个胶囊为一个向量,该向量包括caps_length个分量)
  • 胶囊层DigitCaps:胶囊层,目的是代替最后一层全连接层,输出为10个胶囊

代码实现

胶囊相关组件

激活函数Squash

胶囊网络有特有的激活函数Squash函数: $$ Squash(S) = \cfrac{||S||2}{1+||S||2} \cdot \cfrac{S}{||S||} $$ 其中输入为S胶囊,该激活函数可以将胶囊的长度压缩,代码实现如下:

代码语言:javascript
复制
def squash(inputs, axis=-1):
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
    return scale * inputs

其中:

  • norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)计算输入胶囊的长度,p=2表示计算的是二范数,keepdim=True表示保持原有的空间形状。
  • scale = norm**2 / (1 + norm**2) / (norm + 1e-8)计算缩放因子,即$ \cfrac{||S||2}{1+||S||2} \cdot \cfrac{1}{||S||}$
  • return scale * inputs完成计算

预胶囊层PrimaryCaps

代码语言:javascript
复制
class PrimaryCapsule(nn.Module):
    """
    Apply Conv2D with `out_channels` and then reshape to get capsules
    :param in_channels: input channels
    :param out_channels: output channels
    :param dim_caps: dimension of capsule
    :param kernel_size: kernel size
    :return: output tensor, size=[batch, num_caps, dim_caps]
    """
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):
        super(PrimaryCapsule, self).__init__()
        self.dim_caps = dim_caps
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        outputs = self.conv2d(x)
        outputs = outputs.view(x.size(0), -1, self.dim_caps)
        return squash(outputs)

预胶囊层使用卷积层实现,其前向传播包括三个部分:

  • outputs = self.conv2d(x):对输入进行卷积处理,这一步output的形状是[batch,out_channels,p_w,p_h]
  • outputs = outputs.view(x.size(0), -1, self.dim_caps):将4D的卷积输出变为3D的胶囊输出形式,output的形状为[batch,caps_num,dim_caps],其中caps_num为胶囊数量,可自动计算;dim_caps为胶囊长度,需要预先指定。
  • return squash(outputs):激活函数,并返回激活后的胶囊

胶囊层DigitCaps

参数定义
代码语言:javascript
复制
def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):
    super(DenseCapsule, self).__init__()
    self.in_num_caps = in_num_caps
    self.in_dim_caps = in_dim_caps
    self.out_num_caps = out_num_caps
    self.out_dim_caps = out_dim_caps
    self.routings = routings
    self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

参数定义如下:

  • in_num_caps:输入胶囊的数量
  • in_dim_caps:输入胶囊的长度(维数)
  • out_num_caps:输出胶囊的数量
  • out_dim_caps:输出胶囊的长度(维数)
  • routings:动态路由迭代的次数

另外,还定义了权值weight,尺寸为[out_num_caps, in_num_caps, out_dim_caps, in_dim_caps],即每个输出和每个输出胶囊都有连接

前向传播
代码语言:javascript
复制
def forward(self, x):
    x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
    x_hat_detached = x_hat.detach()

    b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)).cuda()
    assert self.routings > 0, 'The \'routings\' should be > 0.'
    for i in range(self.routings):
        c = F.softmax(b, dim=1)
        if i == self.routings - 1:
            outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
        else:
            outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True))
            b = b + torch.sum(outputs * x_hat_detached, dim=-1)
    return torch.squeeze(outputs, dim=-2)

前向传播分为两个部分:输入映射和动态路由。输入映射如下所示:

  1. x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
    • x[:, None, :, :, None]将数据维度从[batch, in_num_caps, in_dim_caps]扩展到[batch, 1,in_num_caps, in_dim_caps,1]
    • torch.matmul()将weight和扩展后的输入相乘,weight的尺寸是[out_num_caps, in_num_caps, out_dim_caps, in_dim_caps],相乘后结果尺寸为[batch, out_num_caps, in_num_caps,out_dim_caps, 1]
    • torch.squeeze()去除多余的维度,去除后结果尺寸[batch,out_num_caps,in_num_caps,out_dim_caps]
  2. x_hat_detached = x_hat.detach()截断梯度反向传播

这一部分结束后,每个输入胶囊都产生了out_num_caps个输出胶囊,所以目前共有in_num_caps*out_num_caps个胶囊,第二部分是动态路由,动态路由的算法图如下所示:

dynamic_route.jpg

以下部分实现了该过程:

代码语言:javascript
复制
b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)).cuda()
    for i in range(self.routings):
        c = F.softmax(b, dim=1)
        if i == self.routings - 1:
            outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
        else:
            outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True))
            b = b + torch.sum(outputs * x_hat_detached, dim=-1)
  1. 第一部分是softmax函数,使用c = F.softmax(b, dim=1)实现,该步骤不改变b的尺寸
  2. 第二部分是计算路由结果:outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
    • c[:, :, :, None]扩展c的维度,以便按位置相乘时广播维度
    • torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True)计算出每个胶囊与对应权值的积,即算法中的$s_j$,同时在倒数第二维上求和,则该步输出的结果尺寸为[batch, out_num_caps, 1,out_dim_caps]
    • 通过激活函数squash()
  3. 第三部分更新权重b = b + torch.sum(outputs * x_hat_detached, dim=-1),两个按位相乘的变量尺寸分别为[batch, out_num_caps, in_num_caps, out_dim_caps]和[batch, out_num_caps, 1,out_dim_caps],倒数第二维上有广播行为,因此最终结果为[batch, out_num_caps, in_num_caps]

其他组件

网络结构

代码语言:javascript
复制
class CapsuleNet(nn.Module):
    """
    A Capsule Network on MNIST.
    :param input_size: data size = [channels, width, height]
    :param classes: number of classes
    :param routings: number of routing iterations
    Shape:
        - Input: (batch, channels, width, height), optional (batch, classes) .
        - Output:((batch, classes), (batch, channels, width, height))
    """
    def __init__(self, input_size, classes, routings):
        super(CapsuleNet, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.routings = routings

        # Layer 1: Just a conventional Conv2D layer
        self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0)

        # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_caps, dim_caps]
        self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0)

        # Layer 3: Capsule layer. Routing algorithm works here.
        self.digitcaps = DenseCapsule(in_num_caps=32*6*6, in_dim_caps=8,
                                      out_num_caps=classes, out_dim_caps=16, routings=routings)

        # Decoder network.
        self.decoder = nn.Sequential(
            nn.Linear(16*classes, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU()

    def forward(self, x, y=None):
        x = self.relu(self.conv1(x))
        x = self.primarycaps(x)
        x = self.digitcaps(x)
        length = x.norm(dim=-1)
        if y is None:  # during testing, no label given. create one-hot coding using `length`
            index = length.max(dim=1)[1]
            y = Variable(torch.zeros(length.size()).scatter_(1, index.view(-1, 1).cpu().data, 1.).cuda())
        reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
        return length, reconstruction.view(-1, *self.input_size)

网络组件包括两个部分:胶囊网络和重建网络,重建网络为多层感知机,根据胶囊的结果重建了图像,这表示胶囊除了包括结果外,还可以包括一些空间信息。

注意胶囊网络的前向传播部分为:

代码语言:javascript
复制
x = self.relu(self.conv1(x))
x = self.primarycaps(x)
x = self.digitcaps(x)
length = x.norm(dim=-1)

最终的输出为每个胶囊的二范数,即向量的长度

代价函数

胶囊神经网络的胶囊部分的代价函数如下所示 $$ L_c = T_c max(0,m^+ - ||V_c||)^2 + \lambda (1 - T_c)max(0,||v_c|| - m^-) ^ 2 $$

以下代码实现了这个部分,其中L为胶囊的代价函数计算,这里$m+=0.9,m-=0.1$,L_recon为重建的代价函数,为输入图像与复原图像的MSELoss函数。

代码语言:javascript
复制
def caps_loss(y_true, y_pred, x, x_recon, lam_recon):
    L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + \
        0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2
    L_margin = L.sum(dim=1).mean()
    L_recon = nn.MSELoss()(x_recon, x)
    return L_margin + lam_recon * L_recon

参考

CapsNet论文

CapsNet开源代码

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018.04.17 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • CapsNet基本结构
  • 代码实现
    • 胶囊相关组件
      • 激活函数Squash
      • 预胶囊层PrimaryCaps
      • 胶囊层DigitCaps
    • 其他组件
      • 网络结构
      • 代价函数
  • 参考
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档