前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >13 | PyTorch全连接网络识别飞机、全连接在图像分类上的缺陷

13 | PyTorch全连接网络识别飞机、全连接在图像分类上的缺陷

作者头像
机器学习之禅
发布2022-07-11 15:43:15
5760
发布2022-07-11 15:43:15
举报
文章被收录于专栏:机器学习之禅机器学习之禅

接着上一小节说,我们已经把全连接网络建好了,接下来就需要去训练网络,找到合适的参数来拟合我们的训练数据,那么第一个事情就看损失函数。

损失函数

回忆我们之前用的MSE损失函数,当结果偏离实际结果,不管是正向的偏离还是反向的偏离,损失都会上升,我们在分类中当然也可以使用这样的损失,但是效果并不太好,因为我们不是想让最终的概率结果一个精确的值,[1,0]或者[0,1],我们希望的是如果一张图是鸟,那么鸟的概率比飞机高就可以了,而不是绞尽脑汁研究怎么把这张图的概率优化到识别“它一定是一只鸟”。

由于我们给出的结果是一个在0-1区间的概率,而实际结果是0或者1,那么计算标准差的结果也不会很大,我们很难看出来效果好坏。

这里给出了一个新的损失函数,叫做负对数似然损失函数(Negative Log Likelihood, NLL),它先对结果求log,然后求和并取负数。NLL=-sum(log(out_i[c_i]))它的图像如下:

当预测结果的概率较低时,NLL会趋近于无穷大,当预测结果概率大于0.5的时候,NLL有缓慢的下降,随着我们预测概率越大,损失越低。也就是说当某个类别概率大于0.5的时候,我们就可以认为它已经差不多符合我们的要求了。

这里给出了一个关于交叉熵损失和MSE损失的直观对比,需要说的是我们虽然前面说的损失是NLL损失,但是我们在输出的时候使用了softmax,也就是进行了-sum(log(softmax(out_i[c_i])))这样一个变换,这个叫做交叉熵损失,在nn模块当然也有它的实现,可以通过调用nn.CrossEntropyLoss()来使用。从图上可以看出来,如果使用MSE损失,首先在预测结果上有很大一块都是平的,也就是很难看出来给了4和2有什么差距,而且损失有一个峰值,就是大概到2的时候就封顶了。

这时候稍微改动一下我们的模型,把输出改成LogSoftmax,并实例化我们的NLL损失

代码语言:javascript
复制
model = nn.Sequential(
            nn.Linear(3072, 512),
            nn.Tanh(),
            nn.Linear(512, 2),
            nn.LogSoftmax(dim=1))

loss = nn.NLLLoss()

说了这么多,我们就是需要对分类问题换一个损失函数,接下来开始正经训练。

训练分类器

下面是一个完整的代码,我们这里用的是全连接,跑起来很慢,不妨让我们先看一下这个代码跟之前有什么不同。

在模型设定方面是一样的讨论,只是模型内部的变换多了一些,紧接着是学习率、优化器、损失函数,这里跟以前都一样的,有区别的是在循环里面,我们原来是一个大循环,现在里面又套了一个小循环,在每一个小循环里,我们只取出一个图像样本进行评估,然后计算损失,反向传播并迭代参数。小批量数据有助于防止陷入局部最优值,但是也会导致训练很不稳定,所以要使用比较小的学习率。从下面的输出可以看到,在前10个epoch中,epoch5损失最小,其他的都在波动中,并不像我们之前看到的是持续下降的情况。

代码语言:javascript
复制
import torchimport torch.nn as nnimport torch.optim as optim
model = nn.Sequential(
                       nn.Linear(3072,512),
                       nn.Tanh(),
                       nn.Linear(512,2),
                       nn.LogSoftmax(dim=1))learning_rate = 1e-2optimizer = optim.SGD(model.parameters(), lr=learning_rate)loss_fn = nn.NLLLoss()n_epochs = 100for epoch in range(n_epochs):
    for img, label in cifar2:
        out  =  model(img.view(-1).unsqueeze(0))
        loss = loss_fn(out, torch.tensor([label]))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch: %d, Loss: %f" % (epoch, float(loss)))outs:Epoch: 0, Loss: 5.763190Epoch: 1, Loss: 7.964352Epoch: 2, Loss: 5.558451Epoch: 3, Loss: 4.321349Epoch: 4, Loss: 8.145464Epoch: 5, Loss: 1.852705Epoch: 6, Loss: 6.117345Epoch: 7, Loss: 16.624195Epoch: 8, Loss: 3.632586Epoch: 9, Loss: 6.123263

作者给出了一个形象的图来展示小批量更新的效果,可以看到在整个图上的损失情况基本上是从左下角到右上角是一个下降的趋势,其中的黄色曲线是全数据集计算梯度下降的理想曲线,而黑色曲线是在小批量数据上进行梯度下降的状况。

先不等上面那个跑完了,我们接着往下看。因为刚才那个相当于每个小批量只用了1个样本,运算起来很慢。这里我们考虑选择64张图作为一个批次的数据,使用一个叫DataLoader()的方法来获取数据。

代码语言:javascript
复制
import torchimport torch.nn as nnimport torch.optim as optim

train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)  #使用DataLoader加载数据,设置shuffle表示打乱数据
model = nn.Sequential(
                       nn.Linear(3072,512),
                       nn.Tanh(),
                       nn.Linear(512,2),
                       nn.LogSoftmax(dim=1))learning_rate = 1e-2optimizer = optim.SGD(model.parameters(), lr=learning_rate)loss_fn = nn.NLLLoss()n_epochs = 100for epoch in range(n_epochs):
    for imgs, labels in train_loader:
        out  =  model(img.view(-1).unsqueeze(0))
        loss = loss_fn(out, torch.tensor([label]))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch: %d, Loss: %f" % (epoch, float(loss)))outs:Epoch: 0, Loss: 0.001268Epoch: 1, Loss: 0.000632Epoch: 2, Loss: 0.000420···
Epoch: 99, Loss: 0.000012

从上面的结果看到,使用了一批64个数据之后,我们的损失顿时就小了很多,而且非常稳定,训练速度也快了很多,到了最后一代基本上可以认为超级精准了,loss降到了0.000012。

既然损失这么低了,我们可以来检测一下我们模型的准确率了,这时候掏出我们的验证集。

代码语言:javascript
复制
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)correct = 0total = 0with torch.no_grad():
    for imgs, labels in val_loader:
        batch_size = imgs.shape[0]
        outputs = model(imgs.view(batch_size, -1)) #预测验证集数据
        _, predicted = torch.max(outputs, dim=1) #取较大值作为预测结果
        total += labels.shape[0] #计算验证集的大小
        correct +=  int((predicted == labels).sum()) #计算预测正确的标签数量
        print("Accuracy: ", correct / total)outs:Accuracy:  0.676

然而,我们的得到的验证集准确率竟然只有67.6%,这有问题啊xdm,这明显过拟合了,在我们的训练集上那么低的损失,在验证集上的准确率只有这么一丢丢,只比随机好一点。

神经网络的好处就是,我们可以不断的拓展我们的神经元,一层不够就来两层,我前面说过深而窄的网络效果一般比宽而浅的网络效果好,一种表面上的理解就是深层次的思考可以习得更多有用的特征,让我们再加两层隐含层,这里损失也不再那么麻烦了,直接用一步到位的交叉熵损失。

代码语言:javascript
复制
#主要改模型部分
model = nn.Sequential(
                       nn.Linear(3072,1024),
                       nn.Tanh(),
                       nn.Linear(1024,512),
                       nn.Tanh(),
                       nn.Linear(512,128),
                       nn.Tanh(),
                       nn.Linear(128,2)) #注意这里去掉了softmax,因为在交叉熵损失里面已经包含了softmax部分
#然后是损失
loss_fn = nn.CrossEntropyLoss()
outs:# 这里省略了大部分输出
Epoch: 99, Loss: 0.000036

从上面的结果看,在100代损失稍微大了一点,再跑一下验证集Accuracy: 0.6925,效果提升了一点点,但是不多。

全连接网络的局限

看来这个模型效果就这样了,我们先不再改进它,转头思考一下,这个模型有什么问题。

第一个问题可能是参数太多导致训练太慢。 当然这个问题可能跟结果没什么关系,但是如果训练能够快很多的话,我们每天可以训更多次,也可能优化更多的地方,另外当参数特别大的时候,我们的电脑承受不来,可能会导致内存溢出,根本就没办法训练了。

要查看我们到底有多少参数,nn.Model也提供了parameters()方法,我们可以用它来获取参数数量

numel()函数:返回[数组]中元素的个数

代码语言:javascript
复制
numel_list = [p.numel() 
             for p in model.parameters()
             if p.requires_grad == True]sum(numel_list), numel_list
outs:(3737474, [3145728, 1024, 524288, 512, 65536, 128, 256, 2])

可以看到,就我们用的这么简单的小模型,都有高达370w的参数,如果图像再大点我们的电脑就直接崩溃了。

在第一个问题的基础上,第二个问题就是不具有平移不变性。 考虑我们对图像做的预处理,我们把它的三个通道都摊平了,并且塞到了一个一维向量中,那么我们的模型只能学到一个顺序排列的数组的特性,但是图片实际上并不是这样的是不是,一个图像像素跟他上下左右的像素都有关系。

所以这里有一个概念叫做平移不变性,就是在一个图片上,同样一架飞机出现在图片左上角和右下角并不影响这是一张跟飞机相关的图片,但是我们把它拉成一个一维向量这个特征就丢了,如下图所示

比如说左上角的图是飞机在左上角,把它拉成一维向量之后,与我们的权重向量进行计算得到A; 左下角的图像是飞机在右上角,把它拉成一维向量后与我们的权重计算得到B。

如果让我们看图片,这两张图肯定都是飞机啊,但是经过同样的权重矩阵计算,它俩却得到了不一样的结果,所以如果图像发生的变化,我们的模型就没办法很好的给出结果了,哪怕只是把飞机从左边移到右边。当然我们可以考虑增加样本量,比如给图像做镜像变换,上下左右翻转,各种裁剪旋转等等,但是有一个更好的方案就是使用卷积层,下一节我们看一下卷积层如何解决这个问题。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-06-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习之禅 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 损失函数
  • 训练分类器
  • 全连接网络的局限
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档