前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >张量拼接_调整维度_切片

张量拼接_调整维度_切片

作者头像
狼啸风云
修改2022-09-04 21:08:12
1.3K0
修改2022-09-04 21:08:12
举报
文章被收录于专栏:计算机视觉理论及其实现

1、tf.concat

tf.concat的作用主要是将向量按指定维连起来,其余维度不变;而1.0版本以后,函数的用法变成:

代码语言:javascript
复制
import tensorflow as tf

t1 = [[1, 2, 3], [4, 5, 6]]

t2 = [[7, 8, 9], [10, 11, 12]]

#按照第0维连接

t12_0 = tf.concat( [t1, t2],0)

#按照第1维连接

t12_1 = tf.concat([t1, t2],1)

with tf.Session() as sess:
    print(sess.run(t12_0))
    print(sess.run(t12_1))

输出:
---------------------
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]
[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]
----------------------

作为参考合成神经网络输出的时候在深度方向(inception_v3)是数字3,[batch,heigh,width,depth]。

2、tf.stack

用法:stack(values, axis=0, name=”stack”): “”“Stacks a list of rank-R tensors into one rank-(R+1) tensor.

代码语言:javascript
复制
x = tf.constant([1, 4])

y = tf.constant([2, 5])

z = tf.constant([3, 6])

tf.stack([x,y,z]) ==> [[1,4],[2,5],[3,6]]

tf.stack([x,y,z],axis=0) ==> [[1,4],[2,5],[3,6]]

tf.stack([x,y,z],axis=1) ==> [[1, 2, 3], [4, 5, 6]]

tf.stack将一组R维张量变为R+1维张量。注意:tf.pack已经变成了tf.stack\3、tf.squeeze。数据降维,只裁剪等于1的维度。不指定维度则裁剪所有长度为1的维度。

代码语言:javascript
复制
import tensorflow as tf

arr = tf.Variable(tf.truncated_normal([3,4,1,6,1], stddev=0.1))

sess = tf.Session()

sess.run(tf.global_variables_initializer())


print(sess.run(arr).shape)

print(sess.run(tf.squeeze(arr,[2,])).shape)

print(sess.run(tf.squeeze(arr,[2,4])).shape)

print(sess.run(tf.squeeze(arr)).shape)


输出:
----------------
(3, 4, 1, 6, 1)
(3, 4, 6, 1)
(3, 4, 6)
(3, 4, 6)
----------------

3、张量切片

tf.slice

slice(input_, begin, size, name=None)

从张量中提取一个切片。

代码语言:javascript
复制
import tensorflow as tf

input = [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]]

slice_1 = tf.slice(input, [1, 0, 0], [1, 1, 3])

slice_2 = tf.slice(input, [1, 0, 0], [1, 2, 3])

slice_3 = tf.slice(input, [1, 0, 0], [2, 1, 3])

with tf.Session() as sess:
    print(sess.run(slice_1))
    print(sess.run(slice_2))
    print(sess.run(slice_3))

输出:
------------
[[[3 3 3]]]


[[[3 3 3]
  [4 4 4]]]


[[[3 3 3]]
 [[5 5 5]]]
------------

tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

代码语言:javascript
复制
def strided_slice(
input_, begin, 
end, 
strides=None, 
begin_mask=0, 
end_mask=0, 
ellipsis_mask=0, 
new_axis_mask=0, 
shrink_axis_mask=0, 
var=None, 
name=None
):

例:

代码语言:javascript
复制
import tensorflow as tf

# 来把输入变个型,可以看成3维的tensor,从外向为1,2,3维。
input= \
[
    [[1, 1, 1], 
     [2, 2, 2]
    ],
             
    [[3, 3, 3], 
     [4, 4, 4]
    ],
    
    [[5, 5, 5], 
     [6, 6, 6]
     ]
]

slice = tf.strided_slice(input, [0,0,0], [2,2,2], [1,2,1])

with tf.Session() as sess:
 print(sess.run(slice))

# start = [0,0,0] , end = [2,2,2], stride = [1,2,1]
# 求一个[start, end)的一个片段,注意end为开区间

# 第1维 start = 0 , end = 2, stride = 1, 所以取 0 , 1行,此时的输出

# [[[1, 1, 1], [2, 2, 2]],
# [[3, 3, 3], [4, 4, 4]]]

# 第2维时, start = 0 , end = 2 , stride = 2, 所以只能取0行,此时的输出

# [[[1, 1, 1]],
# [[3, 3, 3]]]

# 第3维的时候,start = 0, end = 2, stride = 1, 可以取0,1行,此时得到的就是最后的输出

# [[[1, 1]],
# [[3, 3]]]

# 整理之后最终的输出为:

# [[[1,1],[3,3]]]
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019年06月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1、tf.concat
  • 2、tf.stack
  • 3、张量切片
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档