首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

tensorflow: Shapes and Shaping 探究

定义

Tensor Transformations - Shapes and Shaping: TensorFlow provides several operations that you can use to determine the shape of a tensor and change the shape of a tensor.

tensorflow提供了一些操作,让用户可以定义和修改tensor的形状


常用API

tf.shape

  以tensor形式,返回tensor形状。

tf.shape(input, name=None, out_type=tf.int32)

代码语言:javascript
复制
import tensorflow as tf

t = tf.constant(value=[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], dtype=tf.int32)
with tf.Session() as sess:
    print (sess.run(tf.shape(t)))
代码语言:javascript
复制
[2 2 3]

  另一种方法也可以的到类似答案:

代码语言:javascript
复制
import tensorflow as tf

t = tf.constant(value=[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], dtype=tf.int32)
print (t.shape)
代码语言:javascript
复制
(2, 2, 3)

tf.size

  以tensor形式,返回tensor元素总数。

tf.size(input, name=None, out_type=tf.int32)

代码语言:javascript
复制
import tensorflow as tf

t = tf.ones(shape=[2, 5, 10], dtype=tf.int32)

with tf.Session() as sess:
    print (sess.run(tf.size(t)))
代码语言:javascript
复制
100

tf.rank

  以tensor形式,返回tensor阶数。

tf.rank(input, name=None)

代码语言:javascript
复制
import tensorflow as tf

t = tf.constant(value=[[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], dtype=tf.int32)
with tf.Session() as sess:
    print (sess.run(tf.rank(t)))
代码语言:javascript
复制
3

tf.reshape

  以tensor形式,返回重新被塑形的tensor。

tf.reshape(tensor, shape, name=None)

代码语言:javascript
复制
import tensorflow as tf

t = tf.constant(value=[1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=tf.int32)
with tf.Session() as sess:
    print (sess.run(t))
    print
    print (sess.run(tf.reshape(t, [3, 3])))
代码语言:javascript
复制
[1 2 3 4 5 6 7 8 9]

[[1 2 3]
 [4 5 6]
 [7 8 9]]

tf.squeeze

  以tensor形式,返回移除指定维后的tensor。

tf.squeeze(input, axis=None, name=None, squeeze_dims=None)

  • axis=None 时: Removes dimensions of size 1 from the shape of a tensor. 将tensor中 维度为1所有维 全部移除
  • axis=[2, 4] 时: 将tensor中 维度为1第2维第4维 移除
代码语言:javascript
复制
import tensorflow as tf

t = tf.ones(shape=[1, 2, 1, 3, 1, 1], dtype=tf.int32)

with tf.Session() as sess:
    print (sess.run(tf.shape(t)))
    print
    print (sess.run(tf.shape(tf.squeeze(t))))
    print
    print (sess.run(tf.shape(tf.squeeze(t, axis=[2, 4]))))
代码语言:javascript
复制
[1 2 1 3 1 1]

[2 3]

[1 2 3 1]

tf.expand_dims

  以tensor形式,返回插入指定维后的tensor。

tf.expand_dims(input, axis=None, name=None, dim=None)

代码语言:javascript
复制
import tensorflow as tf

t = tf.ones(shape=[2, 3, 5], dtype=tf.int32)

with tf.Session() as sess:
    print (sess.run(tf.shape(t)))
    print
    print (sess.run(tf.shape(tf.expand_dims(t, 0))))
    print
    print (sess.run(tf.shape(tf.expand_dims(t, 1))))
    print
    print (sess.run(tf.shape(tf.expand_dims(t, 2))))
    print
    print (sess.run(tf.shape(tf.expand_dims(t, 3))))
    print
    print (sess.run(tf.shape(tf.expand_dims(t, -1))))
代码语言:javascript
复制
[2 3 5]

[1 2 3 5]

[2 1 3 5]

[2 3 1 5]

[2 3 5 1]

[2 3 5 1]


下一篇
举报
领券