前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch基础知识 切片与索引-上

pytorch基础知识 切片与索引-上

作者头像
用户6719124
发布2019-11-17 23:26:20
9840
发布2019-11-17 23:26:20
举报

切片和索引是pytorch中经常使用的操作

为后续讲解方便,这里先介绍CNN的基本图片的概念,一般将图片设定为[batch_size, channel, height, width]的四维矩阵。

这里先随机建立一个矩阵

代码语言:javascript
复制
import torch
a = torch.rand(4, 3, 28, 28)
print(a.size())

输出size为:

代码语言:javascript
复制
torch.Size([4, 3, 28, 28])

再对第一维进行索引

代码语言:javascript
复制
# 对第一维进行索引
print(a[0].size())
代码语言:javascript
复制
torch.Size([3, 28, 28])

这里的输出可以认为是第一个图片的三个维度通道的28*28的像素点。

代码语言:javascript
复制
print(a[0, 0].size())
代码语言:javascript
复制
torch.Size([28, 28])

这里的输出可以认为是第一个图片的第一个维度通道的28*28的像素点。

当具体到某一个像素点时

代码语言:javascript
复制
print(a[0, 0, 2, 3])
代码语言:javascript
复制
tensor(0.4736)

这里的输出代表第一个图片的第一个维度通道的[2,3]的像素点张量为(0.4736)。

若想取连续的索引,

需要用到:

代码语言:javascript
复制
# 取连续索引
print(a.shape)
print(a[:2].shape)
代码语言:javascript
复制
torch.Size([2, 3, 28, 28])
# 这里的:相当于→(箭头),表明batch从第一个到第二个,不写默认写全部

同理

代码语言:javascript
复制
print(a[:2, 1:, :, :].shape)
# 1写在:前面,表明从1个通道开始到末尾,,不包括1
代码语言:javascript
复制
torch.Size([2, 2, 28, 28])

另外

当索引出现-1时,要提到一个知识点

代码语言:javascript
复制
print(a[:2, -1:, :, :].shape)
# 默认索引的顺序为[0, 1, 2],当倒着写时变为[-3, -2, -1]。由于这里取-1,因此为最后一位。

此时输出

代码语言:javascript
复制
torch.Size([2, 1, 28, 28])

当想隔点取样输出时

代码语言:javascript
复制
print(a[:, :, 0:28:2, 0:28:2].shape)
# 输出全部batch和channel,对每个高和宽间隔2个点采样
代码语言:javascript
复制
torch.Size([4, 3, 14, 14])

也可简化为

代码语言:javascript
复制
print(a[:, :, ::2, ::2].shape)

同样输出为

代码语言:javascript
复制
torch.Size([4, 3, 14, 14])

这里需要注意 当写为[0:28:]则等同于[0:28:1]因此可以认为[start:end:steps]

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-10-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档