Hinton胶囊理论代码开源,上线即受热捧

当前的深度学习理论是由GeoffreyHinton大神在2007年确立起来的,但是如今他却认为,“CNN的特征提取层与次抽样层交叉存取,将相同类型的相邻特征检测器的输出汇集到一起”是大有问题的。

去年9月,在多伦多接受媒体采访时,Hinton大神断然宣称要放弃反向传播,让整个人工智能从头再造。10月,人们关注已久的Hinton大神那篇Capsule论文"Dynamic Routing between Capsules"终于揭开面纱。

在论文中,Capsule被Hinton大神定义为这样一组神经元:其活动向量所表示的是特定实体类型的实例化参数。他的实验表明,鉴别式训练的多层Capsule系统,在MNIST手写数据集上表现出目前最先进的性能,并且在识别高度重叠数字的效果要远好于CNN。

近日,该论文的一作Sara Sabour终于在GitHub上公开了该论文中的代码。该项目上线5天便获得了217个Star,并被fork了14218次。下面让我们一起来看看Sara Sabour开源的代码吧。

胶囊模型的代码在以下论文中使用:

"Dynamic Routing between Capsules" by Sara Sabour, Nickolas Frosst, Geoffrey E. Hinton.

要求

运行测试代码验证设置是否正确,比如:

pythonlayers_test.py

快速MNIST测试结果:

从以下网址下载并提取MNIST记录到 $DATA_DIR/:https://storage.googleapis.com/capsule_toronto/mnist_data.tar.gz

从以下网址下载并提取MNIST模型检测点(checkpoint)到$CKPT_DIR:

pythonexperiment.py --data_dir=$DATA_DIR/mnist_data/ --train=false\

--summary_dir=/tmp/ --

checkpoint=$CKPT_DIR/mnist_checkpoint/model.ckpt-1

快速CIFAR10 ensemble测试结果:

从以下网址下载并提取cifar10二进制版本到$DATA_DIR/:

https://www.cs.toronto.edu/~kriz/cifar.html

从以下网址下载并提取cifar10模型检测点(checkpoint)到$CKPT_DIR:

https://storage.googleapis.com/capsule_toronto/cifar_checkpoints.tar.gz

将提取的二进制文件目录作为data_dir传递给($DATA_DIR)

python experiment.py --data_dir=$DATA_DIR --train=false--dataset=cifar1\

--hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false\

--summary_dir=/tmp/--checkpoint=$CKPT_DIR/cifar/cifar{}/model.ckpt-600000\

--num_trials=7

Sample CIFAR10训练命令:

pythonexperiment.py --data_dir=$DATA_DIR--dataset=cifar10 --max_steps=600000\

--hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false\

--summary_dir=/tmp/

Sample MNIST完整训练命令:

python experiment.py --data_dir=$DATA_DIR/mnist_data/--max_steps=300000\

--summary_dir=/tmp/attempt0/

Sample MNIST 基线训练命令:

python experiment.py --data_dir=$DATA_DIR/mnist_data/--max_steps=300000\

--summary_dir=/tmp/attempt1/--model=baseline

上述模型的训练期间在验证集上进行测试

训练中连续运行的注意事项:

在训练中 --validate = true

总共需要总共2块GPU:一个用于训练,一个用于验证

如果训练和验证工作位于同一台机器上,则需要限制每个任务的RAM占用量,因为TensorFlow会默认为第一个任务分配所有的RAM,而第二个任务将无法进行。

在MultiMNIST上测试/训练:

--num_targets=2

--data_dir= $ DATA_DIR / multitest_6shifted_mnist.tfrecords@10

生成multiMNIST / MNIST记录的代码位于input_data / mnist / mnist_shift.py

生成multiMNIST测试分割的示例代码:

python mnist_shift.py --data_dir=$DATA_DIR/mnist_data/ --split=test--shift=6

--pad=4 --num_pairs=1000 --max_shard=100000 --multi_targets=true

为affNIST泛化能力建立expanded_mnist: --shift = 6;--pad = 6。

Github地址:

https://github.com/Sarasra/models/tree/master/research/capsules

论文地址:

https://arxiv.org/abs/1710.09829

本文来自企鹅号 - AI科技大本营媒体

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

GitHub项目推荐 | ChainerCV:计算机视觉中的深度学习图书馆

ChainerCV是一个使用Chainer训练和运行神经网络以进行计算机视觉任务的工具集合。

1725
来自专栏AI研习社

Github 项目推荐 | Basel Face Model 2017 完全参数化人脸

本软件可以从 Basel Face Model 2017 里生成完全参数化的人脸,论文链接: https://arxiv.org/abs/1712.01619 ...

6057
来自专栏云时之间

深度学习与神经网络:基于自建手写字体数据集上的模型测试

1083
来自专栏ATYUN订阅号

使用FastText(Facebook的NLP库)进行文本分类和word representatio...

介绍 现在, 社交软件Facebook面临诸多挑战。Facebook每天处理大量的各种形式的文本数据,例如状态更新、评论等等。而对Facebook来说,更重要...

1.9K5
来自专栏ATYUN订阅号

ChainerCV: 一个用于深度学习的计算机视觉库

ChainerCV是一个基于Chainer用于训练和运行计算机视觉任务的神经网络工具。它涵盖了计算机视觉模型的高质量实现,以及开展计算机视觉研究的必备工具集。 ...

4397
来自专栏机器学习算法工程师

分布式TensorFlow入门教程

深度学习在各个领域实现突破的一部分原因是我们使用了更多的数据(大数据)来训练更复杂的模型(深度神经网络),并且可以利用一些高性能并行计算设备如GPU和FPGA来...

2663
来自专栏专知

【干货】手把手教你用苹果Core ML和Swift开发人脸目标识别APP

【导读】CoreML是2017年苹果WWDC发布的最令人兴奋的功能之一。它可用于将机器学习整合到应用程序中,并且全部脱机。CoreML提供的机器学习 API,包...

3146
来自专栏瓜大三哥

基于FPGA的非线性滤波器(四)

基于FPGA的非线性滤波器(四) 之并行全比较排序模块设计 2.sort_2d模块设计 对于二维运算,采用同样的思路来处理,整个计算步骤如下: (1)计算一维行...

2279
来自专栏数据和云

算法分析:Oracle 11g 中基于哈希算法对唯一值数(NDV)的估算

1 为什么引入新 NDV 算法 字段的统计数据是 CBO 优化器估算执行计划代价的重要依据。而字段的统计数据可以分为两类: 1. 概要统计数据:如 NDV 字段...

3627
来自专栏大数据智能实战

facebook Faiss的基本使用示例(逐步深入)

针对上一篇文章,安装完毕之后,可以对faiss进行基本的案例学习,具体步骤如下: step1:构造实验数据 ? step2:为向量集构建IndexFlatL2索...

1K5

扫码关注云+社区

领取腾讯云代金券