首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >如何在tensorflow数组中按批进行切片?

如何在tensorflow数组中按批进行切片?
EN

Stack Overflow用户
提问于 2022-01-02 15:27:52
回答 2查看 336关注 0票数 0

我有一个数组output和一个id subject_ids

代码语言:javascript
代码运行次数:0
运行
复制
output = [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]]

subject_ids = [[0, 1], [1, 2], [0, 2]]

ID中的数字分别表示起始位置和结束位置,然后根据开始和结束位置得到它们之间的向量。

例如,在本例中,我应该得到[[1, 2, 3], [4, 5, 6]][[4, 5, 6], [7, 8, 9]][[1, 2, 3], [4, 5, 6], [7, 8, 9]]

我该怎么办?我试过tf.slicetf.gather,但似乎行不通。

EN

回答 2

Stack Overflow用户

发布于 2022-01-02 19:00:39

如果您只想使用Tensorflow,请尝试将tf.gathertf.rangetf.ragged.stack相结合

代码语言:javascript
代码运行次数:0
运行
复制
import tensorflow as tf

output = tf.constant([
                      [[1, 2, 3]], 
                      [[4, 5, 6]], 
                      [[7, 8, 9]]
                      ])

subject_ids = tf.constant([[0, 1], [1, 2], [0, 2]])

ragged_ouput = tf.ragged.stack([tf.gather(output, tf.range(subject_ids[i, 0], subject_ids[i, 1] + 1)) for i in tf.range(0, tf.shape(subject_ids)[0])], axis=0)
ragged_ouput = tf.squeeze(ragged_ouput, axis=2)
print(ragged_ouput)
代码语言:javascript
代码运行次数:0
运行
复制
[[[1, 2, 3], [4, 5, 6]], [[4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]

更新1:

代码语言:javascript
代码运行次数:0
运行
复制
import tensorflow as tf
tf.config.run_functions_eagerly(True)

output = tf.constant([
                      [[1, 2, 3]], 
                      [[4, 5, 6]], 
                      [[7, 8, 9]]
                      ])

subject_ids = tf.constant([[0, 1], [1, 2], [0, 2]])

def slice_tensor(x):
  return tf.ragged.stack([tf.gather(output, tf.range(x[0], x[1] + 1))], axis=0)

ragged_ouput = tf.map_fn(slice_tensor, subject_ids, fn_output_signature=tf.RaggedTensorSpec(shape=[1, None, None, 3],
                                                                    dtype=tf.int32,
                                                                    ragged_rank=2,
                                                                    row_splits_dtype=tf.int64))
ragged_ouput = tf.squeeze(ragged_ouput, axis=1)
tf.print(ragged_ouput, summarize=-1)
代码语言:javascript
代码运行次数:0
运行
复制
[[[[1, 2, 3]], [[4, 5, 6]]], [[[4, 5, 6]], [[7, 8, 9]]], [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]]]
票数 1
EN

Stack Overflow用户

发布于 2022-01-02 15:32:44

不如就这样

代码语言:javascript
代码运行次数:0
运行
复制
>>> [output[np.arange(x, y+1)] for x, y in subject_ids] 
[array([[[1, 2, 3]],
        [[4, 5, 6]]]),
        
 array([[[4, 5, 6]],
        [[7, 8, 9]]]),
        
 array([[[1, 2, 3]],
        [[4, 5, 6]],
        [[7, 8, 9]]])]
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70557245

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档