共享变量 tensorflow解读

正文共3707个字,预计阅读时间13分钟。

你可以在怎么使用变量中所描述的方式来创建,初始化,保存及加载单一的变量.但是当创建复杂的模块时,通常你需要共享大量变量集并且如果你还想在同一个地方初始化这所有的变量,我们又该怎么做呢.本教程就是演示如何使用tf.variable_scope() 和tf.get_variable()两个方法来实现这一点.

问题

假设你为图片过滤器创建了一个简单的模块,和我们的卷积神经网络教程模块相似,但是这里包括两个卷积(为了简化实例这里只有两个).如果你仅使用tf.Variable变量,那么你的模块就如怎么使用变量里面所解释的是一样的模块.

 1def my_image_filter(input_images):
 2conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
 3name="conv1_weights")
 4conv1_biases = tf.Variable(tf.zeros([32]), name="conv1_biases")
 5conv1 = tf.nn.conv2d(input_images, conv1_weights,
 6strides=[1, 1, 1, 1], padding='SAME')
 7relu1 = tf.nn.relu(conv1 + conv1_biases)
 8conv2_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
 9name="conv2_weights")
10conv2_biases = tf.Variable(tf.zeros([32]), name="conv2_biases")
11conv2 = tf.nn.conv2d(relu1, conv2_weights,
12strides=[1, 1, 1, 1], padding='SAME')
13return tf.nn.relu(conv2 + conv2_biases)

你很容易想到,模块集很快就比一个模块变得更为复杂,仅在这里我们就有了四个不同的变量:conv1_weights,conv1_biases, conv2_weights, 和conv2_biases. 当我们想重用这个模块时问题还在增多.假设你想把你的图片过滤器运用到两张不同的图片, image1和image2.你想通过拥有同一个参数的同一个过滤器来过滤两张图片,你可以调用my_image_filter()两次,但是这会产生两组变量.

1# First call creates one set of variables.
2result1 = my_image_filter(image1)
3# Another set is created in the second call.
4result2 = my_image_filter(image2)

通常共享变量的方法就是在单独的代码块中来创建他们并且通过使用他们的函数.如使用字典的例子:

 1 variables_dict = {
 2"conv1_weights": tf.Variable(tf.random_normal([5, 5, 32, 32]),
 3name="conv1_weights")
 4"conv1_biases": tf.Variable(tf.zeros([32]), name="conv1_biases")
 5... etc. ...
 6}
 7def my_image_filter(input_images, variables_dict):
 8conv1 = tf.nn.conv2d(input_images, variables_dict["conv1_weights"],
 9strides=[1, 1, 1, 1], padding='SAME')
10relu1 = tf.nn.relu(conv1 + variables_dict["conv1_biases"])
11conv2 = tf.nn.conv2d(relu1, variables_dict["conv2_weights"],
12strides=[1, 1, 1, 1], padding='SAME')
13return tf.nn.relu(conv2 + variables_dict["conv2_biases"])
14# The 2 calls to my_image_filter() now use the same variables
15result1 = my_image_filter(image1, variables_dict)
16result2 = my_image_filter(image2, variables_dict)

虽然使用上面的方式创建变量是很方便的,但是在这个模块代码之外却破坏了其封装性:

  • 在构建试图的代码中标明变量的名字,类型,形状来创建.
  • 当代码改变了,调用的地方也许就会产生或多或少或不同类型的变量.

解决此类问题的方法之一就是使用类来创建模块,在需要的地方使用类来小心地管理他们需要的变量. 一个更高明的做法,不用调用类,而是利用TensorFlow 提供了变量作用域 机制,当构建一个视图时,很容易就可以共享命名过的变量.

变量作用域实例

变量作用域机制在TensorFlow中主要由两部分组成:

  • tf.get_variable(<name>, <shape>, <initializer>): 通过所给的名字创建或是返回一个变量.
  • tf.variable_scope(<scope_name>): 通过 tf.get_variable()为变量名指定命名空间.

方法 tf.get_variable() 用来获取或创建一个变量,而不是直接调用tf.Variable.它采用的不是像`tf.Variable这样直接获取值来初始化的方法.一个初始化就是一个方法,创建其形状并且为这个形状提供一个张量.这里有一些在TensorFlow中使用的初始化变量:

  • tf.constant_initializer(value) 初始化一切所提供的值,
  • tf.random_uniform_initializer(a, b)从a到b均匀初始化,
  • tf.random_normal_initializer(mean, stddev) 用所给平均值和标准差初始化均匀分布.

为了了解tf.get_variable()怎么解决前面所讨论的问题,让我们在单独的方法里面创建一个卷积来重构一下代码,命名为conv_relu:

 1def conv_relu(input, kernel_shape, bias_shape):
 2# Create variable named "weights".
 3weights = tf.get_variable("weights", kernel_shape,
 4initializer=tf.random_normal_initializer())
 5# Create variable named "biases".
 6biases = tf.get_variable("biases", bias_shape,
 7initializer=tf.constant_intializer(0.0))
 8conv = tf.nn.conv2d(input, weights,
 9strides=[1, 1, 1, 1], padding='SAME')
10return tf.nn.relu(conv + biases)

这个方法中用了"weights" 和"biases"两个简称.而我们更偏向于用conv1 和 conv2这两个变量的写法,但是不同的变量需要不同的名字.这就是tf.variable_scope() 变量起作用的地方.他为变量指定了相应的命名空间.

1def my_image_filter(input_images):
2with tf.variable_scope("conv1"):
3# Variables created here will be named "conv1/weights", "conv1/biases".
4relu1 = conv_relu(input_images, [5, 5, 32, 32], [32])
5with tf.variable_scope("conv2"):
6# Variables created here will be named "conv2/weights", "conv2/biases".
7return conv_relu(relu1, [5, 5, 32, 32], [32])

现在,让我们看看当我们调用 my_image_filter() 两次时究竟会发生了什么.

1result1 = my_image_filter(image1)
2result2 = my_image_filter(image2)
3# Raises ValueError(... conv1/weights already exists ...)

就像你看见的一样,tf.get_variable()会检测已经存在的变量是否已经共享.如果你想共享他们,你需要像下面使用的一样,通过reuse_variables()这个方法来指定.

1with tf.variable_scope("image_filters") as scope:
2result1 = my_image_filter(image1)
3scope.reuse_variables()
4result2 = my_image_filter(image2)

用这种方式来共享变量是非常好的,轻量级而且安全.

变量作用域是怎么工作的?

理解 tf.get_variable()

为了理解变量作用域,首先完全理解tf.get_variable()是怎么工作的是很有必要的. 通常我们就是这样调用tf.get_variable 的.

1v = tf.get_variable(name, shape, dtype, initializer)

此调用做了有关作用域的两件事中的其中之一,方法调入.总的有两种情况.

  • 情况1:当tf.get_variable_scope().reuse == False时,作用域就是为创建新变量所设置的.

这种情况下,v将通过tf.Variable所提供的形状和数据类型来重新创建.创建变量的全称将会由当前变量作用域名+所提供的名字所组成,并且还会检查来确保没有任何变量使用这个全称.如果这个全称已经有一个变量使用了,那么方法将会抛出ValueError错误.如果一个变量被创建,他将会用initializer(shape)进行初始化.比如:

1with tf.variable_scope("foo"):
2v = tf.get_variable("v", [1])
3assert v.name == "foo/v:0"
  • 情况1:当tf.get_variable_scope().reuse == True时,作用域是为重用变量所设置

这种情况下,调用就会搜索一个已经存在的变量,他的全称和当前变量的作用域名+所提供的名字是否相等.如果不存在相应的变量,就会抛出ValueError 错误.如果变量找到了,就返回这个变量.如下:

1with tf.variable_scope("foo"):
2v = tf.get_variable("v", [1])
3with tf.variable_scope("foo", reuse=True):
4v1 = tf.get_variable("v", [1])
5assert v1 == v

tf.variable_scope() 基础

知道tf.get_variable()是怎么工作的,使得理解变量作用域变得很容易.变量作用域的主方法带有一个名称,它将会作为前缀用于变量名,并且带有一个重用标签来区分以上的两种情况.嵌套的作用域附加名字所用的规则和文件目录的规则很类似:

1with tf.variable_scope("foo"):
2with tf.variable_scope("bar"):
3v = tf.get_variable("v", [1])
4assert v.name == "foo/bar/v:0"

当前变量作用域可以用tf.get_variable_scope()进行检索并且reuse 标签可以通过调用tf.get_variable_scope().reuse_variables()设置为True .

1with tf.variable_scope("foo"):
2v = tf.get_variable("v", [1])
3tf.get_variable_scope().reuse_variables()
4v1 = tf.get_variable("v", [1])
5assert v1 == v

注意你不能设置reuse标签为False.其中的原因就是允许改写创建模块的方法.想一下你前面写得方法my_image_filter(inputs).有人在变量作用域内调用reuse=True 是希望所有内部变量都被重用.如果允许在方法体内强制执行reuse=False,将会打破内部结构并且用这种方法使得很难再共享参数.

即使你不能直接设置 reuse 为 False ,但是你可以输入一个重用变量作用域,然后就释放掉,就成为非重用的变量.当打开一个变量作用域时,使用reuse=True 作为参数是可以的.但也要注意,同一个原因,reuse 参数是不可继承.所以当你打开一个重用变量作用域,那么所有的子作用域也将会被重用.

 1with tf.variable_scope("root"):
 2# At start, the scope is not reusing.
 3assert tf.get_variable_scope().reuse == False
 4with tf.variable_scope("foo"):
 5# Opened a sub-scope, still not reusing.
 6assert tf.get_variable_scope().reuse == False
 7with tf.variable_scope("foo", reuse=True):
 8# Explicitly opened a reusing scope.
 9assert tf.get_variable_scope().reuse == True
10with tf.variable_scope("bar"):
11    # Now sub-scope inherits the reuse flag.
12    assert tf.get_variable_scope().reuse == True
13# Exited the reusing scope, back to a non-reusing one.
14assert tf.get_variable_scope().reuse == False

获取变量作用域

在上面的所有例子中,我们共享参数只因为他们的名字是一致的,那是因为我们开启一个变量作用域重用时刚好用了同一个字符串.在更复杂的情况,他可以通过变量作用域对象来使用,而不是通过依赖于右边的名字来使用.为此,变量作用域可以被获取并使用,而不是仅作为当开启一个新的变量作用域的名字.

1with tf.variable_scope("foo") as foo_scope:
2v = tf.get_variable("v", [1])
3with tf.variable_scope(foo_scope)
4w = tf.get_variable("w", [1])
5with tf.variable_scope(foo_scope, reuse=True)
6v1 = tf.get_variable("v", [1])
7w1 = tf.get_variable("w", [1])
8assert v1 == v
9assert w1 == w

当开启一个变量作用域,使用一个预先已经存在的作用域时,我们会跳过当前变量作用域的前缀而直接成为一个完全不同的作用域.这就是我们做得完全独立的地方.

1with tf.variable_scope("foo") as foo_scope:
2assert foo_scope.name == "foo"
3with tf.variable_scope("bar")
4with tf.variable_scope("baz") as other_scope:
5assert other_scope.name == "bar/baz"
6with tf.variable_scope(foo_scope) as foo_scope2:
7    assert foo_scope2.name == "foo"  # Not changed.

变量作用域中的初始化器

使用tf.get_variable()允许你重写方法来创建或者重用变量,并且可以被外部透明调用.但是如果我们想改变创建变量的初始化器那要怎么做呢?是否我们需要为所有的创建变量方法传递一个额外的参数呢?那在大多数情况下,当我们想在一个地方并且为所有的方法的所有的变量设置一个默认初始化器,那又改怎么做呢?为了解决这些问题,变量作用域可以携带一个默认的初始化器.他可以被子作用域继承并传递给tf.get_variable() 调用.但是如果其他初始化器被明确地指定,那么他将会被重写.

 1with tf.variable_scope("foo", initializer=tf.constant_initializer(0.4)):
 2v = tf.get_variable("v", [1])
 3assert v.eval() == 0.4  # Default initializer as set above.
 4w = tf.get_variable("w", [1], initializer=tf.constant_initializer(0.3)):
 5assert w.eval() == 0.3  # Specific initializer overrides the default.
 6with tf.variable_scope("bar"):
 7v = tf.get_variable("v", [1])
 8assert v.eval() == 0.4  # Inherited default initializer.
 9with tf.variable_scope("baz", initializer=tf.constant_initializer(0.2)):
10v = tf.get_variable("v", [1])
11assert v.eval() == 0.2  # Changed default initializer.

在tf.variable_scope()中ops的名称

我们讨论 tf.variable_scope 怎么处理变量的名字.但是又是如何在作用域中影响到 其他ops的名字的呢?ops在一个变量作用域的内部创建,那么他应该是共享他的名字,这是很自然的想法.出于这样的原因,当我们用with tf.variable_scope("name")时,这就间接地开启了一个tf.name_scope("name").比如:

1with tf.variable_scope("foo"):
2x = 1.0 + tf.get_variable("v", [1])
3assert x.op.name == "foo/add"

名称作用域可以被开启并添加到一个变量作用域中,然后他们只会影响到ops的名称,而不会影响到变量.

1with tf.variable_scope("foo"):
2with tf.name_scope("bar"):
3v = tf.get_variable("v", [1])
4x = 1.0 + v
5assert v.name == "foo/v:0"
6assert x.op.name == "foo/bar/add"

当用一个引用对象而不是一个字符串去开启一个变量作用域时,我们就不会为ops改变当前的名称作用域.

使用实例

这里有一些指向怎么使用变量作用域的文件.特别是,他被大量用于 时间递归神经网络和sequence-to-sequence模型,

File

What's in it?

models/image/cifar10.py

图像中检测对象的模型.

models/rnn/rnn_cell.py

时间递归神经网络的元方法集.

models/rnn/seq2seq.py

为创建sequence-to-sequence模型的方法集.

原文:Sharing Variables 翻译:nb312校对:Wiki

原文链接:https://www.jianshu.com/u/74b632e3297c

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2018-06-01

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏java架构师

【SQL Server】系统学习之三:逻辑查询处理阶段-六段式

一、From阶段 针对连接说明: 1、笛卡尔积 2、on筛选器 插播:unknown=not unknuwn 缺失的值; 筛选器(on where having...

371110
来自专栏海纳周报

Python的三个问题

第一,以下程序的执行结果是什么? def foo(a = []): a.append(1) print a foo()foo() 第二,以下...

29670
来自专栏尾尾部落

[剑指offer] 旋转数组的最小数字

把一个数组最开始的若干个元素搬到数组的末尾,我们称之为数组的旋转。 输入一个非减排序的数组的一个旋转,输出旋转数组的最小元素。 例如数组{3,4,5,1,2}为...

10520
来自专栏Objective-C

Swift 基本语法03-"if let"和"guard let"

45140
来自专栏我的技术专栏

C++ 顺序容器基础知识总结

22650
来自专栏数据结构与算法

洛谷P3812 【模板】线性基

1 \leq n \leq 50, 0 \leq S_i \leq 2 ^ {50}1≤n≤50,0≤Si​≤250

7620
来自专栏瓜大三哥

HLS Lesson6-数据类型转换

1.整数数据类型 传统的C语言可以采用:数据类型 数据变量 赋值 int var = -1; ap_int<6> a_6bit_var_c = -22;//复制...

475100
来自专栏mathor

TRIE(4)

 这道题的大意是我们有一个网站,然后要配置规则,决定哪些IP能访问,哪些IP不能。这些规则大概长这个样子:

11340
来自专栏C语言及其他语言

C语言逆向之表达式短路分析及应用

大家在学习C语言过程中,可能会见到过一些这样的题,就是表达式短路,表达式短路主要体现在C语言中逻辑运算符&&和||。今天将对表达式短路的做逆向分析,来深入理解它...

25740
来自专栏禅林阆苑

LESS 学习demo 【原创】

LESS 学习demo Write By CS逍遥剑仙 我的主页: www.csxiaoyao.com GitHub: github.com/...

41290

扫码关注云+社区

领取腾讯云代金券