专栏首页磐创AI技术团队的专栏使用PyTorch进行情侣幸福度测试指南

使用PyTorch进行情侣幸福度测试指南

DeepConnection模型框架

计算机视觉--图像和视频数据分析是深度学习目前最火的应用领域之一。因此,在学习深度学习的同时尝试运用某些计算机视觉技术做些有趣的事情会很有意思,也会让你发现些令人吃惊的事实。长话短说,我的搭档(Maximiliane Uhlich)和我决定将深度学习应用于浪漫情侣的形象分类上,因为Maximiliane是一位关系研究员和情感治疗师。具体来说,我们想知道我们是否可以准确地判断图像或视频中描绘的情侣是否对他们的关系感到满意?事实证明,我们可以!我们的最终模型(我们称之为DeepConnection)分类准确率接近97%,能够准确地区分幸福与不幸福的情侣。大家可以在我们的论文预览链接[1]里阅读完整介绍,上图是我们为这个任务设计的框架草图。

在数据集收集方面,我们使用这个Python脚本[2]进行网页数据抽取(webscraping)来获取幸福和不幸福的情侣数据。最后,我们整理出了大约包含1000张图像的训练集。这并不是特别多,所以我们使用数据增强与迁移学习来增强我们模型在数据集上的表现。数据增强--图像方向的微小变化,色调和色彩强度以及许多其他因素都会增强模型的泛化能力,从而避免学习一些不相关信息。例如,如果数据中幸福夫妻的图像平均比不幸福夫妻的图像更亮,我们并不希望我们的模型映射这种关联。我们使用了强大的ImgAug库[3]进行了相当多策略的数据扩充,以确保我们模型的鲁棒性。基本上对于每个批次的每个图像,我们都至少应用多种数据增强技术。下图是一张图片应用了48种数据增强策略的示例。

图像增强后数据示例

我们决定使用ResNet模型作为DeepConnection的基础网络,在大型数据集ImageNet上预先训练。通过预训练,模型已经具有了一定的识别能力。我们所有的模型都借用PyTorch实现,我们使用Google Colab上的免费GPU资源进行训练和测试。这个基础模型本身已经具备了良好的分类能力,但我们决定更进一步,用空间金字塔池化层(SPP)[4] 替换ResNet-34基础模型的最后一个自适应池模块。这里,处理后的图像数据被分成不同数量的正方形,并且仅传递最大值以进行进一步分析(最大池化)。这使得模型可以专注于重要的特征,使其对不同大小的图像具有鲁棒性,并且不受图像扰动的影响。之后,我们放置了一个均值变换(PMT)层[5],用数学函数转换数据以引入非线性,使得DeepConnection可以从数据中捕获更复杂的关系。这两个模块均提高了我们的分类准确度,我们在单独的验证集上得到了大约97%准确率。SPP / PMT和后续分类层的代码如下所示:

class SPP(nn.Module):
  def __init__(self):
    super(SPP, self).__init__()

    ## features incoming from ResNet-34 (after SPP/PMT)
    self.lin1 = nn.Linear(2*43520, 100)

    self.relu = nn.ReLU()
    self.bn1 = nn.BatchNorm1d(100)
    self.dp1 = nn.Dropout(0.5)
    self.lin2 = nn.Linear(100, 2)

  def forward(self, x):
    # SPP
    x = spatial_pyramid_pool(x, x.shape[0], [x.shape[2], x.shape[3]], [8, 4, 2, 1])

    # PMT
    x_1 = torch.sign(x)*torch.log(1 + abs(x))
    x_2 = torch.sign(x)*(torch.log(1 + abs(x)))**2
    x = torch.cat((x_1, x_2), dim = 1)

    # fully connected classification part
    x = self.lin1(x)
    x = self.bn1(self.relu(x))

    #1
    x1 = self.lin2(self.dp1(x))
    #2
    x2 = self.lin2(self.dp1(x))
    #3
    x3 = self.lin2(self.dp1(x))
    #4
    x4 = self.lin2(self.dp1(x))
    #5
    x5 = self.lin2(self.dp1(x))
    #6
    x6 = self.lin2(self.dp1(x))
    #7
    x7 = self.lin2(self.dp1(x))
    #8
    x8 = self.lin2(self.dp1(x))

    x = torch.mean(torch.stack([x1, x2, x3, x4, x5, x6, x7, x8]), dim = 0)

    return x

仔细观察代码可以看出,最终分类层上有八个变种。看似浪费了算力实际上恰恰相反。这个概念是最近提出的,叫做multi-sample dropout(多样本随机丢弃),它在训练期间显着加速了收敛[6]。它基本上是防止模型学习虚假关系(过度拟合)和试图不丢弃丢失掩码中的信息之间的折衷。

我们在项目中对这个方法进行了其他一些调整优化,具体参看我们在GitHub放出的项目代码[7]以获取更多信息。简单地提一下:我们使用混合精度(使用Apex库[8]实现)训练模型,以大大降低内存使用率,使用早停(earlystopping)来防止过度拟合,并根据余弦函数进行学习率退火。

在达到令人满意的分类准确度(具有相应高的召回率和精确度)后,我们想知道我们是否可以从DeepConnection执行的分类中学到一些东西。因此,我们尝试模型解释性探索并使用梯度加权类激活映射技术(Grad-CAM)进行分析[9]。基本地,Grad-CAM获取最终卷积层的输入梯度以确定显著区域,其可以被视为原始图像之上的上采样热图。具体实现与可视化结果如下:

热度图对比

## from https://github.com/eclique/pytorch-gradcam/blob/master/gradcam.ipynb

def GradCAM(img, c, features_fn, classifier_fn):
    feats = modulelist_conv(img.cuda().half())
    feats = feats.cuda()
    _, N, H, W = feats.size()

    out = modulelist_fc(feats)
    c_score = out[0, c]
    grads = torch.autograd.grad(c_score, feats)
    w = grads[0][0].mean(-1).mean(-1)

    sal = torch.matmul(w, feats.view(N, H*W))
    sal = sal.view(H, W).cpu().detach().numpy()
    sal = np.maximum(sal, 0)

    return sal

我们在论文中对此进行了进一步讨论,并将其嵌入到了现有的心理学研究中,但DeepConnection似乎主要关注面部区域。从研究的角度来看,这很有意义,因为面部表情会传达沟通和情感。除了Grad-CAM获得的视觉感知之外,我们还想看看我们是否可以通过模型解释得出实际特征。为此,我们创建了激活状态图,以显示最终分类层的哪些神经元被哪些给定图像区域激活。

不同幸福程度代表性激活状态图

与其他模型相比,DeepConnection还学习到了代表不幸福的特征,并不仅仅将缺乏代表幸福的特征的分类为不幸福。但是,我们需要进一步的研究才能将这些特征实际映射到人类行为可解释性方面。我们还尝试过在未知的情侣视频帧上使用DeepConnection,效果非常好。

总体而言,该模型的稳健性是其强大优势之一。准确的分类同样适用于同性恋伴侣不同肤色人种除情侣外包含其他人的视频帧中不能完整显示情侣人脸的视频帧中等等。对于图像中存在其他人的情况,DeepConnection甚至可以识别其他人是否感到满意,但仍然将其预测集中在这对情侣身上。

除了进一步的模型解释之外,下一步的工作将是使用更大的训练数据集,从而训练更复杂的模型。使用DeepConnection作为情侣治疗师的助手将会很有意思,可以在会话期间或之后对情侣的当前关系状态进行实时反馈。此外,我建议您与女票/男票一起输入你们的合照,看看DeepConnection对你们的关系有何看法!希望这会是一个好的开始!

1: https://psyarxiv.com/df25j/ 2: https://github.com/Bribak/DeepConnection 3: https://github.com/aleju/imgaug 4: https://arxiv.org/abs/1406.4729 5: https://www.sciencedirect.com/science/article/pii/S0031320318304503 6: https://arxiv.org/abs/1905.09788 7: https://github.com/Bribak/DeepConnection 8: https://github.com/NVIDIA/apex 9: https://arxiv.org/abs/1610.02391

本文分享自微信公众号 - 磐创AI(xunixs)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-09-21

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 无梯度强化学习:使用Numpy进行神经进化

    如果我告诉你训练神经网络不需要计算梯度,只需要前项传播你会怎么样?这就是神经进化的魔力!同时,我要展示的是,所有这一切只用Numpy都可以很容易地做到!学习统计...

    磐创AI
  • 探索学习率设置技巧以提高Keras中模型性能 | 炼丹技巧

    学习率是一个控制每次更新模型权重时响应估计误差而调整模型程度的超参数。学习率选取是一项具有挑战性的工作,学习率设置的非常小可能导致训练过程过长甚至训练进程被卡住...

    磐创AI
  • 在PyTorch中使用深度自编码器实现图像重建

    人工神经网络有许多流行的变体,可用于有监督和无监督学习问题。自编码器也是神经网络的一个变种,主要用于无监督学习问题。

    磐创AI
  • Python 还能实现图片去雾?FFA 去雾算法、暗通道去雾算法用起来! | 附代码

    Pytorch模块用来模型训练和网络层建立;其底层和Torch框架一样,但是使用Python重新写了很多内容,不仅更加灵活,支持动态图,而且提供了Python接...

    AI科技大本营
  • 分享两个小程序

      小编也不知道大家能不能用的到,我只是把我学到的知识分享出来,有需要的可以看一下。python本身就是一个不断更新改进的语言,不存在抄袭,有需要就可以拿过来用...

    py3study
  • 我的小工具,用C和python实现远程读卡器,远程读写消费卡片

    这个远程读卡器就是一普通usb口或串口的读卡器,只不过配合一个电脑软件作为tcp服务器。这样,程序员可以在公司电脑上运行程序连到服务器上。服务器端操作控制现场...

    特立独行的猫a
  • Pytorch实现卷积神经网络训练量化(QAT)

    深度学习在移动端的应用越来越广泛,而移动端相对于GPU服务来讲算力较低并且存储空间也相对较小。基于这一点我们需要为移动端定制一些深度学习网络来满足我们的日常续需...

    BBuf
  • 10.带人机对战的五子棋程序

    今天我们带来一个带人机对战功能的五子棋程序。程序基于前面文章中的框架搭建,新增人机对战的策略。程序基于规则进行决策,不考虑禁手,玩家执黑子先行。棋盘规模采用15...

    用户4381798
  • 6.wxPython防止窗体重画棋子消失的机制

    可以画图的类中wx.ClientDC不必依赖窗体绘画事件,可以随时实例化,随时画图。但是窗体最小化之后再恢复,重画的窗体上通过wx.ClientDC绘制的棋子会...

    用户4381798
  • 手把手教你用Python开发“剪刀石头布”小游戏【附源码】

    最近在学习PyQt5可视化界面,这是一个内容非常丰富的gui库,相对于tkinter库,功能更加强大,界面更加美观,操作也不难。于是我开始小试牛刀,用PyQt...

    python学习教程

扫码关注云+社区

领取腾讯云代金券