前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow调试技巧

TensorFlow调试技巧

作者头像
用户7164815
发布2020-04-08 11:15:40
1.3K0
发布2020-04-08 11:15:40
举报

TensorFlow从诞生以来就一直在深度学习框架中稳居老大的位置,虽然自从2018年12月PyTorch 1.0 stable版本正式发布以来,很快减小了差距,但是也难以超越。

TensorFlow的强项在于部署(包括TensorFlow Lite在移动端部署)和运行效率,另外对各种operation的支持特别齐全,基本上你能想到的算子都已经实现好了,直接调用就好。除此之外,Google Brain的各项前沿研究,以及现在DeepMind的很多研究,开源代码肯定都是基于TensorFlow,比如现在很火的AutoML技术等等,所以成为No.1也是自然而然。

但是又不得不吐槽其调试功能,真是太难用了。这也直接导致了TensorFlow的学习曲线异常之陡,和vim的类似,学起来很难很痛苦,但是学好之后,那是相当地爽。

那么,TensorFlow怎么调试呢?使用断点还是print?亦或是高大上的tfdbg?都不是。

由于TensorFlow静态图的设计(eager模式除外,这个后面单独讨论),设置断点根本无法获取实际tensor的值,具体取值都在后台以C++的方式执行。那print呢?也只能打印出tensor的shape信息。tfdbg,这个官方开发的专用工具该行了吧?不过我建议还是不要尝试了,不仅要一点一点敲命令,我在debug大型程序的时候,直接卡死。

对了,还有一种暴力方法,我最开始的时候在使用,就是把tensor拉出来sess.run一把,这样的确可以得到tensor运行的具体值,但是每次要手动改,很麻烦。

好了,神器要出来了:tf.Print. 在老版本的TensorFlow中可以这么用,非常方便:

x = tf.Print(x,[x, x,shape, x[0], …], message=“x debug info”, summarize=100)

其中,x是需要打印的tensor,注意第一个输入是x和输出相同,但其实也可以不同,做一些操作,但一般debug不需要,所以等式左边的输出也是x.

第二个输入在方括号内表示需要打印的东西,可以是tensor x的具体值,或者是其shape,slice,甚至是函数。

第三个输入message用来标识这一处打印,可以自定义字符串。

最后的summarize控制输出元素的数量,比如100就输出x的前100个元素。

对于新版的TensorFlow,使用tf.print,语法如下:

print_op = tf.print(x)

withtf.control_dependencies([print_op]):

out = tf.add(x, x)

sess.run(out)

很方便吧?

虽然不如直接在PyCharm中设置断点方便,但能把tensor打印出来定位问题也就容易多了。当然,如果是学习代码,想单步跟踪,建议使用eager模式,这就和PyTorch的方式非常相近了,当然,牺牲的是运行效率。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-07-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI人工智能与大数据 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档