首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >get_collection和get_collection_ref有什么区别?

get_collection和get_collection_ref有什么区别?
EN

Stack Overflow用户
提问于 2017-12-01 02:42:24
回答 2查看 1.3K关注 0票数 2

我已经检查了这两个方法的文档,但是它们看起来是一样的,只不过get_collection可以接受一个额外的作用域参数。

代码语言:javascript
运行
复制
In [11]: aaa = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS)
In [12]: aaaa = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
In [13]: aaa == aaaa
Out[13]: True
In [14]: aaa is aaaa
Out[14]: False

两者有什么区别,什么时候使用哪一种?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-12-01 03:39:23

我看到了这种不同:

代码语言:javascript
运行
复制
In [24]: w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
In [25]: params = tf.get_collection_ref(tf.GraphKeys.WEIGHTS)
In [26]: params
Out[26]: [<tf.Variable 'Variable_1:0' shape=(3,) dtype=float32_ref>]
In [27]: del params[:]
In [28]: tf.get_collection_ref(tf.GraphKeys.WEIGHTS)
Out[28]: []
In [29]: w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
In [30]: params = tf.get_collection(tf.GraphKeys.WEIGHTS)
In [31]: params
Out[31]: [<tf.Variable 'Variable_2:0' shape=(3,) dtype=float32_ref>]
In [32]: del params[:]
In [33]: tf.get_collection_ref(tf.GraphKeys.WEIGHTS)
Out[33]: [<tf.Variable 'Variable_2:0' shape=(3,) dtype=float32_ref>]

因此,get_collection只返回集合的值,但是get_collection_ref返回集合的引用,然后我可以通过删除它引用的返回变量来删除集合。

get_collection中的作用域参数用于根据作用域名称过滤变量。但是get_collection_ref不提供这样的功能。

票数 0
EN

Stack Overflow用户

发布于 2017-12-01 03:39:11

如果您不使用tf.get_collection,中的作用域参数,那么这两个方法在计算图中返回相同的集合。

get_collection,不带作用域,在不应用任何筛选操作的情况下获取集合中的每个值。

当指定作用域参数时,集合的每个元素都会被该作用域过滤。

考虑下面的示例代码,它返回与示例代码相同的内容(这里,我使用tf.GraphKeys.TRAINABLE_VARIABLES作为演示的目的)。

代码语言:javascript
运行
复制
with tf.variable_scope("foo1"):
    v1 = tf.get_variable("v1", [1])
with tf.variable_scope("foo2"):
    v2 = tf.get_variable("v2", [1])

aaa = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
aaaa = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

print(aaa == aaaa) #True
print(aaa is aaaa) #False

由于没有指定作用域,所以变量所引用的aaaaaa对象是相等的。

但是,如果使用指定的作用域运行以下示例代码,

代码语言:javascript
运行
复制
with tf.variable_scope("foo1"):
    v1 = tf.get_variable("v1", [1])
with tf.variable_scope("foo2"):
    v2 = tf.get_variable("v2", [1])

aaa = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
aaaa = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'foo1')

print(aaa == aaaa) # False
print(aaa is aaaa) # False

由于指定了作用域,因此变量引用的aaaaaa对象并不相等。

此外,在这两种情况下,aaaaaa都没有指向同一个对象。因此,aaa是aaa,在两种情况下都是False ( in Python?)。

希望这能有所帮助。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47585864

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档