TensorFlow中的Nan值的陷阱

之前在TensorFlow中实现不同的神经网络,作为新手,发现经常会出现计算的loss中,出现Nan值的情况,总的来说,TensorFlow中出现Nan值的情况有两种,一种是在loss中计算后得到了Nan值,另一种是在更新网络权重等等数据的时候出现了Nan值,本文接下来,首先解决计算loss中得到Nan值的问题,随后介绍更新网络时,出现Nan值的情况。

01 Loss计算中出现Nan值

在搜索以后,找到StackOverflow上找到大致的一个解决办法(原文地址:这里),大致的解决办法就是,在出现Nan值的loss中一般是使用的TensorFlow的log函数,然后计算得到的Nan,一般是输入的值中出现了负数值或者0值,在TensorFlow的官网上的教程中,使用其调试器调试Nan值的出现,也是查到了计算log的传参为0;而解决的办法也很简单,假设传参给log的参数为y,那么在调用log前,进行一次数值剪切,修改调用如下:

loss = tf.log(tf.clip_by_value(y,1e-8,1.0))

这样,y的最小值为0的情况就被替换成了一个极小值,1e-8,这样就不会出现Nan值了,StackOverflow上也给出了相同的解决方案。于是,我就采用了上述的解决方案对于log的参数进行数值限制,但是我更加复杂化了这个限制。

tf.clip_by_value这个函数,是将第一个参数,限制在第二、三个参数指定的范围之内,使用这个函数的原意是要避免0值,并没有限制最大值,因而我将限制的调用修改如下:

loss = tf.log(tf.clip_by_value(y,1e-8,tf.reduce_max(y)))

这样就确保了对于y值的剪切,不会影响到其数值的上限。但是在实际的神经网络中使用的时候,我发现这样修改后,虽然loss的数值一直在变化,可是优化后的结果几乎是保持不变的,这就存在问题了。

经过检查,其实并不能这么简单的为了持续训练,而修改计算损失函数时的输入值。这样修改后,loss的数值很可能(存在0的话确定就是)假的数值,会对优化器优化的过程造成一定的影响,导致优化器并不能正常的工作。

要解决这个假的loss的方法很简单,就是人为的改造神经网络,来控制输出的结果,不会存在0。这就需要设计好最后一层输出层的激活函数,每个激活函数都是存在值域的,详情请见这篇博客,比如要给一个在(0,1)之间的输出(不包含0),那么显然sigmoid是最好的选择。不过需要注意的是,在TensorFlow中,tf.nn.sigmoid函数,在输出的参数非常大,或者非常小的情况下,会给出边界值1或者0的输出,这就意味着,改造神经网络的过程,并不只是最后一层输出层的激活函数,你必须确保自己大致知道每一层的输出的一个范围,这样才能彻底的解决Nan值的出现。

举例说明就是TensorFlow的官网给的教程,其输出层使用的是softmax激活函数,其数值在[0,1],这在设计的时候,基本就确定了会出现Nan值的情况,只是发生的时间罢了。

02 更新网络时出现Nan值

更新网络中出现Nan值很难发现,但是一般调试程序的时候,会用summary去观测权重等网络中的值的更新,因而,此时出现Nan值的话,会报错类似如下:

InvalidArgumentError (see above for traceback): Nan in summary histogram for: weight_1

这样的情况,一般是由于优化器的学习率设置不当导致的,而且一般是学习率设置过高导致的,因而此时可以尝试使用更小的学习率进行训练来解决这样的问题。

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2018-01-23

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏大数据挖掘DT机器学习

阿里天池大数据竞赛实战:RF&GBRT 完成过程

一点比赛心得,供不太熟悉Xlab RF和GBRT调用的同学参考,不喜勿喷,大神绕道---------- 6月初的时候LR 做到4.9后一直上不去,...

37211
来自专栏腾讯Bugly的专栏

机器学习入门之HelloWorld(Tensorflow)

1 环境搭建 (Windows) 安装虚拟环境 Anaconda,方便python包管理和环境隔离。 Anaconda3 4.2 http://mirrors...

4508
来自专栏磐创AI技术团队的专栏

ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人

简介 ? 还在开发中,它工作的效果还不好。但是你可以直接训练,并且运行。 包含预处理过的 twitter 英文数据集,训练,运行,工具代码,可以运行但是效果有待...

3838
来自专栏深度学习那些事儿

深度学习-TF、keras两种padding方式:vaild和same

在使用Keras的时候会遇到这样的代码x = Conv2D(filters, kernel_size=5, strides=2, padding='same')...

2045
来自专栏人工智能LeadAI

解析Tensorflow官方PTB模型的demo

01 seq2seq代码案例解读 RNN 模型作为一个可以学习时间序列的模型被认为是深度学习中比较重要的一类模型。在Tensorflow的官方教程中,有两个与...

4658
来自专栏AI深度学习求索

CAM实践:基于pytorch的使用方法

注意:如果为了快一点,不使用网络的图片以及文件的话,记得更改图片地址和已下载文件地址哦

1943
来自专栏编程

大神级Python工程师是怎么P图的,带你用Python玩转P图

? 1.PIL:Python影像库 PIL或者Python Imaging Library是一个包含许多函数来处理来自Python脚本的图像的包。PIL官方网...

3528
来自专栏ATYUN订阅号

深度学习图像识别项目(中):Keras和卷积神经网络(CNN)

在下篇文章中,我还会演示如何将训练好的Keras模型,通过几行代码将其部署到智能手机上。

1.4K5
来自专栏贾志刚-OpenCV学堂

OpenCV中图像算术操作与逻辑操作

在图像处理中有两类最重要的基础操作分别是图像点操作与块操作,简单点说图像点操作就是图像每个像素点的相关逻辑与几何运算、块操作最常见就是基于卷积算子的各种操作、实...

41810
来自专栏Small Code

sklearn中Logistics Regression的coef_和intercept_的具体意义

使用sklearn库可以很方便的实现各种基本的机器学习算法,例如今天说的逻辑斯谛回归(Logistic Regression),我在实现完之后,可能陷入代码太久...

3096

扫码关注云+社区