首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >展平tensorflow中的批次

展平tensorflow中的批次
EN

Stack Overflow用户
提问于 2016-04-17 03:13:09
回答 2查看 26.8K关注 0票数 29

我有一个形状为[None, 9, 2]的tensorflow的输入(其中的None是批处理)。

要对其执行进一步的操作(例如matmul),我需要将其转换为[None, 18]形状。该怎么做呢?

EN

回答 2

Stack Overflow用户

发布于 2017-02-23 19:53:45

代码语言:javascript
复制
flat_inputs = tf.layers.flatten(inputs)
票数 18
EN

Stack Overflow用户

发布于 2016-04-17 07:06:10

您可以使用动态整形在运行时通过tf.batch获取批量维度的值,并将整个新维度集计算到tf.reshape中。这是一个在不知道列表长度的情况下将平面列表重塑为方阵的示例。

代码语言:javascript
复制
tf.reset_default_graph()
sess = tf.InteractiveSession("")
a = tf.placeholder(dtype=tf.int32)
# get [9]
ashape = tf.shape(a)
# slice the list from 0th to 1st position
ashape0 = tf.slice(ashape, [0], [1])
# reshape list to scalar, ie from [9] to 9
ashape0_flat = tf.reshape(ashape0, ())
# tf.sqrt doesn't support int, so cast to float
ashape0_flat_float = tf.to_float(ashape0_flat)
newshape0 = tf.sqrt(ashape0_flat_float)
# convert [3, 3] Python list into [3, 3] Tensor
newshape = tf.pack([newshape0, newshape0])
# tf.reshape doesn't accept float, so convert back to int
newshape_int = tf.to_int32(newshape)
a_reshaped = tf.reshape(a, newshape_int)
sess.run(a_reshaped, feed_dict={a: np.ones((9))})

您应该会看到

代码语言:javascript
复制
array([[1, 1, 1],
       [1, 1, 1],
       [1, 1, 1]], dtype=int32)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/36668542

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档