前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >详解 tf.slice 函数

详解 tf.slice 函数

作者头像
触摸壹缕阳光
发布2022-05-25 14:02:57
5170
发布2022-05-25 14:02:57
举报

TensorFlow 张量的索引切片方式和 NumPy 模块差不多。与此同时,TensorFlow2.X 也提供了一些比较高级的切片函数,比如:

  • 对张量进行不规则切片提取的 tf.gathertf.gather_ndtf.boolean_mask
  • 对张量的连续子区域进行切片提取的 tf.slice

相比于对张量进行不规则的切片提取的三个函数,tf.slice 的实现方式比较特殊,所以本文来详细的介绍 tf.slice 函数。

代码语言:javascript
复制
tf.slice(
    input_, begin, size, name=None
)

tf.slice 函数主要有三个参数:

  • input_: 待切片提取的张量
  • begin: 张量每个维度进行切片操作的起始位置
  • size: 张量每个维度取出的元素个数

为了理解 tf.slice 函数的实现方式,首先创建一个形状为 (3, 2, 3) 的三维的张量 X。

代码语言:javascript
复制
import tensorflow as tf

X = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])

print(X.shape) # (3, 2, 3)

我们知道 n 维数组可以看成是每个元素是 n - 1 维数组的一维数组,有点类似复合函数,多维张量同样如此。我们用类似复合函数的方式将形状为 (3, 2, 3) 的三维张量进行分解。

  1. 第一个维度有 3 个元素,用 A, B, C 表示,即 X = [[A], [B], [C]]
  2. 第二个维度有 2 个元素,第一个维度的 3 个元素,每个元素都有 2 个元素,用 i, j, k, l, m, n 表示,即 A = [i, j]B =[k, l]C = [m, n]
  3. 第三个维度有 3 个元素,第二个维度的 2 个元素,每个元素都有 3 个元素,即 i = [1, 1, 1]j = [2, 2, 2]k = [3, 3, 3]l = [4, 4, 4]m = [5, 5, 5]n = [6, 6, 6]

为了直观,我们可以将其绘制成层次结构:

有了这些准备,我们直接在 X 上使用 tf.slice 函数:

代码语言:javascript
复制
print(tf.slice(X, [1, 0, 0], [1, 1, 3]))
'''
[[[3, 3, 3]]]
'''

此时 begin 和 size 两个参数分别是 [1, 0, 0][1, 1, 3],begin 参数为张量每个维度进行切片操作的起始位置,对于 [1, 0, 0],我们可以理解为:

  • 第一个维度从位置 1 开始
  • 第二个维度从位置 0 开始
  • 第三个维度从位置 0 开始

size 参数为张量每个维度取出元素的个数,对于 [1, 1, 3],我们可以理解为:

  • 第一个维度取出 1 个元素
  • 第二个维度取出 1 个元素
  • 第三个维度取出 3 个元素

我们按照维度整合 begin 和 size 参数:

  • 第一个维度,从位置 1 开始,并且取出 1 个元素
  • 第二个维度,从位置 0 开始,并且取出 1 个元素
  • 第三个维度,从位置 0 开始,并且取出 3 个元素

不过这里有个需要注意的地方,按照上面的说法,此时可能有两种选取方式:

  1. 第一种方式:每次选取都是独立的;
  2. 第二种方式:按照层次结构逐层进行选取。

比如,按照第一种方式,第一个维度选择 B,第二个维度选择 i, j,第三个维度选择 [5, 5, 5],这种每次选取都独立的方式显然是不合理的。tf.slice 显然使用第二种方式,这也是为什么说 tf.slice 能够对张量的连续子区域进行切片。

接下来,就可以将上面对 tf.slice 的理解对应到三维张量 X 中,为了更直观的理解,我们使用上面的层次结构图,图中红色的部分表示已经被选中的元素。对于 begin 和 size 两个参数分别是 [1, 0, 0][1, 1, 3]

  • 第一个维度,从位置 1 开始,并且取出 1 个元素(Python 的索引从 0 开始)
  • 在选中的基础上,我们继续在第二个维度,从位置 0 开始,并且取出 1 个元素
  • 在选中的基础上,我们继续在第三个维度,从位置 0 开始,并且取出 3 个元素

明白了 tf.slice 函数,下面再来几个例子。

代码语言:javascript
复制
print(tf.slice(t, [1, 0, 0], [1, 2, 3]))
'''
tf.Tensor(
[[[3 3 3]
  [4 4 4]]], shape=(1, 2, 3), dtype=int32)
'''
代码语言:javascript
复制
print(tf.slice(X, [1, 0, 0], [2, 1, 3]))
'''
tf.Tensor(
[[[3 3 3]]
 [[5 5 5]]], shape=(2, 1, 3), dtype=int32)
'''

References:

  1. https://www.quora.com/How-does-tf-slice-work-in-TensorFlow
  2. https://www.tensorflow.org/api_docs/python/tf/slice
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-05-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档