首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >tf.get_variable()函数

tf.get_variable()函数

作者头像
狼啸风云
修改2022-09-04 21:41:51
修改2022-09-04 21:41:51
5.7K0
举报

如果你定义的变量名称在之前已被定义过,则TensorFlow 会引发异常。可使用tf.get_variable( ) 函数代替tf.Variable( )。如果变量存在,函数tf.get_variable( ) 会返回现有的变量。如果变量不存在,会根据给定形状和初始值创建变量。

代码语言:javascript
复制
tf.get_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    collections=None,
    caching_device=None,
    partitioner=None,
    validate_shape=True,
    use_resource=None,
    custom_getter=None,
    constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.VariableAggregation.NONE
)

下面是一个基本的例子:

代码语言:javascript
复制
def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v

v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
assert v1 == v2

如果初始化器为None(缺省值),则将使用在变量范围中传递的缺省初始化器。如果没有,则使用glorot_uniform_initializer。初始化器也可以是一个张量,在这种情况下,变量初始化为这个值和形状。类似地,如果正则化器为None(默认值),则将使用在变量范围中传递的默认正则化器(如果也是None,则默认情况下不执行正则化)。如果提供了分区程序,则返回一个PartitionedVariable。以张量的形式访问这个对象,返回沿分区轴连接的切分。可以使用一些有用的分区器。参见,例如,variable_axis_size_partitioner和min_max_variable_partitioner。

参数:

  • name:新变量或现有变量的名称。
  • shape:新变量或现有变量的形状。
  • dtype:新变量或现有变量的类型(默认为DT_FLOAT)。
  • initializer:如果创建了变量的初始化器。可以是初始化器对象,也可以是张量。如果它是一个张量,它的形状必须是已知的,除非validate_shape是假的。
  • regularizer:A(张量->张量或无)函数;将其应用于新创建的变量的结果将添加到集合tf.GraphKeys中。正则化-损耗,可用于正则化。
  • trainable:如果为真,也将变量添加到图形集合GraphKeys中。TRAINABLE_VARIABLES(见tf.Variable)。
  • collections:要向其中添加变量的图形集合键的列表。默认为[GraphKeys.GLOBAL_VARIABLES](见tf.Variable)。
  • caching_device:可选的设备字符串或函数,描述变量应该缓存到什么地方以便读取。变量的设备的默认值。如果没有,则缓存到另一个设备上。典型的用途是在使用该变量的操作系统所在的设备上缓存,通过Switch和其他条件语句来重复复制。
  • partitioner:可选的callable,它接受要创建的变量的完全定义的TensorShape和dtype,并返回每个轴的分区列表(目前只能分区一个轴)。
  • validate_shape:如果为False,则允许用一个未知形状的值初始化变量。如果为真,默认情况下,initial_value的形状必须是已知的。要使用它,初始化器必须是一个张量,而不是初始化器对象。
  • use_resource:如果为False,则创建一个常规变量。如果为真,则创建一个具有定义良好语义的实验性资源变量。默认值为False(稍后将更改为True)。当启用紧急执行时,该参数总是强制为真。
  • custom_getter: Callable,它将true getter作为第一个参数,并允许覆盖内部get_variable方法。custom_getter的签名应该与这个方法的签名相匹配,但是未来最可靠的版本将允许更改:def custom_getter(getter、*args、**kwargs)。还允许直接访问所有get_variable参数:def custom_getter(getter、name、*args、**kwargs)。一个简单的身份自定义getter,简单地创建变量与修改的名称是:
  • constraint:优化器更新后应用于变量的可选投影函数(例如,用于为层权重实现规范约束或值约束)。函数必须将表示变量值的未投影张量作为输入,并返回投影值的张量(其形状必须相同)。在进行异步分布式培训时使用约束并不安全。
  • synchronization:指示何时聚合分布式变量。可接受的值是在tf.VariableSynchronization类中定义的常量。默认情况下,同步设置为AUTO,当前分发策略选择何时同步。如果同步设置为ON_READ,则不能将trainable设置为True。
  • aggregation:指示如何聚合分布式变量。可接受的值是在tf.VariableAggregation类中定义的常量。

返回值:

  • 创建的或现有的变量(或PartitionedVariable,如果使用了分区器)。

可能产生的异常:

  • ValueError: when creating a new variable and shape is not declared, when violating reuse during variable creation, or when initializer dtype and dtype don't match. Reuse is set inside
  • variable_scope
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019年07月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档