【前沿】TensorFlow Pytorch Keras代码实现深度学习大神Hinton NIPS2017 Capsule论文

【导读】10月26日,深度学习元老Hinton的NIPS2017 Capsule论文《Dynamic Routing Between Capsules》终于在arxiv上发表。今天相关关于这篇论文的TensorFlow\Pytorch\Keras实现相继开源出来,让我们来看下。

论文地址:https://arxiv.org/pdf/1710.09829.pdf

摘要:Capsule 是一组神经元,其活动向量(activity vector)表示特定实体类型的实例化参数,如对象或对象部分。我们使用活动向量的长度表征实体存在的概率,向量方向表示实例化参数。同一水平的活跃 capsule 通过变换矩阵对更高级别的 capsule 的实例化参数进行预测。当多个预测相同时,更高级别的 capsule 变得活跃。我们展示了判别式训练的多层 capsule 系统在 MNIST 数据集上达到了最好的性能效果,比识别高度重叠数字的卷积网络的性能优越很多。为了达到这些结果,我们使用迭代的路由协议机制:较低级别的 capsule 偏向于将输出发送至高级别的 capsule,有了来自低级别 capsule 的预测,高级别 capsule 的活动向量具备较大的标量积。

CapsNet-PyTorch

python依赖包

  • Python 3
  • PyTorch
  • TorchVision
  • TorchNet
  • TQDM
  • Visdom

使用说明

第一步capsule_network.py文件中设置训练epochs,batch size等

BATCH_SIZE = 100NUM_CLASSES = 10NUM_EPOCHS = 30NUM_ROUTING_ITERATIONS = 3

Step 2 开始训练. 如果本地文件夹中没有MNIST数据集,将运行脚本自动下载到本地. 确保 PyTorch可视化工具Visdom正在运行。

$ sudo python3 -m visdom.server & python3 capsule_network.py

基准数据

经过30个epoche的训练手写体数字的识别率达到99.48%. 从下图的训练进度和损失图的趋势来看,这一识别率可以被进一步的提高 。

采用了PyTorch中默认的Adam梯度优化参数并没有用到动态学习率的调整。 batch size 使用100个样本的时候,在雷蛇GTX 1050 GPU上每个Epochs 用时3分钟。

待完成

  • 扩展到除MNIST以外的其他数据集。

Credits

主要借鉴了以下两个 TensorFlow 和 Keras 的实现:

  1. Keras implementation by @XifengGuo
  2. TensorFlow implementation by @naturomics

Many thanks to @InnerPeace-Wu for a discussion on the dynamic routing procedure outlined in the paper.

CapsNet-Tensorflow

Python依赖包

  • Python
  • NumPy
  • Tensorflow (I'm using 1.3.0, not yet tested for older version)
  • tqdm (for displaying training progress info)
  • scipy (for saving image)

使用说明

训练

*第一步 * 用git命令下载代码到本地.

$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow

第二部 下载MNIST数据集(http://yann.lecun.com/exdb/mnist/), 移动并解压到data/mnist 文件夹.(当你用复制wget 命令到你的终端是注意渠道花括号里的反斜杠)

$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/{train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz}
$ gunzip data/mnist/*.gz

第三部 开始训练:

$ pip install tqdm  # install it if you haven't installed yet
$ python train.py

tqdm包并不是必须的,只是为了可视化训练过程。如果你不想要在train.py中将循环for in step ... 改成 ``for step in range(num_batch)就行了。

评估

$ python eval.py --is_training False

结果

错误的运行结果(Details in Issues #8):

  • training loss
  • test acc

Epoch

49

51

test acc

94.69

94.71


Results after fixing Issues #8:

关于capsule的一点见解

  1. 一种新的神经单元(输入向量输出向量,而不是标量)
  2. 常规算法类似于Attention机制
  3. 总之是一项很有潜力的工作,有很多工作可以在之上开展

待办:

  • 完成MNIST的实现Finish the MNIST version of capsNet (progress:90%)
  • 在其他数据集上验证capsNet
    • 调整模型结构
  • 一篇新的投稿在ICLR2018上的后续论文(https://openreview.net/pdf?id=HJWLfGWRb) about capsules(submitted to ICLR 2018),

CapsNet-Keras

依赖包

  • Keras
  • matplotlib

使用方法

训练

第一步 安装 Keras:

$ pip install keras

第二步git命令下载代码到本地.

$ git clone https://github.com/xifengguo/CapsNet-Keras.git
$ cd CapsNet-Keras

第三步 训练:

$ python capsulenet.py

一次迭代训练(default 3).

$ python capsulenet.py --num_routing 1

其他参数包括想 batch_size, epochs, lam_recon, shift_fraction, save_dir 可以以同样的方式使用。 具体可以参考 capsulenet.py

测试

假设你已经有了用上面命令训练好的模型,训练模型将被保存在 result/trained_model.h5. 现在只需要使用下面的命令来得到测试结果。

$ python capsulenet.py --is_training 0 --weights result/trained_model.h5

将会输出测试结果并显示出重构后的图片。测试数据使用的和验证集一样 ,同样也可以很方便的在新数据上验证,至于要按照你的需要修改下代码就行了。

如果你的电脑没有GPU来训练模型,你可以从https://pan.baidu.com/s/1hsF2bvY下载预先训练好的训练模型

结果

主要结果 运行 python capsulenet.py: epoch=1 代表训练一个epoch 后的结果 在保存的日志文件中,epoch从0开始。

Epoch

1

5

10

15

20

train_acc

90.65

98.95

99.36

99.63

99.75

vali_acc

98.51

99.30

99.34

99.49

99.59

损失和准确度:

一次常规迭代后的结果 运行 python CapsNet.py --num_routing 1

Epoch

1

5

10

15

20

train_acc

89.64

99.02

99.42

99.66

99.73

vali_acc

98.55

99.33

99.43

99.57

99.58

每个 epoch 在单卡GTX 1070 GPU上大概需要110s 注释: 训练任然是欠拟合的,欢迎在你自己的机器上验证。学习率decay还没有经过调试, 我只是试了一次,你可以接续微调。

测试结果 运行 python capsulenet.py --is_training 0 --weights result/trained_model.h5

模型结构:

其他实现代码

  • Kaggle (this version as self-contained notebook):
    • MNIST Dataset running on the standard MNIST and predicting for test data
    • MNIST Fashion running on the more challenging Fashion images.
  • TensorFlow:
    • naturomics/CapsNet-Tensorflow Very good implementation. I referred to this repository in my code.
    • InnerPeace-Wu/CapsNet-tensorflow I referred to the use of tf.scan when optimizing my CapsuleLayer.
    • LaoDar/tf_CapsNet_simple
  • PyTorch:
    • nishnik/CapsNet-PyTorch
    • timomernick/pytorch-capsule
    • gram-ai/capsule-networks
    • andreaazzini/capsnet.pytorch
    • leftthomas/CapsNet
  • MXNet:
    • AaronLeong/CapsNet_Mxnet
  • Lasagne (Theano):
    • DeniskaMazur/CapsNet-Lasagne
  • Chainer:
    • soskek/dynamic_routing_between_capsules

参考网址链接:

https://github.com/gram-ai/capsule-networks

https://github.com/naturomics/CapsNet-Tensorflow

https://github.com/XifengGuo/CapsNet-Keras

特别提示:

请关注专知公众号,后台回复“MLDL” 就可以获取机器学习&深度学习知识资料大全集的pdf下载链接

原文发布于微信公众号 - 专知(Quan_Zhuanzhi)

原文发表时间:2017-11-05

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏专知

【专知国庆特刊-PyTorch手把手深度学习教程系列01】一文带你入门优雅的PyTorch

【导读】主题链路知识是我们专知的核心功能之一,为用户提供AI领域系统性的知识学习服务,一站式学习人工智能的知识,包含人工智能( 机器学习、自然语言处理、计算机视...

1.2K80
来自专栏ATYUN订阅号

深度学习图像识别项目(中):Keras和卷积神经网络(CNN)

在下篇文章中,我还会演示如何将训练好的Keras模型,通过几行代码将其部署到智能手机上。

3K60
来自专栏逍遥剑客的游戏开发

边缘高亮效果(三)

16120
来自专栏AI研习社

PyTorch 的预训练,是时候学习一下了

前言 最近使用 PyTorch 感觉妙不可言,有种当初使用 Keras 的快感,而且速度还不慢。各种设计直接简洁,方便研究,比 tensorflow 的臃...

488100
来自专栏QQ音乐前端团队专栏

前端图片主题色提取

对于需要根据用户“定制”、“生成”的图片,这样的方式就有了一个上传图片---->后端计算---->返回结果的时间,等待时间也许就比较长了。由此,我尝试着利用 c...

1.8K150
来自专栏数据小魔方

ggplot2双坐标轴的解决方案

本来没有打算写这一篇的,因为在一幅图表中使用双坐标轴确实不是一个很好地习惯,无论是信息传递的效率还是数据表达的准确性而言。 但是最近有好几个小伙伴儿跟我咨询关于...

48790
来自专栏AI科技大本营的专栏

前端工程师掌握这18招,就能在浏览器里玩转深度学习

【导读】TensorFlow.js 的发布可以说是 JS 社区开发者的福音!但是在浏览器中训练一些模型还是会存在一些问题与不同,如何可以让训练效果更好?本文的作...

12810
来自专栏尾尾部落

使用自己的语料训练word2vec模型

先对新闻文本进行分词,使用的是结巴分词工具,将分词后的文本保存在seg201708.txt,以备后期使用。

98030
来自专栏CreateAMind

代码解读Top-down Neural Attention by Excitation Backprop及模型架构图

34530
来自专栏AI研习社

PyTorch 模型不适合自己怎么办?预训练教程在这里

前言 最近使用 PyTorch 感觉妙不可言,有种当初使用 Keras 的快感,而且速度还不慢。各种设计直接简洁,方便研究,比 tensorflow 的臃肿...

2.5K40

扫码关注云+社区

领取腾讯云代金券