前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow小技巧整理:

Tensorflow小技巧整理:

作者头像
狼啸风云
修改2022-09-02 20:47:44
9100
修改2022-09-02 20:47:44
举报

tf.trainable_variables(), tf.all_variables(), tf.global_variables()查看变量

在使用tensorflow搭建模型时,需要定义许多变量,例如一个映射层就需要权重与偏置。当网络结果越来越复杂,变量越来越多的时候,就需要一个查看管理变量的函数,在tensorflow中,tf.trainable_variables(), tf.all_variables(),和tf.global_variables()可以来满足查看变量的要求,来简单说一下他们的不同。

tf.trainable_variables()

顾名思义,这个函数可以也仅可以查看可训练的变量,在我们生成变量时,无论是使用tf.Variable()还是tf.get_variable()生成变量,都会涉及一个参数trainable,其默认为True。以tf.Variable()为例:

代码语言:javascript
复制
__init__(
    initial_value=None,
    trainable=True,
    collections=None,
    validate_shape=True,
   ...
)

对于一些我们不需要训练的变量,比较典型的例如学习率或者计步器这些变量,我们都需要将trainable设置为False,这时tf.trainable_variables() 就不会打印这些变量。举个简单的例子,在下图中共定义了4个变量,分别是一个权重矩阵,一个偏置向量,一个学习率和计步器,其中前两项是需要训练的而后两项则不需要。

这个时候tf.trainable_variables()效果如下:

另一个问题就是,如果变量定义在scope域中,是否会有不同。实际上,tf.trainable_variables()是可以通过参数选定域名的,如下图所示:

我们重新声明了两个新变量,其中w2是在‘var’中的,如果我们直接使用tf.trainable_variables(),结果如下:

但如果我们只希望查看‘var’域中的变量,我们可以通过加入scope参数的方式实现:

可以看到,只有w2被打印出来。

tf.global_variables()

回到第一个例子,如果我希望查看全部变量,包括我的学习率等信息,可以通过tf.global_variables()来实现。效果如下:

可以看到,这时候打印出来了4个变量,其中后两个即为trainable=False的学习率和计步器。与tf.trainable_variables()一样,tf.global_variables()也可以通过scope的参数来选定域中的变量。

tf.all_variables()

与tf.global_variables()作用拥有相似的功能,只是版本问题,可以看到:

运行时会有warning的提示。还有一点需要注意的是,tf.all_variables()似乎是没有scope输入参数的,这点作用性不如前两个那么强。

应用中

在实际代码中,我们可以在定义model的时候,定义一个内部函数用来查看模型中的变量,在训练过程中,可以在开始的时候调用一次,来看一下变量名称及其阶数,对模型控制性更强,了解更加明确。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2020-09-04 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • tf.trainable_variables()
  • tf.global_variables()
  • tf.all_variables()
  • 应用中
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档