[编程经验] Tensorflow中的共享变量机制小结

今天说一下tensorflow的变量共享机制,首先为什么会有变量共享机制? 这个还是要扯一下生成对抗网络GAN,我们知道GAN由两个网络组成,一个是生成器网络G,一个是判别器网络D。G的任务是由输入的隐变量z生成一张图像G(z)出来,D的任务是区分G(z)和训练数据中的真实的图像(real images)。所以这里D的输入就有2个,但是这两个输入是共享D网络的参数的,简单说,也就是权重和偏置。而TensorFlow的变量共享机制,正好可以解决这个问题。但是我现在不能确定,TF的这个机制是不是因为GAN的提出才有的,还是本身就存在。

所以变量共享的目的就是为了在对网络第二次使用的时候,可以使用同一套模型参数。TF中是由Variable_scope来实现的,下面我通过几个栗子,彻底弄明白到底该怎么使用,以及使用中会出现的错误。栗子来源于文档,然后我写了不同的情况,希望能帮到你。

# - * - coding:utf-8 - * -
import tensorflow as tf
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


def fc_variable():
    v1 = tf.Variable(
        initial_value=tf.random_normal(
            shape=[2, 3], mean=0., stddev=1.),
        dtype=tf.float32,
        name='variable_1')
    print v1
    print "- v1 - * " * 5
    return v1

"""
<tf.Variable 'variable_1:0' shape=(2, 3) dtype=float32_ref>
- v1 - * - v1 - * - v1 - * - v1 - * - v1 - * 
"""

def variable_value(variables):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # 如果没有这句会报错,所以tf在调用变量之前主要
        # 先初始化
        """
        tensorflow.python.framework.errors_impl.
        FailedPreconditionError: Attempting to use
         uninitialized value variable_1
        """
        print '- * - value: - * - ' * 3
        print sess.run(variables)
        """
        [[ 0.00556329  0.20311342 -0.79569227]
         [ 0.1700473   0.9499892  -0.46801034]]
        """


def fc_variable_scope():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=True):
        v1 = tf.get_variable("v")
        print v1.name

"""
foo/v:0
foo/w:0
foo/v:0
"""
# 解释:
# 这里说明v1和v的相同的,还有这里用的是
# get_variable定义的变量,这个和Variable
# 定义变量的区别是,如果变量存在get_variable
# 会获得他的值,如果不存在则创建变量


def fc_variable_scope_v2():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=False):
        v1 = tf.get_variable("v")
        print v1.name


"""
ValueError: Variable foo/v already exists, disallowed. 
Did you mean to set reuse=True in VarScope? Originally
 defined at:
"""
# 解释:
# 当reuse为False的时候由于v1在'fool'这个scope里面,
# 所以和v的name是一样的,而reuse为False,变量命名就起了冲突。


def fc_variable_scope_v3():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=True):
        v1 = tf.get_variable("u", [1])
        print v1.name


"""
ValueError: Variable foo/u does not exist, 
or was not created with tf.get_variable().
 Did you mean to set reuse=None in VarScope?
"""
# 解释:
# 当reuse为True时时候,而这里定义了新变量u,
# 之前不存在,这样也无法reuse。


def fc_variable_scope_v4():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=False):
        v1 = tf.get_variable("u")
        print v1.name

"""
ValueError: Shape of a new variable (foo/u)
 must be fully defined, but instead was <unknown>.

"""
# 解释:
# 这里reuse为Flase,但是定义新变量的时候,
# 必须define fully变量,也就是要指定变量
# 的shape或者初始值等。


def fc_variable_scope_v5():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print dir(v)
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=False):
        v1 = tf.get_variable("u", [1])
        print v1.name


"""
foo/v:0
foo/w:0
foo/u:0
"""
# 这样就没错了


def fc_variable_scope_v6():
    with tf.variable_scope("foo"):
        v1 = tf.Variable(tf.random_normal(
            shape=[2, 3], mean=0., stddev=1.),
            dtype=tf.float32, name='v1')
        print v1.name
        v2 = tf.get_variable("v2", [1])
        print v2.name

    with tf.variable_scope("foo", reuse=True):
        v3 = tf.get_variable('v2')
        print v3.name
        v4 = tf.get_variable('v1')
        print v4.name


"""
foo/v1:0
foo/v2:0
foo/v2:0

ValueError: Variable foo/v1 does not exist, or
 was not created with tf.get_variable(). Did 
 you mean to set reuse=None in VarScope?

"""

# 解释:
# 这里虽然reuse为True,但是v1是由Variable定义的,
# 不能被get。


def compare_name_and_variable_scope():
    with tf.name_scope("hello") as ns:
        arr1 = tf.get_variable(
            "arr1", shape=[2, 10], dtype=tf.float32)
        print (arr1.name)

    print " - * -" * 5
    with tf.variable_scope("hello") as vs:
        arr1 = tf.get_variable(
            "arr1", shape=[2, 10], dtype=tf.float32)
        print (arr1.name)

"""
arr1:0
 - * - - * - - * - - * - - * -
hello/arr1:0
"""
#解释:
# 这里除了name_scope和variable_scope不同,
# 其他都相同,但是从他们的name,也能看出来区别了。

if __name__ == "__main__":
    fc_variable_scope_v6()
    # # 需要测试那个函数,直接写在这里。

简单总结一下,今天的内容主要是变量定义的两种方法,Variable个get_variable,还有变量的范围以及reuse是什么鬼。通过几个栗子,应该明白了。

明天要说的是用TensorFlow实现Kmeans聚类,欢迎关注~

============End============

原文发布于微信公众号 - 机器学习和数学(ML_And_Maths)

原文发表时间:2017-07-22

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏专注研发

poj-3185-开关问题

牛一行20他们喝的水碗。碗可以那么(面向正确的为清凉水)或颠倒的(一个位置而没有水)。他们希望所有20个水碗那么,因此用宽鼻子翻碗。

11530
来自专栏Hadoop数据仓库

HAWQ + MADlib 玩转数据挖掘之(三)——向量

一、定义         这里不讨论向量严格的数学定义。在Madlib中,可以把向量简单理解为矩阵。矩阵是Madlib中数据的基本格式,当矩阵只有一维时,就是向...

247100
来自专栏崔庆才的专栏

Attention原理及TensorFlow AttentionWrapper源码解析

3.3K40
来自专栏Java 源码分析

平衡搜索树

2-3树 ​ 其实仔细来看2-3树好像是 B 树的一个特例,它规定了一个节点要么有一个 key 要么有两个 key。 如果有一个 key 那么他就有两个子...

32090
来自专栏小樱的经验随笔

Vijos P1497 立体图【模拟】

立体图 描述 小渊是个聪明的孩子,他经常会给周围的小朋友们讲些自己认为有趣的内容。最近,他准备给小朋友讲解立体图,请你帮他画出立体图。 小渊有一块面积为m*n的...

38160
来自专栏数据结构与算法

洛谷P2503 [HAOI2006]均分数据(模拟退火)

19900
来自专栏数据结构与算法

07:矩阵归零消减序列和

07:矩阵归零消减序列和 总时间限制: 1000ms 内存限制: 65536kB描述 给定一个n*n的矩阵(3 <= n <= 100,元素的值都是非负整数...

40060
来自专栏林冠宏的技术文章

opencv 之 icvCreateHidHaarClassifierCascade 分类器信息初始化函数部分详细代码注释。

请看注释。这个函数,是人脸识别主函数,里面出现过的函数之一,作用是初始化分类器的数据,就是一个xml文件的数据初始化。 1 static CvHidHaar...

239100
来自专栏来自地球男人的部落格

浅谈Attention-based Model【源码篇】

源码不可能每一条都详尽解释,主要在一些关键步骤上加了一些注释和少许个人理解,如有不足之处,请予指正。 计划分为三个部分: 浅谈Attention-bas...

333100
来自专栏机器学习算法原理与实践

用scikit-learn学习谱聚类

    在谱聚类(spectral clustering)原理总结中,我们对谱聚类的原理做了总结。这里我们就对scikit-learn中谱聚类的使用做一个总结。

27240

扫码关注云+社区

领取腾讯云代金券