参考 tf.train.Coordinator - 云+社区 - 腾讯云
目录
一个类,初始化器,它生成具有常量值的张量。由新张量的期望shape后面的参数value指定。参数value
可以是常量值,也可以是类型为dtype的值列表。如果value是一个列表,那么列表的长度必须小于或等于由张量的期望形状所暗示的元素的数量。如果值中的元素总数小于张量形状所需的元素数,则值中的最后一个元素将用于填充剩余的元素。如果值中元素的总数大于张量形状所需元素的总数,初始化器将产生一个ValueError。
参数:
可能产生的异常:
TypeError
: If the input value
is not one of the expected types.示例:下面的示例可以使用numpy重写。ndarray代替了值列表,甚至重新构造了值列表,如值列表初始化下面的两行注释所示。
import numpy as np
import tensorflow as tf
value = [0, 1, 2, 3, 4, 5, 6, 7]
# value = np.array(value)
# value = value.reshape([2, 4])
init = tf.constant_initializer(value)
print('fitting shape:')
with tf.Session():
x = tf.get_variable('x', shape=[2, 4], initializer=init)
x.initializer.run()
print(x.eval())
Output:
-------------------
fitting shape:
[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]]
-------------------
print('larger shape:')
with tf.Session():
x = tf.get_variable('x', shape=[3, 4], initializer=init)
x.initializer.run()
print(x.eval())
Output:
-------------------
larger shape:
[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]
[ 7. 7. 7. 7.]]
-------------------
print('smaller shape:')
with tf.Session():
x = tf.get_variable('x', shape=[2, 3], initializer=init)
Error:
-----------------------------------------------------------------------------------------
Traceback (most recent call last):
File "D:/tensorflow_learning/test.py", line 11, in <module>
x = tf.get_variable('x', shape=[2, 3], initializer=init)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1484, in get_variable
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1234, in get_variable
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 538, in get_variable
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 492, in _true_getter
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 920, in _get_single_variable
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 145, in __call__
return cls._variable_call(*args, **kwargs)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 141, in _variable_call
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 120, in <lambda>
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2441, in default_variable_creator
expected_shape=expected_shape, import_scope=import_scope)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 147, in __call__
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1104, in __init__
constraint=constraint)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1212, in _init_from_args
initial_value(), name="initial_value", dtype=dtype)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 894, in <lambda>
shape.as_list(), dtype=dtype, partition_info=partition_info)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\init_ops.py", line 219, in __call__
self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\constant_op.py", line 207, in constant
value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 497, in make_tensor_proto
(shape_size, nparray.size))
ValueError: Too many elements provided. Needed at most 6, but received 8
-----------------------------------------------------------------------------------------
print('shape verification:')
init_verify = tf.constant_initializer(value, verify_shape=True)
with tf.Session():
x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)
Error:
-----------------------------------------------------------------------------------------
Traceback (most recent call last):
File "D:/tensorflow_learning/test.py", line 12, in <module>
x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1484, in get_variable
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1234, in get_variable
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 538, in get_variable
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 492, in _true_getter
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 920, in _get_single_variable
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 145, in __call__
return cls._variable_call(*args, **kwargs)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 141, in _variable_call
aggregation=aggregation)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 120, in <lambda>
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2441, in default_variable_creator
expected_shape=expected_shape, import_scope=import_scope)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 147, in __call__
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1104, in __init__
constraint=constraint)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1212, in _init_from_args
initial_value(), name="initial_value", dtype=dtype)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 894, in <lambda>
shape.as_list(), dtype=dtype, partition_info=partition_info)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\init_ops.py", line 219, in __call__
self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\constant_op.py", line 207, in constant
value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 492, in make_tensor_proto
(tuple(shape), nparray.shape))
TypeError: Expected Tensor's shape: (3, 4), got (8,).
-----------------------------------------------------------------------------------------
1、__init__
__init__(
value=0,
dtype=tf.float32,
verify_shape=False
)
2、__call__
__call__(
shape,
dtype=None,
partition_info=None,
verify_shape=None
)
3、from_config
from_config(
cls,
config
)
从配置字典实例化初始化器。例子:
initializer = RandomUniform(-1, 1)
config = initializer.get_config()
initializer = RandomUniform.from_config(config)
参数:
config
: 一个Python字典。它通常是get_config的输出。返回:
4、get_config
get_config()