tf.one_hot(
indices,
depth,
on_value=None,
off_value=None,
axis=None,
dtype=None,
name=None
)
返回一个独热张量。 索引中由索引表示的位置取值on_value,而所有其他位置取值off_value。on_value和off_value必须具有匹配的数据类型。如果还提供了dtype,则它们必须与dtype指定的数据类型相同。如果没有提供on_value,它将默认为值1,类型为dtype。如果没有提供off_value,它将默认值为0,类型为dtype。如果输入索引的秩为N,那么输出的秩为N+1。新轴是在维度轴上创建的(缺省值:新轴附加在末尾)。如果索引是标量,则输出形状将是长度深度向量。如果索引是长度特征向量,则输出形状为:
features x depth if axis == -1
depth x features if axis == 0
如果索引是一个形状为[batch, features]的矩阵(batch),则输出形状为:
batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0
如果没有提供dtype,它将尝试假设数据类型为on_value或off_value(如果传入了一个或两个值)。如果不提供on_value、off_value或dtype, dtype将默认为tf.float32。注意:如果需要非数值数据类型输出(tf)。字符串,特遣部队。bool等),on_value和off_value都必须提供给one_hot。
例如:
indices = [0, 1, 2]
depth = 3
tf.one_hot(indices, depth) # output: [3 x 3]
# [[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]]
indices = [0, 2, -1, 1]
depth = 3
tf.one_hot(indices, depth,
on_value=5.0, off_value=0.0,
axis=-1) # output: [4 x 3]
# [[5.0, 0.0, 0.0], # one_hot(0)
# [0.0, 0.0, 5.0], # one_hot(2)
# [0.0, 0.0, 0.0], # one_hot(-1)
# [0.0, 5.0, 0.0]] # one_hot(1)
indices = [[0, 2], [1, -1]]
depth = 3
tf.one_hot(indices, depth,
on_value=1.0, off_value=0.0,
axis=-1) # output: [2 x 2 x 3]
# [[[1.0, 0.0, 0.0], # one_hot(0)
# [0.0, 0.0, 1.0]], # one_hot(2)
# [[0.0, 1.0, 0.0], # one_hot(1)
# [0.0, 0.0, 0.0]]] # one_hot(-1)
参数:
返回值:
异常:
TypeError
: If dtype of either on_value
or off_value
don't match dtype
TypeError
: If dtype of on_value
and off_value
don't match one another原链接: https://tensorflow.google.cn/versions/r1.9/api_docs/python/tf/one_hot?hl=en