TensorFlow 张量的索引切片方式和 NumPy 模块差不多。与此同时,TensorFlow2.X 也提供了一些比较高级的切片函数,比如:
tf.gather
、tf.gather_nd
和 tf.boolean_mask
;tf.slice
。相比于对张量进行不规则的切片提取的三个函数,tf.slice
的实现方式比较特殊,所以本文来详细的介绍 tf.slice
函数。
tf.slice(
input_, begin, size, name=None
)
tf.slice
函数主要有三个参数:
为了理解 tf.slice
函数的实现方式,首先创建一个形状为 (3, 2, 3) 的三维的张量 X。
import tensorflow as tf
X = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
print(X.shape) # (3, 2, 3)
我们知道 n 维数组可以看成是每个元素是 n - 1 维数组的一维数组,有点类似复合函数,多维张量同样如此。我们用类似复合函数的方式将形状为 (3, 2, 3) 的三维张量进行分解。
X = [[A], [B], [C]]
;A = [i, j]
、B =[k, l]
、C = [m, n]
;i = [1, 1, 1]
、j = [2, 2, 2]
、k = [3, 3, 3]
、l = [4, 4, 4]
、m = [5, 5, 5]
、n = [6, 6, 6]
。为了直观,我们可以将其绘制成层次结构:
有了这些准备,我们直接在 X 上使用 tf.slice
函数:
print(tf.slice(X, [1, 0, 0], [1, 1, 3]))
'''
[[[3, 3, 3]]]
'''
此时 begin 和 size 两个参数分别是 [1, 0, 0]
和 [1, 1, 3]
,begin 参数为张量每个维度进行切片操作的起始位置,对于 [1, 0, 0]
,我们可以理解为:
size 参数为张量每个维度取出元素的个数,对于 [1, 1, 3]
,我们可以理解为:
我们按照维度整合 begin 和 size 参数:
不过这里有个需要注意的地方,按照上面的说法,此时可能有两种选取方式:
比如,按照第一种方式,第一个维度选择 B,第二个维度选择 i, j,第三个维度选择 [5, 5, 5]
,这种每次选取都独立的方式显然是不合理的。tf.slice
显然使用第二种方式,这也是为什么说 tf.slice
能够对张量的连续子区域进行切片。
接下来,就可以将上面对 tf.slice
的理解对应到三维张量 X 中,为了更直观的理解,我们使用上面的层次结构图,图中红色的部分表示已经被选中的元素。对于 begin 和 size 两个参数分别是 [1, 0, 0]
和 [1, 1, 3]
:
明白了 tf.slice
函数,下面再来几个例子。
print(tf.slice(t, [1, 0, 0], [1, 2, 3]))
'''
tf.Tensor(
[[[3 3 3]
[4 4 4]]], shape=(1, 2, 3), dtype=int32)
'''
print(tf.slice(X, [1, 0, 0], [2, 1, 3]))
'''
tf.Tensor(
[[[3 3 3]]
[[5 5 5]]], shape=(2, 1, 3), dtype=int32)
'''
References:
本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!