前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >卷积_ResNet

卷积_ResNet

作者头像
火星娃统计
发布2021-11-02 15:02:31
2990
发布2021-11-02 15:02:31
举报
文章被收录于专栏:火星娃统计

Resnet

概述

刚才边写这个,别跑程序,偏偏没有选择自动保存,因此没得了,一个字也没有给我留下来,消耗了我所有的耐心。

因此跑程序的时候,记得保存,毕竟这破电脑什么水平自己知道 残差神经网络(ResNet)是由微软研究院的何恺明等人提出,当年的预测准确率很高

理论

就是说随着卷积的深度增加,会导致梯度爆炸和梯度消失,因此增加深度不会提高预测的准确性

image-20211025193250594

  • 为了避免梯度消失,加入的跳连的线,用于保留信息
  • 两种跳连方式,一种实线的,一种是虚线的,虚线的代表下采样,步长为2
  • 同时还有另一种类型的残差块,主要用于深度较深的网络

image-20211025193752889

image-20211025194043888

代码

实现resnet18

代码语言:javascript
复制
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model

np.set_printoptions(threshold=np.inf)
# cifar10共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。
# 这里面有50000张用于训练,另外10000用于测试,单独构成一批。
# 主要是鸟啊,飞机啊,猫啊之类的东西,就是10个类吧
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# 将像素点规整为0-1的数值
x_train, x_test = x_train / 255.0, x_test / 255.0

# resnet结构块
# 第一个卷积为3*3 步长不定
# 第二个卷积为3*3 步长为1
# 最后按照是否下采样进行判定添加下采样的跳连
# 最后过激活函数relu
class ResnetBlock(Model):

    def __init__(self, filters, strides=1, residual_path=False):
        super(ResnetBlock, self).__init__()
        self.filters = filters
        self.strides = strides
        self.residual_path = residual_path # 是否跳连下采样
        # 3*3的卷积 填充,使用偏倚
        self.c1 = Conv2D(filters, (3, 3), strides=strides, padding='same', use_bias=False)
        self.b1 = BatchNormalization() # 批标准化,各个卷积结构进行一个标准化运行,比较简单理解
        self.a1 = Activation('relu')# 激活函数

        self.c2 = Conv2D(filters, (3, 3), strides=1, padding='same', use_bias=False)
        self.b2 = BatchNormalization()

        # residual_path为True时,对输入进行下采样,即用1x1的卷积核做卷积操作,保证x能和F(x)维度相同,顺利相加
        if residual_path:
            self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding='same', use_bias=False)
            self.down_b1 = BatchNormalization()

        self.a2 = Activation('relu')

    def call(self, inputs):
        residual = inputs  # residual等于输入值本身,即residual=x
        # 将输入通过卷积、BN层、激活层,计算F(x)
        x = self.c1(inputs)
        x = self.b1(x)
        x = self.a1(x)

        x = self.c2(x)
        y = self.b2(x)

        if self.residual_path:
            residual = self.down_c1(inputs)
            residual = self.down_b1(residual)
        # 如果residual_path为ture那么上述的residual为下采样之后的,如果为false,那么为原始矩阵相加
        out = self.a2(y + residual)  # 最后输出的是两部分的和,即F(x)+x或F(x)+Wx,再过激活函数
        return out

# resnet18的搭建
class ResNet18(Model):

    def __init__(self, block_list, initial_filters=64):  # block_list表示每个block有几个卷积层
        super(ResNet18, self).__init__()
        self.num_blocks = len(block_list)  # 共有几个block
        self.block_list = block_list
        self.out_filters = initial_filters
        # 第一层卷积 这里为3*3 ,实际上是7*7
        self.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding='same', use_bias=False)
        self.b1 = BatchNormalization()
        self.a1 = Activation('relu')
        self.blocks = tf.keras.models.Sequential()# 使用Sequential方式添加层
        # 构建ResNet网络结构,使用一个循环进行生成
        for block_id in range(len(block_list)):  # 第几个resnet block,这里为4
            for layer_id in range(block_list[block_id]):  # 第几个卷积层

                if block_id != 0 and layer_id == 0:  # 对除第一个block以外的每个block的输入进行下采样
                    block = ResnetBlock(self.out_filters, strides=2, residual_path=True)
                else:
                    block = ResnetBlock(self.out_filters, residual_path=False)
                self.blocks.add(block)  # 将构建好的block加入resnet
            self.out_filters *= 2  # 下一个block的卷积核数是上一个block的2倍
        self.p1 = tf.keras.layers.GlobalAveragePooling2D() # 平均池化
        # 全连接过softmax函数
        self.f1 = tf.keras.layers.Dense(10, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())

    def call(self, inputs):
        x = self.c1(inputs)
        x = self.b1(x)
        x = self.a1(x)
        x = self.blocks(x) #这里会生成16层
        x = self.p1(x)
        y = self.f1(x)
        return y


model = ResNet18([2, 2, 2, 2])#[2, 2, 2, 2]为块的list,共4个大块,每个大块里有2个小块,1个小块有2层卷积
# 优化器,损失函数 测量指标
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])
# 断电续训
# 为上次模型的结果,目的是为了提高模型收敛的速度
checkpoint_save_path = "./checkpoint/ResNet18.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)
# cp_callback 保存模型的结果,也就是各个卷积核、w的参数
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
# 模型的训练,批次,循环,验证等内容
history = model.fit(x_train, y_train, batch_size=128, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()
# 保存模型的参数,各个w和卷积核的数值
# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

结束语

一直以为我的台式机是1050ti,后来发现才是fake的,所以电脑死机黑屏是常有的事儿,

没有gpu就别跑了,都是血泪,各种丢

love &peace

佛祖保运

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-10-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 火星娃统计 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Resnet
    • 概述
      • 理论
        • 代码
          • 结束语
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档