tf.get_variable

获取一个已经存在的变量或者创建一个新的变量

get_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=True,
    collections=None,
    caching_device=None,
    partitioner=None,
    validate_shape=True,
    use_resource=None,
    custom_getter=None,
    constraint=None
)

Args:

  • name:新变量或现有变量的名称。
  • shape:新变量或现有变量的形状。
  • dtype:新变量或现有变量的类型(默认为DT_FLOAT)。
  • ininializer:如果创建了则用它来初始化变量。 initializer是变量初始化的方式,初始化的方式有以下几种:
    • tf.constant_initializer:常量初始化函数
    • tf.random_normal_initializer:正态分布
    • tf.truncated_normal_initializer:截取的正态分布
    • tf.random_uniform_initializer:均匀分布
    • tf.zeros_initializer:全部是0
    • tf.ones_initializer:全是1
    • tf.uniform_unit_scaling_initializer:满足均匀分布,但不影响输出数量级的随机值
  • regularizer:A(Tensor - > Tensor或None)函数;将它应用于新创建的变量的结果将添加到集合tf.GraphKeys.REGULARIZATION_LOSSES中,并可用于正则化。
  • trainable:如果为True,还将变量添加到图形集合GraphKeys.TRAINABLE_VARIABLES(参见tf.Variable)。
  • collections:要将变量添加到的图表集合列表。默认为[GraphKeys.GLOBAL_VARIABLES](参见tf.Variable)。
  • caching_device:可选的设备字符串或函数,描述变量应被缓存以供读取的位置。默认为Variable的设备。如果不是None,则在另一台设备上缓存。典型用法是在使用变量驻留的Ops的设备上进行缓存,以通过Switch和其他条件语句进行重复数据删除。
  • partitioner:可选callable,接受完全定义的TensorShape和要创建的Variable的dtype,并返回每个轴的分区列表(当前只能对一个轴进行分区)。
  • validate_shape:如果为False,则允许使用未知形状的值初始化变量。如果为True,则默认为initial_value的形状必须已知。
  • use_resource:如果为False,则创建常规变量。如果为true,则使用定义良好的语义创建实验性ResourceVariable。默认为False(稍后将更改为True)。在Eager模式下,此参数始终强制为True。
  • custom_getter:Callable,它将第一个参数作为true getter,并允许覆盖内部get_variable方法。 custom_getter的签名应与此方法的签名相匹配,但最适合未来的版本将允许更改:def custom_getter(getter,* args,** kwargs)。也允许直接访问所有get_variable参数:def custom_getter(getter,name,* args,** kwargs)。一个简单的身份自定义getter只需创建具有修改名称的变量是:python def custom_getter(getter,name,* args,** kwargs):return getter(name +’_suffix’,* args,** kwargs)

如果initializer初始化方法是None(默认值),则会使用variable_scope()中定义的initializer,如果也为None,则默认使用glorot_uniform_initializer,也可以使用其他的tensor来初始化,value,和shape与此tensor相同

正则化方法默认是None,如果不指定,只会使用variable_scope()中的正则化方式,如果也为None,则不使用正则化;

附: tf.truncated_narmal()和tf.truncated_naomal__initializer()的区别

  • tf.truncated_narmal(shape=[],mean=0,stddev=0.5)使用时必须制定shape,返回值是在截断的正态分布随机生成的指定shape的tensor
  • tf.truncated_normal_initializer(mean=0.stddev=0.5)调用返回一个initializer 类的一个实例(就是一个初始化器),不可指定shape,
import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  
  
a1 = tf.get_variable(name='a1', shape=[2,3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
a2 = tf.get_variable(name='a2', shape=[1], initializer=tf.constant_initializer(1))
a3 = tf.get_variable(name='a3', shape=[2,3], initializer=tf.ones_initializer())
 
with tf.Session() as sess:
	sess.run(tf.initialize_all_variables())
	print(sess.run(a1))
	print(sess.run(a2))
	print(sess.run(a3))

输出:

[[ 0.42299312 -0.25459203 -0.88605702]
 [ 0.22410156  1.34326422 -0.39722782]]
[ 1.]
[[ 1.  1.  1.]
 [ 1.  1.  1.]]

注意:不同的变量之间不能有相同的名字,除非你定义了variable_scope,这样才可以有相同的名字。

tf.Variable() 和tf.get_variable()区别

1、使用tf.Variable时,如果检测到命名冲突,系统会自己处理。使用tf.get_variable()时,系统不会处理冲突,而会报错

import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print w_1.name
print w_2.name
#输出
#w_1:0
#w_1_1:0
import tensorflow as tf

w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
#错误信息
#ValueError: Variable w_1 already exists, disallowed. Did
#you mean to set reuse=True in VarScope?

2、基于这两个函数的特性,当我们需要共享变量的时候,需要使用tf.get_variable()。在其他情况下,这两个的用法是一样的

import tensorflow as tf

with tf.variable_scope("scope1"):
    w1 = tf.get_variable("w1", shape=[])
    w2 = tf.Variable(0.0, name="w2")
with tf.variable_scope("scope1", reuse=True):
    w1_p = tf.get_variable("w1", shape=[])
    w2_p = tf.Variable(1.0, name="w2")

print(w1 is w1_p, w2 is w2_p)
#输出
#True  False

由于tf.Variable() 每次都在创建新对象,所有reuse=True 和它并没有什么关系。对于get_variable(),来说,如果已经创建的变量对象,就把那个对象返回,如果没有创建变量对象的话,就创建一个新的。

参考:https://blog.csdn.net/MrR1ght/article/details/81228087 https://blog.csdn.net/UESTC_C2_403/article/details/72327321 https://www.jianshu.com/p/2061b221cd8f https://www.w3cschool.cn/tensorflow_python/tensorflow_python-st6f2ez1.html https://blog.csdn.net/zz2230633069/article/details/81414330

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • tf.cast()

    参考:https://www.cnblogs.com/hezhiyao/p/8196587.html

    周小董
  • Tensorflow模型保存和读取tf.train.Saver

    然后,在训练循环中,定期调用 saver.save() 方法,向文件夹中写入包含了当前模型中所有可训练变量的 checkpoint 文件。

    周小董
  • tf.Variable()函数

    tf.Variable(initializer,name),参数initializer是初始化参数,name是可自定义的变量名称,用法如下:

    周小董
  • tf API 研读2:math

    TF API数学计算 tf...... :math (1)刚开始先给一个运行实例。         tf是基于图(Graph)的计算系统。而图的节点则是由操作(...

    MachineLP
  • TensorFlow指南(一)——上手TensorFlow

    http://blog.csdn.net/u011239443/article/details/79066094 TensorFlow是谷歌开源的深度学习库...

    用户1621453
  • tensorflow编程: Constants, Sequences, and Random Values

      注意: start 和 stop 参数都必须是 浮点型;     取值范围也包括了 stop; tf.lin_space 等同于 tf.lins...

    JNingWei
  • 深入理解TensorFlow中的tf.metrics算子

    本文翻译自Avoiding headaches with tf.metrics,原作者保留版权。

    机器学习算法工程师
  • [TensorFlow深度学习入门]实战六·用CNN做Kaggle比赛手写数字识别准确率99%+

    参考博客地址 本博客采用Lenet5实现,也包含TensorFlow模型参数保存与加载参考我的博文,实用性比较好。在训练集准确率99.85%,测试训练集准确率...

    小宋是呢
  • 深度学习在花椒直播中的应用—神经网络与协同过滤篇

    协同过滤(collaborative filtering)算法一经发明便在推荐系统中取得了非凡的成果。许多知名的系统早期都采用了协同过滤算法,例如Google ...

    石晓文
  • 独家 | 一文读懂TensorFlow基础

    本文长度为7196字,建议阅读10分钟 本文为你讲解如何使用Tensorflow进行机器学习和深度学习。 1. 前言 深度学习算法的成功使人工智能的研究和应用取...

    数据派THU

扫码关注云+社区

领取腾讯云代金券