前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CNN的Flatten操作 | Pytorch系列(七)

CNN的Flatten操作 | Pytorch系列(七)

作者头像
AI算法与图像处理
发布2020-04-26 10:06:11
6.2K0
发布2020-04-26 10:06:11
举报

文 |AI_study

欢迎回到这个关于神经网络编程的系列。在这篇文章中,我们将可视化一个单一灰度图像的张量flatten 操作,我们将展示如何flatten 特定的张量轴,这是CNNs经常需要的,因为我们处理的是批量输入而不是单个输入。

张量的flatten

张量flatten操作是卷积神经网络中的一种常见操作。这是因为传递给全连接层的卷积层的输出必须在全连接层接受输入之前进行flatten。

在以前的文章中,我们学习了一个张量的形状,然后学习了reshape操作。flatten操作是一种特殊类型的reshape操作,其中所有的轴都被平滑或压扁在一起。

为了使一个张量扁平化,我们需要至少有两个轴。这使得我们开始的时候不是扁平的。现在让我们来看一幅来自MNIST数据集的手写图像。这个图像有两个不同的维度,高度和宽度。

高度和宽度分别为18 x 18。这些尺寸告诉我们这是裁剪过的图像,因为MNIST数据集是包含28 x 28的图像。现在让我们看看如何将这两个高度轴和宽度轴展平为单个长度为324的轴。

上图显示了我们的扁平化输出,其单轴长度为324。边缘上的白色对应于图像顶部和底部的白色。

在此示例中,我们将展平整个张量图像,但是如果我们只想展平张量内的特定轴怎么办?这是使用CNN时通常需要的操作。

让我们看看如何使用PyTorch展平代码中的张量的特定轴。

展平张量的特定轴

在CNN输入张量形状的文章中《深度学习中关于张量的阶、轴和形状的解释 | Pytorch系列(二)》,我们了解了一个卷积神经网络的张量输入通常有4个轴,一个用于批量大小,一个用于颜色通道,另外两个个用于高度和宽度。

(Batch Size, Channels, Height, Width)

让我们通过构造一个满足这些规格的张量来开始。首先,假设我们有以下三个张量。

一、创建一个张量表示一批图片

代码语言:javascript
复制
t1 = torch.tensor([
    [1,1,1,1],
    [1,1,1,1],
    [1,1,1,1],
    [1,1,1,1]
])

t2 = torch.tensor([
    [2,2,2,2],
    [2,2,2,2],
    [2,2,2,2],
    [2,2,2,2]
])

t3 = torch.tensor([
    [3,3,3,3],
    [3,3,3,3],
    [3,3,3,3],
    [3,3,3,3]
])

每一个的形状都是4x4,所以我们有3个2阶张量。出于我们的目的,我们将这些看作是3张4×4的图片,它们可以用来创建一批可以传递给CNN的图片。

记住,batches (多张图片)以是用一个张量表示的,所以我们需要把这三个张量合并成一个更大的张量,它有三个轴而不是两个。

代码语言:javascript
复制
> t = torch.stack((t1, t2, t3))
> t.shape

torch.Size([3, 4, 4])

在这里,我们使用stack() 方法将我们的三个张量序列连接到一个新的轴上。因为我们沿着一个新的轴有三个张量,我们知道这个轴的长度应该是3,实际上,我们可以从形状中看到我们有3个高和宽都是4的张量。

想知道stack() 方法是如何工作的吗?stack()方法的解释将在本系列的后面介绍。

https://deeplizard.com/learn/video/kF2AlpykJGY

长为3的轴表示批大小,长为4的轴分别表示高度和宽度。这就是这个批处理的张量表示的输出。

代码语言:javascript
复制
> t
tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]],

        [[2, 2, 2, 2],
         [2, 2, 2, 2],
         [2, 2, 2, 2],
         [2, 2, 2, 2]],

        [[3, 3, 3, 3],
         [3, 3, 3, 3],
         [3, 3, 3, 3],
         [3, 3, 3, 3]]])

在这一点上,我们有一个包含3张4×4张图片的3阶张量。我们现在要做的就是把这个张量变成CNN所期望的形式,就是为颜色通道添加一个轴。我们基本上对每个图像张量都有一个隐式的单色通道,所以在实践中,这些是灰度图像。

一个CNN会期望看到一个显式的颜色通道轴,所以让我们通过重构这个张量来增加一个。

代码语言:javascript
复制
> t = t.reshape(3,1,4,4)
> t
tensor(
[
    [
        [
            [1, 1, 1, 1],
            [1, 1, 1, 1],
            [1, 1, 1, 1],
            [1, 1, 1, 1]
        ]
    ],
    [
        [
            [2, 2, 2, 2],
            [2, 2, 2, 2],
            [2, 2, 2, 2],
            [2, 2, 2, 2]
        ]
    ],
    [
        [
            [3, 3, 3, 3],
            [3, 3, 3, 3],
            [3, 3, 3, 3],
            [3, 3, 3, 3]
        ]
    ]
])

注意,我们如何在batch size 轴之后指定长度为1的轴。然后,附上高度和宽度轴的长度4。另外,注意长度为1的额外轴是如何不改变张量中元素的数量的。这是因为当我们乘以1时,这些分量的乘积值不变。

第一个轴有3个元素。第一个轴的每个元素表示一个图像。对于每个图像,通道轴上都有一个单色通道。每个通道包含4个数组,其中包含4个数字或标量组件。

让我们通过这个张量的下标来看看这个。

这是第一个图像。

代码语言:javascript
复制
> t[0]
tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]])

我们在第一幅图像中有第一个颜色通道。

代码语言:javascript
复制
> t[0][0]
tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

我们在第一幅图像的第一颜色通道中有第一行像素。

代码语言:javascript
复制
> t[0][0][0]
tensor([1, 1, 1, 1])

我们在第一图像的第一颜色通道的第一行中有第一个像素值。

代码语言:javascript
复制
> t[0][0][0][0]
tensor(1)
二、扁平化张量

好。让我们看看如何扁平化这批图像。记住,整个批是一个单独的张量,它将被传递给CNN,所以我们不想把整个东西拉平。我们只想在张量内展平每一张图像张量。

我们先把它压平,看看会是什么样子。另外,我还想说一下在上一篇文章中提供flatten()函数的其他实现方法。

代码语言:javascript
复制
> t.reshape(1,-1)[0] # Thank you Mick!
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

> t.reshape(-1) # Thank you Aamir!
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

> t.view(t.numel()) # Thank you Ulm!
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

> t.flatten() # Thank you PyTorch!
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

在底部,你会注意到作为张量对象的方法内置的另一种方式,称为flatten() 。这种方法产生的输出与其他替代品完全相同。

关于此输出,我想让您注意的是,我们已经将整个批次展平了,这会将所有图像糅合到一个轴上。请记住,这些像素值 1 代表第一个图像的像素,第二个图像则是像素值 2,第三个图像则是像素值 3。

由于我们需要对批处理张量中的每个图像进行单独的预测,因此此扁平化的批次在我们的CNN中无法很好地起作用,现在我们一团糟。

解决方案是在保持batch 轴不变的情况下使每个图像变平。这意味着我们只想拉平张量的一部分。我们要使用高度和宽度轴和颜色通道轴展平。

These axes need to be flattened: (C,H,W)

这可以通过PyTorch的内置flatten() 方法来完成。

三、扁平化张量的特定轴

运行下面的代码:

代码语言:javascript
复制
> t.flatten(start_dim=1).shape
torch.Size([3, 16])

> t.flatten(start_dim=1)
tensor(
[
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
    [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
]
)

注意在调用中我们如何指定start_dim参数。这将告诉flatten() 方法应从哪个轴开始展开操作。这里的 1 是索引,因此它是第二个轴,即颜色通道轴。可以这么说,我们跳过了batch 轴,使其保持原样。

检查形状,我们可以看到我们有一个2级张量,其中三个单色通道图像被展平为16个像素。

四、扁平化一个RGB图

如果我们将RGB图像展平,那么颜色会怎样?

What happens to the Color Channels?

每个颜色通道将首先被展平。然后,展平后的通道将在张量的单个轴上并排排列。让我们来看一个代码示例。

我们将构建一个示例RGB图像张量,高度为2,宽度为2。

代码语言:javascript
复制
r = torch.ones(1,2,2)
g = torch.ones(1,2,2) + 1
b = torch.ones(1,2,2) + 2

img = torch.cat(
    (r,g,b)
    ,dim=0
)

上面给了我们所需的张量。我们可以通过如下检查形状来验证这一点:

代码语言:javascript
复制
> img.shape
torch.Size([3, 2, 2])

我们有三个高度和宽度为两个的颜色通道。我们还可以像这样检查该张量的数据:

代码语言:javascript
复制
> img
tensor([
    [
        [1., 1.]
        ,[1., 1.]
    ]
    ,[
        [2., 2.]
        , [2., 2.]
    ],
    [
        [3., 3.]
        ,[3., 3.]
    ]
])

现在,我们可以通过展平图像张量来查看。

代码语言:javascript
复制
> img.flatten(start_dim=0)
tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])

请注意,这里的start_dim参数告诉flatten() 方法从何处开始展平。在这种情况下,我们将使整个图像变平。但是,我们也只能像这样展平通道:

代码语言:javascript
复制
> img.flatten(start_dim=1)
tensor([
    [1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]
])

总结:

现在,我们应该对张量的展平操作有了一个很好的了解。我们知道如何展平整个张量,并且我们知道展平特定张量尺寸/轴。我们将在构建CNN时看到将其投入使用。

文章中内容都是经过仔细研究的,本人水平有限,翻译无法做到完美,但是真的是费了很大功夫,希望小伙伴能动动你性感的小手,分享朋友圈或点个“在看”,支持一下我 ^_^

英文原文链接是:

https://deeplizard.com/learn/video/fCVuiW9AFzY

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-04-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI算法与图像处理 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 张量的flatten
    • 二、扁平化张量
    • 三、扁平化张量的特定轴
    • 四、扁平化一个RGB图
    • 总结:
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档