资源 | 价值迭代网络的PyTorch实现与Visdom可视化

选自GitHub

作者:Xingdong Zuo

机器之心编译

参与:吴攀

《价值迭代网络(Value Iteration Networks)》是第 30 届神经信息处理系统大会(NIPS 2016)的最佳论文奖(Best Paper Award)获奖论文,机器之心曾在该论文获奖后第一时间采访了该论文作者之一吴翼(Yi Wu),参见《独家 | 机器之心对话 NIPS 2016 最佳论文作者:如何打造新型强化学习观?(附演讲和论文)》。吴翼在该文章中介绍说:

VIN 的目的主要是解决深度强化学习泛化能力较弱的问题。传统的深度强化学习(比如 deep Q-learning)目标一般是采用神经网络学习一个从状态(state)到决策(action)的直接映射。神经网络往往会记忆一些训练集中出现的场景。所以,即使模型在训练时表现很好,一旦我们换了一个与之前训练时完全不同的场景,传统深度强化学习方法就会表现的比较差。在 VIN 中,我们提出,不光需要利用神经网络学习一个从状态到决策的直接映射,还要让网络学会如何在当前环境下做长远的规划(learn to plan),并利用长远的规划辅助神经网络做出更好的决策。

该研究得到了广泛的关注,在原来 Theano 实现之外也出现了 TensorFlow 等其它版本的实现。近日,GitHub 用户 Xingdong Zuo 又公开发布了一个 PyTorch 的版本和另一个 TensorFlow 版本,机器之心在本文中对前者进行了介绍。

项目地址:

  • PyTorch 版本:https://github.com/zuoxingdong/VIN_PyTorch_Visdom
  • TensorFlow 版本:https://github.com/zuoxingdong/VIN_TensorFlow

相关项目地址:

  • 原作者的 Theano 实现:https://github.com/avivt/VIN
  • TensorFlow 实现:

https://github.com/TheAbhiKumar/tensorflow-value-iteration-networks

  • 原论文地址:https://arxiv.org/abs/1602.02867

关键想法

  • 一个完全可微分的神经网络,带有一个「规划(planning)」子模块
  • 价值迭代 = 卷积层+面向信道的最大池化(Value Iteration = Conv Layer + Channel-wise Max Pooling)
  • 用于新的未见过的任务时,能比反应策略更好地泛化

学习到的奖励图像(Reward Image)和其每次 VI 迭代时的价值图像(Value Images,访问原项目查看动图)

依赖包

该项目需要以下软件包:

  • Python >= 3.6
  • Numpy >= 1.12.1
  • PyTorch >= 0.1.10
  • SciPy >= 0.19.0
  • Visdom >= 0.1

数据集

每一个数据样本都由网格世界中当前状态的 (x, y) 坐标构成,后面跟着一张障碍图像(obstacle image)和一张目标图像(goal image)。

运行试验:训练

网格世界 8×8

python run.py --datafile data/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128

网格世界 16×16

python run.py --datafile data/gridworld_16x16.npz --imsize 16 --lr 0.008 --epochs 30 --k 20 --batch_size 128

网格世界 28×28

python run.py --datafile data/gridworld_28x28.npz --imsize 28 --lr 0.003 --epochs 30 --k 36 --batch_size 128

说明:

  • datafile:数据文件的路径
  • imsize:输入图像的尺寸,从 [8, 16, 28] 中选择
  • lr:使用 RMSProp 优化器的学习率,推荐 [0.01, 0.005, 0.002, 0.001]
  • epochs:训练的 epoch 数量,默认:30
  • k:价值迭代(Value Iterations)的数量,推荐 [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • ch_i:输入层中信道(channel)的数量,默认:2,即障碍图像和目标图像
  • ch_h:第一层卷积层中信道的数量,默认:50,论文中有描述
  • ch_q:VI 模块中 q 层(~actions)中的信道数量,默认:10,论文中有描述
  • batch_size:批大小,默认:128

使用 Visdom 进行可视化

我们将使用 Visdom 来为每次 VI 迭代可视化学习到的奖励图像(reward image)及其对应的价值图像(value image)。

首先启动服务器

python -m visdom.server

在浏览器中打开 Visdom:http://localhost:8097

然后运行以下代码来可视化学习的奖励和价值图像:

python vis.py --datafile learned_rewards_values_28x28.npz

注:如果你想自己产生价值图像的 GIF 动画,可使用下面的命令:

convert -delay 20 -loop 0 *.png value_function.gif

基准

GPU:Titan X

表现:测试精度

注意:这是在测试集上的精度。不同于论文中的表格,其表示了在环境中学习到的策略的 rollout 的成功率。

使用 GPU 的速度

常见问题

问:如何从观察(observation)中获得奖励图像?

答:观察图像有 2 个信道。第一个信道是障碍图像(0:无障碍,1:障碍)。第二个信道是目标图像(0:无目标,10:目标)。比如说,在 8×8 的网格世界中,批大小为 128 的输入张量的形状是 [128, 2, 8, 8]。然后其被馈送到一个带有 [3,3] 滤波器和 150 个特征图卷积层,之后又是另一个带有 [3,3] 滤波器和 1 个特征图的卷积层。输出张量的形状是 [128, 1, 8, 8]。这就是奖励图像。

问:过渡模型(transition model)到底是什么?怎么通过 VI 模块从奖励图像中获取价值图像?

答:让我们假设在 8×8 的网格世界中,批大小为 128。一旦我们获得了形状为 [128, 1, 8, 8] 的奖励图像,那么我们就可以为 VI 模块中的 q 层做卷积层。[3,3] 滤波器表示其过渡概率。存在一个有 10 个滤波器的集合,其中每一个都是为了在 q 层中生成一个特征图。每一个特征图对应于一个「action」。注意这比真实可用的动作(只有 8)大一些。然后我们做一个面向信道的最大池化,以获得形状为 [128, 1, 8, 8] 的价值图像。最后我们将这个价值图像和奖励图像堆叠在一起,以进行新一次的 VI 迭代。

本文为机器之心编译,转载请联系本公众号获得授权。

✄------------------------------------------------

原文发布于微信公众号 - 机器之心(almosthuman2014)

原文发表时间:2017-03-31

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏小小挖掘机

推荐系统遇上深度学习(十一)--神经协同过滤NCF原理及实战

好久没更新该系列了,最近看到了一篇关于神经协同过滤的论文,感觉还不错,跟大家分享下。

7214
来自专栏PPV课数据科学社区

“小数据”的统计学

一、小数据来自哪里? 科技公司的数据科学、关联性分析以及机器学习等方面的活动大多围绕着”大数据”,这些大型数据集包含文档、 用户、 文件、 查询、 歌曲、 图片...

3566
来自专栏应兆康的专栏

20. 偏差和方差

假设你的训练集,开发集和测试集都来自同一分布。那么你会觉得获取更多的训练数据就可以提高性能,对吗? 尽管更多的数据是无害的,但它并不是总会像我们所期望的那样有用...

3559
来自专栏星回的实验室

推荐系统从0到1[三]:排序模型

前文中,我们根据不同召回策略召回了一批文章,并统一根据文章质量排序输出。但实际上,用户的阅读兴趣还会受到很多其他因素的影响。比如用户所处的网络环境,文章点击率、...

6904
来自专栏机器之心

专栏 | 内存带宽与计算能力,谁才是决定深度学习执行性能的关键?

机器之心专栏 作者:李飞 随着深度学习的不断发展,计算能力得到了深度学习社区越来越多的注意。任何深度学习模型,归根到底都是需要跑在设备上的,而模型对设备性能的...

4159
来自专栏AI研习社

NanoNets:数据有限如何应用深度学习?

使用深度学习解决问题的一个常见障碍是训练模型所需的数据量。对大数据的需求是因为模型中有大量参数需要学习。

1442
来自专栏CreateAMind

代码--深度网路场景位置记忆效果惊人-视频-论文

1222
来自专栏新智元

效果惊艳!FAIR提出人体姿势估计新模型,升级版Mask-RCNN

来源:densepose.org 【新智元导读】FAIR和INRIA的合作研究提出一个在Mask-RCNN基础上改进的密集人体姿态评估模型DensePose-R...

43813
来自专栏应兆康的专栏

20. 偏差和方差

1701
来自专栏机器之心

研学社•架构组 | 实时深度学习的推理加速和连续学习

机器之心原创 作者:Yanchen Wang 参与:panda 在本技术分析报告的第一部分《研学社·系统组 | 实时深度学习的推理加速和持续训练》,我们介绍了最...

2966

扫码关注云+社区

领取腾讯云代金券