前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tf.constant_initializer

tf.constant_initializer

作者头像
狼啸风云
发布2022-10-31 16:28:05
4460
发布2022-10-31 16:28:05
举报
文章被收录于专栏:计算机视觉理论及其实现

参考  tf.train.Coordinator - 云+社区 - 腾讯云

目录

一、使用方法

二、类中的函数

1、__init__

2、__call__

3、from_config

4、get_config


一、使用方法

一个类,初始化器,它生成具有常量值的张量。由新张量的期望shape后面的参数value指定。参数value可以是常量值,也可以是类型为dtype的值列表。如果value是一个列表,那么列表的长度必须小于或等于由张量的期望形状所暗示的元素的数量。如果值中的元素总数小于张量形状所需的元素数,则值中的最后一个元素将用于填充剩余的元素。如果值中元素的总数大于张量形状所需元素的总数,初始化器将产生一个ValueError。

参数:

  • value: Python标量、值列表或元组,或n维Numpy数组。初始化变量的所有元素将在value参数中设置为对应的值。
  • dtype: 数据类型。
  • verify_shape: 布尔值,用于验证value的形状。如果为真,如果value的形状与初始化张量的形状不兼容,初始化器将抛出错误。

可能产生的异常:

  • TypeError: If the input value is not one of the expected types.

示例:下面的示例可以使用numpy重写。ndarray代替了值列表,甚至重新构造了值列表,如值列表初始化下面的两行注释所示。

代码语言:javascript
复制
import numpy as np
import tensorflow as tf
value = [0, 1, 2, 3, 4, 5, 6, 7]
# value = np.array(value)
# value = value.reshape([2, 4])
init = tf.constant_initializer(value)

print('fitting shape:')
with tf.Session():
    x = tf.get_variable('x', shape=[2, 4], initializer=init)
    x.initializer.run()
    print(x.eval())

Output:
-------------------
fitting shape:
[[ 0.  1.  2.  3.]
[ 4.  5.  6.  7.]]
-------------------


print('larger shape:')
with tf.Session():
   x = tf.get_variable('x', shape=[3, 4], initializer=init)
   x.initializer.run()
   print(x.eval())

Output:
-------------------
larger shape:
[[ 0.  1.  2.  3.]
[ 4.  5.  6.  7.]
[ 7.  7.  7.  7.]]
-------------------


print('smaller shape:')
with tf.Session():
    x = tf.get_variable('x', shape=[2, 3], initializer=init)


Error:
-----------------------------------------------------------------------------------------
Traceback (most recent call last):
  File "D:/tensorflow_learning/test.py", line 11, in <module>
    x = tf.get_variable('x', shape=[2, 3], initializer=init)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1484, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1234, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 538, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 492, in _true_getter
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 920, in _get_single_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 145, in __call__
    return cls._variable_call(*args, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 141, in _variable_call
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 120, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2441, in default_variable_creator
    expected_shape=expected_shape, import_scope=import_scope)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 147, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1104, in __init__
    constraint=constraint)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1212, in _init_from_args
    initial_value(), name="initial_value", dtype=dtype)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 894, in <lambda>
    shape.as_list(), dtype=dtype, partition_info=partition_info)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\init_ops.py", line 219, in __call__
    self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\constant_op.py", line 207, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 497, in make_tensor_proto
    (shape_size, nparray.size))
ValueError: Too many elements provided. Needed at most 6, but received 8
-----------------------------------------------------------------------------------------


print('shape verification:')
init_verify = tf.constant_initializer(value, verify_shape=True)
with tf.Session():
    x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)

Error:
-----------------------------------------------------------------------------------------
Traceback (most recent call last):
  File "D:/tensorflow_learning/test.py", line 12, in <module>
    x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1484, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1234, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 538, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 492, in _true_getter
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 920, in _get_single_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 145, in __call__
    return cls._variable_call(*args, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 141, in _variable_call
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 120, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2441, in default_variable_creator
    expected_shape=expected_shape, import_scope=import_scope)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 147, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1104, in __init__
    constraint=constraint)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1212, in _init_from_args
    initial_value(), name="initial_value", dtype=dtype)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 894, in <lambda>
    shape.as_list(), dtype=dtype, partition_info=partition_info)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\init_ops.py", line 219, in __call__
    self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\constant_op.py", line 207, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 492, in make_tensor_proto
    (tuple(shape), nparray.shape))
TypeError: Expected Tensor's shape: (3, 4), got (8,).
-----------------------------------------------------------------------------------------

二、类中的函数

1、__init__

代码语言:javascript
复制
__init__(
    value=0,
    dtype=tf.float32,
    verify_shape=False
)

2、__call__

代码语言:javascript
复制
__call__(
    shape,
    dtype=None,
    partition_info=None,
    verify_shape=None
)

3、from_config

代码语言:javascript
复制
from_config(
    cls,
    config
)

从配置字典实例化初始化器。例子:

代码语言:javascript
复制
initializer = RandomUniform(-1, 1)
config = initializer.get_config()
initializer = RandomUniform.from_config(config)

参数:

  • config: 一个Python字典。它通常是get_config的输出。

返回:

  • 一个初始化后的实例。

4、get_config

代码语言:javascript
复制
get_config()
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-10-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、使用方法
  • 二、类中的函数
    • 1、__init__
      • 2、__call__
        • 3、from_config
          • 4、get_config
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档