前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow修炼之道(2)——变量(Variable)

TensorFlow修炼之道(2)——变量(Variable)

作者头像
abs_zero
修改2018-05-26 14:46:54
1.1K0
修改2018-05-26 14:46:54
举报
文章被收录于专栏:AI派AI派

文章内容:TensorFlow 变量

变量

变量(Variable)是 TensorFlow 中程序处理的共享持久状态的最佳方法。与常量不同的时,常量创建后,值便无法更改,但是变量创建后 可以修改。并且修改后的值在多个Session中都是可以看见的。

训练模型时,需要使用变量(Variable)保存和更新参数。变量是包含张量(tensor)的内存缓冲。变量必须要先被 初始化(initialize) ,而且可以在训练时和训练后保存(save)到磁盘中。之后可以再恢复(restore)保存的变量值来训练和测试模型。

创建变量

创建变量有两种方式,一种是使用 tf.Variable 来创建一个新的变量,另一种是使用 tf.get_variable 来获取一个已经存在的变量或者创建一个新的变量。

tf.Variable 需要接收一个 Tensor 给构造函数,也可以自定义结点名称和数据类型。这里使用 tf.random_normal 来生成一个均值为1,标准差0.2,形状为(2, 5)的张量。使用 tf.Variable 时,如果检测到命名冲突,系统会自动处理。

代码语言:javascript
复制
import tensorflow as tf
w1 = tf.Variable(tf.random_normal((2, 5), mean=1, stddev=0.2), name="w1")
w2 = tf.Variable(tf.random_normal((2, 5), mean=1, stddev=0.2), name="w1")
print("w1.name: %s, w2.name: %s" % (w1.name, w2.name))
代码语言:javascript
复制
w1.name: w1:0, w2.name: w1_1:0

可以看出,当已经存在一个相同结点的名称后,tf.Variable 会自动添加“_1”等后缀来做区分。使用 tf.get_variable来创建变量时,结合 tf.variable_scope 可以实现共享变量。

代码语言:javascript
复制
代码语言:javascript
复制
with tf.variable_scope("scope"):
    # 这里创建的变量名将命名为 "scope/b"
    b1 = tf.get_variable(name="b", shape=[2, 5], initializer=tf.constant_initializer(1.0))
    print(b1)
代码语言:javascript
复制
<tf.Variable 'scope/b:0' shape=(2, 5) dtype=float32_ref>

接下来使用 withtf.variable_scope 来生成一个上下文管理器,需要注意的是,在 tf.variable_scope 中,需要指定 reuse=True ,否则会出错。

代码语言:javascript
复制
代码语言:javascript
复制
with tf.variable_scope("scope", reuse=True):
    b2 = tf.get_variable(name="b", shape=[2, 5])
    print(b2)
print(b1 is b2)
代码语言:javascript
复制
<tf.Variable 'scope/b:0' shape=(2, 5) dtype=float32_ref>
True

可以看到,b1b2 是同一个变量。

设备放置

像任何其它TensorFlow操作一样,你可以将变量放置到特定的设备上。

语法结构为:with tf.device(…): block,下面创建一个名为v的变量,并将其放在第一个GPU设备上

代码语言:javascript
复制
代码语言:javascript
复制
with tf.device("/gpu:0"):
    v = tf.get_variable("v", [1])

变量集合

TensorFlow 支持将变量存放在集合(collection)中,以便于在不同地方使用。TensorFlow 中每个集合都是一个列表,并且有一个名称(可以是任何字符串)。可以通过 tf.get_collection 方法来获取不同名称的集合。

默认情况下,每个变量会被存放在 tf.GraphKeys.GLOBAL_VARIABLEStf.GraphKeys.TRAINABLE_VARIABLES 这两个集合中。此外,也可以通过 tf.add_to_collection 手动添加变量到集合中。

代码语言:javascript
复制
print("%s: \n%s\n" % (tf.GraphKeys.GLOBAL_VARIABLES, tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
print("%s: \n%s\n" % (tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))

# 手动添加变量 w1/b1 到集合 my_col 中
tf.add_to_collection("my_col", w1)
tf.add_to_collection("my_col", b1)
print("my_col: \n%s" % tf.get_collection("my_col"))
代码语言:javascript
复制
variables: 
[<tf.Variable 'w1:0' shape=(2, 5) dtype=float32_ref>, <tf.Variable 'w1_1:0' shape=(2, 5) dtype=float32_ref>, <tf.Variable 'scope/b:0' shape=(2, 5) dtype=float32_ref>, <tf.Variable 'v:0' shape=(1,) dtype=float32_ref>]

trainable_variables: 
[<tf.Variable 'w1:0' shape=(2, 5) dtype=float32_ref>, <tf.Variable 'w1_1:0' shape=(2, 5) dtype=float32_ref>, <tf.Variable 'scope/b:0' shape=(2, 5) dtype=float32_ref>, <tf.Variable 'v:0' shape=(1,) dtype=float32_ref>]

my_col: 
[<tf.Variable 'w1:0' shape=(2, 5) dtype=float32_ref>, <tf.Variable 'scope/b:0' shape=(2, 5) dtype=float32_ref>]

初始化变量

在使用变量之前,它必须被初始化。在低级TensorFlow API中编程(需要自己明确地创建图和会话),必须显式初始化变量。大多数高级框架,如tf.contrib.slim、tf.estimator.Estimator和Keras在训练模型之前自动初始化变量。

要在训练开始前一次初始化所有可训练的变量,可以调用 tf.global_variables_initializer() 来完成。如果只想初始化某个变量,可以调用变量的 .initializer属性。在初始化变量之前,可以使用 tf.report_uninitialized_variables() 来查看尚未被初始化的变量的名称。

代码语言:javascript
复制
with tf.Session() as sess:
    # 查看当前未初始化的变量名称
    print(sess.run(tf.report_uninitialized_variables()))
    # 初始化变量 w1
    sess.run(w1.initializer)
    print(sess.run(tf.report_uninitialized_variables()))
    # 初始化所有变量
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.report_uninitialized_variables()))
代码语言:javascript
复制
[b'w1' b'w1_1' b'scope/b' b'v']
[b'w1_1' b'scope/b' b'v']
[]

可以看到,调用变量的 initializer 属性只会初始化该变量,调用 tf.global_variables_initializer() 会初始化所有变量。

使用变量

在 TensorFlow 使用变量时,只需要像对待普通的张量(Tensor)来对待它就可以了。对变量进行操作后,生成的结果会是一个张量。

代码语言:javascript
复制
o1 = w1 + b1

with tf.Session() as sess:
    # 使用变量前需要进行初始化,这里可以不用进行初始化,
    # 因为在上一节的 "初始化变量" 时已经初始过了,这里只是为了保证流程完整,所以加上了。
    sess.run(tf.global_variables_initializer())
    print(sess.run(o1))
代码语言:javascript
复制
[[1.9957633 2.0741634 1.6809903 1.7901803 2.1854873]
 [1.8982319 2.2382631 1.7602906 1.9434371 2.0995731]]

作者:无邪,个人博客:脑洞大开,专注于机器学习研究。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-02-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 脑洞科技栈 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 变量
    • 创建变量
      • 设备放置
        • 变量集合
          • 初始化变量
            • 使用变量
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档