首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >理解variable_scope和name_scope在流量与变量共享中的作用

理解variable_scope和name_scope在流量与变量共享中的作用
EN

Stack Overflow用户
提问于 2016-03-26 15:58:43
回答 1查看 2.1K关注 0票数 0

我想在两个子图之间共享变量。更准确地说,我想要做一个fowolling操作:给定4个张量abcd和一个权重变量w,计算W*aW*bW*cW*d,但是在不同的子图中。我的代码如下:

代码语言:javascript
复制
def forward(inputs):
  w = tf.get_variable("weights", ...)
  return tf.matmult(w, inputs)

with tf.name_scope("group_1"):
  a = tf.placeholder(...)
  b = tf.placeholder(...)
  c = tf.placeholder(...)

  aa = forward(a)
  bb = forward(b)
  cc = forward(c)

with tf.name_scope("group_2):
  d = tf.placeholder(...)

  tf.get_variable_scope().reuse_variable()
  dd = forward(d)

这个例子似乎正在运行,但我不确定变量W是否被重用,特别是在group_1中,当我添加tf.get_variable_scope.reuse_variable()时,我看到一个错误,说明没有变量可以共享。当我在张图中可视化图形时,我确实在group_1子图中有几个group_1

EN

Stack Overflow用户

发布于 2016-09-10 14:12:26

下面的代码可以实现您想要的结果:

代码语言:javascript
复制
import tensorflow as tf

def forward(inputs):
    init = tf.random_normal_initializer()
    w = tf.get_variable("weights", shape=(3,2), initializer=init)
    return tf.matmul(w, inputs)

with tf.name_scope("group_1"):
    a = tf.placeholder(tf.float32, shape=(2, 3), name="a")
    b = tf.placeholder(tf.float32, shape=(2, 3), name="b")
    c = tf.placeholder(tf.float32, shape=(2, 3), name="c")
    with tf.variable_scope("foo", reuse=False):
        aa = forward(a)
    with tf.variable_scope("foo", reuse=True):
        bb = forward(b)
        cc = forward(c)

with tf.name_scope("group_2"):
    d = tf.placeholder(tf.float32, shape=(2, 3), name="d")
    with tf.variable_scope("foo", reuse=True):
        dd = forward(d)

init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    print(bb.eval(feed_dict={b: np.array([[1,2,3],[4,5,6]])}))
    for var in tf.all_variables():
        print(var.name)
        print(var.eval())

有几件重要的事情需要理解:

  • 除了用get_variable().创建的变量外,name_scope()还会影响所有操作系统
  • 要在作用域中放置变量,需要使用variable_scope()。例如,占位符abc实际上名为"group_1/a""group_1/b""group_1/c""group_1/d",但weights变量名为"foo/weights"。因此,名称范围中的get_variable("weights")"group_1"和变量作用域"foo"实际上都在查找"foo/weights"

如果您不确定存在哪些变量以及这些变量是如何命名的,则all_variables()函数非常有用。

票数 1
EN
查看全部 1 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/36237427

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档