【干货】TensorFlow 高阶技巧:常见陷阱、调试和性能优化

【新智元导读】文本将介绍一些 TensorFlow 的操作技巧,旨在提高你的模型性能和训练水平。文章将从预处理和输入管道开始,覆盖图、调试和性能优化的问题。

预处理和输入管道

保持预处理干净简洁

训练一个相对简单的模型也需要很长时间?检查一下你的预处理!任何麻烦的预处理(比如将数据转换成神经网络的输入),都会显著降低你的推理速度。对于我个人来说,我会创建所谓的“距离地图”(distant map),也就是用于“深层交互对象选择”的灰度图像作为附加输入,使用自定义python函数。我的训练速度最高是每秒大约处理 2.4 幅图像,切换到更强大的GTX 1080 后也没有提升。后来我注意到这个瓶颈,修复后训练速度就变成每秒50幅图像。

当你注意到这样的瓶颈时,一般首先会想到优化代码。但是,将计算时间从你的训练管道中去除还有一个更有效的方法,那就是将预处理移动到生成TFRecord文件的一次性操作当中。繁重的预处理只需执行一次,就能为所有的训练数据创建 TFRecords,你的管道本质上做的也就是加载记录。就算你想引入某种随机性来增强数据,一次创建不同的版本,而不是让你的管道变得庞大臃肿也是值得考虑的,不是吗?

注意队列

有一种发现昂贵的预处理管道的方法是查看 Tensorboard 的队列图。如果你使用框架 QueueRunners并将摘要存储在文件中,这些图都是自动生成的。这些图会显示你的计算机是否能够保持队列处在排满的状态。如果你发现图当中出现了负峰值,则系统无法在计算机要处理一个批次的时间内生成新的数据。其中的一个原因上面已经说过了。根据我的经验,最常见的原因是 min_after_dequeue 值很大。如果队列试图在内存中保留大量记录,你的容量很容易就饱和了,这会导致交换(swapping),并且显著降低队列的速度。其他的原因还包括硬盘问题(例如磁盘速度慢),以及单纯的是数据大,大过了你系统可以处理的程度。无论原因为何,修复这个问题都会加快你的训练过程。

图(graph)的构建和训练

把图固定

TensorFlows把图的构建和图的计算模型分开处理,这在日常编程中是非常罕见的,可能会导致初学者产生一些混乱。例如调试和发送错误消息,可能最初构建图的时候在代码里出现一次,然后在实际评估的时候又出现一次,当你习惯于代码只被评估一次后,这就有些别扭。

另一个问题是图的构建是和训练回路(loop)结合在一起的。这些循环通常是“标准”的python循环,因此可以改变图并向其中添加新的操作。在连续评估图的过程中对图进行改动,会产生重大的性能损失,但这一点在最开始的时候很难注意到。幸运的是这很容易解决。只需要在开始训练循环之前,把图固定(finalize)就行——调用tf.getDefaultGraph().finalize() 把图锁定,之后想要添加任何新的操作都会产生错误。看吧,问题解决了。

彻底分析图

实际上 TensorFlow 的分析功能是很强的,不过似乎没有得到那么多宣传。TensorFlow 里有一种机制,可以记录图操作的运行时间和内存消耗。如果你正在寻找瓶颈在哪里,或者需要弄清你的机器不更换硬盘驱动器的话能不能运行一个模型,这个功能就可以派上用场了。

要生成分析数据,你需要在启用跟踪的情况下把图整个跑一遍:

之后,一个 timeline.json 文件会被保存到当前文件夹,跟踪数据可以在 Tensorboard 找到。现在,你可以很容易地看到一个操作花了多长时间来计算,以及这个操作消耗了多少内存。打开Tensorboard的图视图,选择左侧的最新运行,你就能在右边看到性能的详细信息。一方面,这方便你调整模型,尽可能多地使用机器;另一方面,这方便你在训练管道中发现瓶颈。如果你更喜欢时间轴视图,在 Google Chromes 跟踪事件分析工具(Trace Event Profiling Tool)中加载timeline.json 文件就行了。

另一个不错的工具是 tfprof,tfprof 使用相同的功能做内存和执行时间分析,不过提供了更多的便利功能(feature)。额外的统计信息需要更改代码。

注意内存

就像上一节说的那样,分析可以让你了解特定操作的内存使用情况。但是,观察整个模型的内存消耗更加重要。你必须确保不会超过你机器的内存,因为 swapping 绝对会降低你输入管道的速度,这样 GPU 就会等着处理新的数据。要检测这种行为,用简单的 top 或者 Tensorboard 队列图应该足够了。要详细研究可以参照前面说的方法。

调试

善用打印

在调试问题时,比如停滞丢失或产生了奇怪的输出,我主要使用的工具是 tf.Print。考虑到神经网络的性质,看你的模型里面张量的原始值一般没有什么意义。没有人能看懂数百万的浮点数,看出什么地方错了。但是,有些方法,尤其是把形状或平均值打印出来,就能提供很多的信息。如果你要实现一些现有的模型,把东西打印出来能让你把模型的值和论文或文章里的值进行比较,还能帮助你解决一些棘手的问题,或者论文里的拼写错误。

TensorFlow 1.0 推出了新的 TFDebugger,看起来很有用。我现在还没有使用这个功能,但接下来几个星期肯定会用。

设置一个操作执行超时

好,现在你已经实现了你的模型,session 也启动了,但没有事情都没有什么发生?这通常是由空队列引起的。但是,如果你不知道是哪一个队列导致的,那么有一个简单的修复方法:只需在创建会话时启用一个操作执行超时,这样当操作超过限制时,脚本就会崩溃:

使用堆栈跟踪,你就可以找出是哪个操作产生了问题,修复错误,继续训练吧。

希望这篇文章对同样使用 TensorFlow 的你有用。如果你发现了错误,或者有建议或意见,欢迎在评论里和大家分享哦~~

编译来源:

http://www.deeplearningweekly.com/blog/tensorflow-quick-tips

原文发布于微信公众号 - 新智元(AI_era)

原文发表时间:2017-02-21

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏CSDN技术头条

NNabla:索尼开源的一款神经网络框架

NNabla是一款用于研究、开发和生产的深度学习框架。NNabla的目标是要能在台式电脑、HPC集群、嵌入式设备和生产服务器上都能运行。 安装 安装NNabla...

24060
来自专栏北京马哥教育

3行Python代码完成人脸识别

Face Recognition软件包 这是世界上最简单的人脸识别库了。你可以通过Python引用或者命令行的形式使用它,来管理和识别人脸。 该软件包使用dli...

59770
来自专栏大数据智能实战

微软开源认知服务CNTK的测试(语音训练)

前段时间,微软开源了认知服务的工具箱,直到近期才有时间进行测试。 看了文档,这个CNTK工具包还是非常厉害的,可以支持语音识别,图像分类,机器翻译等多种任务。里...

31350
来自专栏PHP实战技术

手把手教你玩转12306验证码的秘密!

12306相信对很多小伙伴都不陌生,假如问你对这个网站的印象的时候,你不是会立即想起那个坑爹的验证码,而正是这个验证码,也一时间成为小伙伴们讨论的话题,今天思梦...

29580
来自专栏Python小屋

两行Python代码实现电影打分与推荐

代码采用基于用户的协同过滤算法,也就是根据用户喜好来确定与当前用户最相似的用户,然后再根据最相似用户的喜好为当前用户进行推荐。 代码采用字典来存放数据,格式为{...

36570
来自专栏简书专栏

基于tensorflow+CNN的报警信息短文本分类

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。 CNN是convolutional neural netwo...

25220
来自专栏机器之心

教程 | 从零开始:TensorFlow机器学习模型快速部署指南

33950
来自专栏算法+

pytorch 移动端框架 thnets 附c示例代码

前年年前做一个手机移动端图像识别项目的时候, 先后尝试了mxnet,thnets,caffe,tensorflow. 当时的情况是,mxnet内存管理奇差,内存...

50670
来自专栏贾志刚-OpenCV学堂

Windows下TensorFlow安装与代码测试

Windows下TensorFlow安装与代码测试 一:Tensorflow介绍 TensorFlow是谷歌的深度学习应用开发框架,其思想基于数据流图与节点图实...

53280
来自专栏weixuqin 的专栏

深度学习之 TensorFlow(一):基础库包的安装

 1.TensorFlow 简介:TensorFlow 是谷歌公司开发的深度学习框架,也是目前深度学习的主流框架之一。  2.TensorFlow 环境的准...

31070

扫码关注云+社区

领取腾讯云代金券