专栏首页python pytorch AI机器学习实践Pytorch-ResNet(残差网络)-下

Pytorch-ResNet(残差网络)-下

ResNet具有诸多优异性能,如下所示

在左图(准确率)的比较中,从AlexNet到GoogleNet再到ResNet,准确率逐渐提高。20层结构是很多网络结构性能提升的分水岭,在20层之前,模型性能提升较容易。但在20层之后,继续添加层数对性能的提升不是很明显。但ResNet很好地解决了高层数带来的误差叠加问题,因此性能也随着层数的增加而提升。

而在右图计算量对比图中,性能最完美的是ResNet-101、Inception-v4等,计算量不大且性能很好。而VGG的运算量较大、AlexNet虽然计算量较小,但性能不佳。

那么在具体代码中,卷积层是如何实现的?

如图我们想构建一个如下图所示得神经网络

首先要明确ResNet本质上是由多个基本单元堆叠实现的,写法与之前所讲的类似。

import torch
import torch.nn as nn
import torch.nn.functional as F

# 先明确ResNet是由conv1+bn+ReLU+conv2+bn+ReLU构成

class ResBlk(nn.Module):
    def __init__(self, ch_in, ch_out):
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        # 依次对kernel_size、stride、padding进行定义
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()

        if ch_out != ch_in:
            # 该处即为short cut结构,若input_channel与该单元输出channel不一致
            # 即将ch_in作为输入、ch_out作为输出
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
                nn.BatchNorm2d(ch_out)
        )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.extra(x) + out
        return out

由此我们看出ResNet本质上是在每一层结构上都加了一个short cut。

若将该思路扩展,在中间的每一层均让其可能与之前层接触。这样就成了连接很密集的DenseNet。

如下所示

Densenet是各个channel上的累加,有时会使后面的计算量contact的很大。因此在DenseNet上的channel选择必须要非常的精妙。

本文分享自微信公众号 - python pytorch AI机器学习实践(gh_a7878fd5de90),作者:王某某搞AI

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-11-29

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • CIFAR10数据集实战-ResNet网络构建(上)

    之前讲到过,ResNet包含了短接模块(short cut)。本节主要介绍如何实现这个模块。

    用户6719124
  • CIFAR10数据集实战-ResNet网络构建(中)

    用户6719124
  • Pytorch-nn.Module

    (1)nn.Module在pytorch中是基本的复类,继承它后会很方便的使用nn.linear、nn.normalize等。

    用户6719124
  • 计算机网络原理梳理丨链路层

    香农信道编码定理:理论上可以通过编码使得数据传输过程不发生错误,或者将错误概率控制在很小的数值之下

    码脑
  • 微博订阅评论

    参考:http://open.weibo.com/wiki/%E7%A4%BA%E4%BE%8B%E4%BB%A3%E7%A0%81

    week
  • Greenplum数据库使用总结--目录部分

    小徐
  • 【OCP最新题库解析(052)--题48】When would you use memory advisors?

    该系列专题为2018年4月OCP-052考题变革后的最新题库。题库为小麦苗解答,若解答有不对之处,可留言,也可联系小麦苗进行修改。

    小麦苗DBA宝典
  • Java设计模式--单例模式

    Java高级架构
  • 8.1 VR扫描:Steam 7月数据:Rift连续6个月最大份额,Vive Pro连续四个月增长

    近日,深圳增强现实技术有限公司(0glass)宣布完成数千万B轮融资,本轮投资由清研新一代人工智能基金(珠海、嘉善)、永柏领中资本领投,第十区VRAR孵化基金跟...

    VRPinea
  • HTTP笔记_04_网络请求过程中发生了什么

    我们搭建一个本地服务,通过浏览器来访问本地服务,使用Wireshark来抓取本机127.0.0.1的网络请求数据。启动本地服务,并在浏览器中访问127.0.0....

    码农帮派

扫码关注云+社区

领取腾讯云代金券