谷歌正式开源 Hinton 胶囊理论代码,即刻用 TensorFlow 实现吧

雷锋网(公众号:雷锋网) AI 研习社消息,相信大家对于「深度学习教父」Geoffery Hinton 在去年年底发表的胶囊网络还记忆犹新,在论文 Dynamic Routing between Capsules 中,Hinton 团队提出了一种全新的网络结构。为了避免网络结构的杂乱无章,他们提出把关注同一个类别或者同一个属性的神经元打包集合在一起,好像胶囊一样。在神经网络工作时,这些胶囊间的通路形成稀疏激活的树状结构(整个树中只有部分路径上的胶囊被激活)。这样一来,Capsule 也就具有更好的解释性。

在实验结果上,CapsNet 在数字识别和健壮性上都取得了不错的效果。详情可以

日前,该论文的第一作者 Sara Sabour 在 GitHub 上公布了论文代码,大家可以马上动手实践起来。雷锋网 AI 研习社将教程编译整理如下:终于盼来了Hinton的Capsule新论文,它能开启深度神经网络的新时代吗?

所需配置:

  • TensorFlow(点击 http://www.tensorflow.org 进行安装或升级)
  • NumPy (详情点击 http://www.numpy.org/ )
  • GPU

执行 test 程序,来验证安装是否正确,诸如:

python layers_test.py

快速 MNIST 测试:

下载并提取 MNIST tfrecord 到 $DATA_DIR/ 下:

https://storage.googleapis.com/capsule_toronto/mnist_data.tar.gz

下载并提取 MNIST 模型 checkpoint 到 $CKPT_DIR 下:

https://storage.googleapis.com/capsule_toronto/mnist_checkpoints.tar.gz

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --train=false \ --summary_dir=/tmp/ --checkpoint=$CKPT_DIR/mnist_checkpoint/model.ckpt-1

快速 CIFAR10 ensemble 测试:

下载并提取 cifar10 二进制文件到 $DATA_DIR/ 下:

https://www.cs.toronto.edu/~kriz/cifar.html

下载并提取 cifar10 模型 checkpoint 到 $CKPT_DIR 下:

https://storage.googleapis.com/capsule_toronto/cifar_checkpoints.tar.gz

将目录($DATA_DIR)作为 data_dir 来传递:

python experiment.py --data_dir=$DATA_DIR --train=false --dataset=cifar10 \ --hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false \ --summary_dir=/tmp/ --checkpoint=$CKPT_DIR/cifar/cifar{}/model.ckpt-600000 \ --num_trials=7

CIFAR10 训练指令:

python experiment.py --data_dir=$DATA_DIR --dataset=cifar10 --max_steps=600000\ --hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false \ --summary_dir=/tmp/

MNIST full 训练指令:

  • 也可以执行--validate=true as well 在训练-测试集上训练
  • 执行 --num_gpus=NUM_GPUS 在多块GPU上训练

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\ --summary_dir=/tmp/attempt0/

MNIST baseline 训练指令:

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\ --summary_dir=/tmp/attempt1/ --model=baseline

To test on validation during training of the above model:

训练如上模型时,在验证集上进行测试(记住,在训练过程中会持续执行指令):

  • 在训练时执行 --validate=true 也一样
  • 可能需要两块 GPU,一块用于训练集,一块用于验证集
  • 如果所有的测试都在一台机器上,你需要对训练集、验证集的测试中限制 RAM 消耗。如果不这样,TensorFlow 会在一开始占用所有的 RAM,这样就不能执行其他工作了

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\ --summary_dir=/tmp/attempt0/ --train=false --validate=true

大家可以通过 --num_targets=2 和 --data_dir=$DATA_DIR/multitest_6shifted_mnist.tfrecords@10 在 MultiMNIST 上进行测试或训练,生成 multiMNIST/MNIST 记录的代码在 input_data/mnist/mnist_shift.py 目录下。

multiMNIST 测试代码:

python mnist_shift.py --data_dir=$DATA_DIR/mnist_data/ --split=test --shift=6 --pad=4 --num_pairs=1000 --max_shard=100000 --multi_targets=true

可以通过 --shift=6 --pad=6 来构造 affNIST expanded_mnist

论文地址:https://arxiv.org/pdf/1710.09829.pdf

GitHub 地址:https://github.com/Sarasra/models/tree/master/research/capsules

原文发布于微信公众号 - AI研习社(okweiwu)

原文发表时间:2018-02-02

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

Github 项目推荐 | 用 JavaScript 实现的神经网络 —— brain.js

不过,一般的开发者应该都不会用神经网络来实现异或的功能吧,所以这里有一个更加实际的例子:训练一个神经网络来识别颜色对比 https://brain.js.org...

18220
来自专栏YoungGy

ML基石_9_LinearRegression

linear regression problem linear regression algorithm 优化问题 求梯度 算法 generalization...

25360
来自专栏吉浦迅科技

确认过的眼神:这是一份NVIDIA TensorRT 4.0的实战教程

NVIDIA TensorRT是一个高性能的深度学习推理优化器和runtime,为深度学习推理应用程序提供低延迟和高吞吐量。您可以从每个深度学习框架中导入经过训...

53620
来自专栏AI研习社

Github 项目推荐 | 微软开源 MMdnn,模型可在多框架间转换

近期,微软开源了 MMdnn,这是一套能让用户在不同深度学习框架间做相互操作的工具。比如,模型的转换和可视化,并且可以让模型在 Caffe、Keras、MXNe...

40880
来自专栏YG小书屋

Query Auto Completion自动完成查询(一)

当我们用搜索引擎或其他工具搜索内容时,输入框下方的提示内容会根据你的输入进行调整展示。这个过程我们称之为Query Auto Completion(QAC)。用...

16910
来自专栏深度学习之tensorflow实战篇

R包—iGraph

这几天收到师兄的任务,熟悉iGRaph包的使用,通过查资料,外加自己的实践,在此做个简单的学习笔记。 以下例子均是在R 3.0.1版本下测试的。 1.用igr...

35450
来自专栏企鹅号快讯

基于自搭建BP神经网络的运动轨迹跟踪控制(二)

1 前言 朋友们~好久没见~。在上一篇基于自搭建BP神经网络的运动轨迹跟踪控制(一)中,首次给大家介绍了如何将BP神经网络模型用于运动控制,并基于matlab做...

29890
来自专栏AI研习社

Github 项目推荐 | ANSI C 的简单神经网络库

Genann是一个经过精心测试的库,用于在 C 中训练和使用前馈人工神经网络(ANN)。它的主要特点是简单、快速、可靠和可魔改(hackable),它只需要提供...

9210
来自专栏ATYUN订阅号

【深度学习】图片风格转换应用程序:使用CoreML创建Prisma

WWDC 2017让我们了解了苹果公司对机器学习的看法以及它在移动设备上的应用。CoreML框架使得将ML模型引入iOS应用程序变得非常容易。 ? 大约一年前,...

49180
来自专栏量化投资与机器学习

深度学习项目

Github上比较受欢迎的深度学习项目(Top Deep Learning Projects),按照获得星星个数的排名,包括一些教程项目等。 ? ? ? ?

20760

扫码关注云+社区

领取腾讯云代金券