首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

cs231n之Assignment2全连接网络上

cs231n之Assignment2全连接网络上

——光城

0.说在前面

在上次作业中,已经实现了两层神经网络,但是有些问题,比如程序不够模块化,耦合度不高等问题,所以本节引出神经网络的层与层结构。本节主要实现一种模块化的神经网络架构,将各个功能封装为一个对象,包括全连接层对象,仿射层,Relu层等,在各层对象的前向传播函数中,将由上一层传来的数据和本层的相关参数,经过本层的激活函数,生成输出值,并将在后面反向传播需要的额外参数,进行缓存处理,将根据后面层次的提取与缓存值计算本层各参数的梯度,从而实现反向传播。

1.仿射层

仿射层前向传播

目标:

- 计算实现一个仿射层的前向传播

输入:

- x: (N, d_1, ..., d_k)

- w: (D, M)

- b: (M,)

返回:

- out: (N, M)

- cache: (x, w, b)

实现

仿射层反向传播

目标:

计算仿射层的后向传播

输入:

- dout: (N, M)

- cache:

x: (N, d_1, ... d_k)

w: (D, M)

b: (M,)

返回:

- dx: (N, d1, ..., d_k)

- dw: (D, M)

- db: (M,)

实现:

首先获得上面前向传播的输出值与cache,紧接着计算反向传播。

cache解开得到前面仿射层的前向传播参数,接着计算梯度即可!

实现

2.RELU层

Relu层前向传播

目标:

计算Relu的前向传播

输入:

- x: 任意shape的输入

返回:

- out: 输出同x一样的shape

- cache: x

实现:

上面目标很明确,这里直接来实现,不多解释,这里用到了一个布尔矩阵运算,如果觉得疑惑,请看作业详解knn中的解释!

实现

Relu层后向传播

目标:

计算Relu的后向传播

输入:

- dout: 任何shape的前向输出(这里疑惑的看英文原文)

- cache:同dout相同shape的x

返回:

- dx: x的梯度

实现:

Relu只有矩阵中大于0的数有效,所以x>0筛选得出一个布尔矩阵,直接相乘就是最后的结果。因为如果x0的结果!

实现

3.两层组合

组合前向传播

目标:

完成仿射层与Relu层组合

输入:

- x: 仿射层的输入

- w, b: 仿射层的权重

返回:

- out: ReLU层输出

- cache: 后向传播的缓存

实现

组合反向传播

目标:

实现反向传播

输入:

- dout

- cache

返回:

- dx: x梯度

- dw: w梯度

- db: b梯度

实现:

直接调用刚才的方法。

4.两层神经网络

类封装

目标:

实现affine - relu - affine - softmax架构

输入:

- input_dim: 输入层尺寸

- hidden_dim: 隐藏层尺寸

- num_classes: 类别数

- dropout: 随机失活强度 0~1

- weight_scale: 权重范围

- reg: 正规化

实现:

封装全局参数

实现

损失函数

输入:

- X: (N, d_1, ..., d_k)

- y: (N,)

返回:

If y is None

test-time forward

- scores: (N, C)

If y is not None,

training-time forward

backward pass

- loss

- grads

实现

5.Solver训练

概要

使用这个训练之前,需要补充optim.py!

此文件实现了常用的各种一阶更新规则用于训练神经网络。每个更新规则接受当前权重和相对于那些权重的损失梯度并产生下一组权重!

SGD

SGD公式

这个被称为朴素sgd(Vanilla SGD)

公式中四个参数分别对应为:下一次的权重w,当前权重w,学习率,当前权重的梯度!

实现

Momentum

Momentum公式

这个被称为结合动量的sgd(最常用)。

阿尔法代表学习率!

实现

RMSProp

RMSProp公式

实现

Adam

Adam公式

实现

训练

实现

训练结果

6.作者的话

最后,如果您觉得本公众号对您有帮助,欢迎关注及转发,有关更多内容请关注本公众号深度学习系列!

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181130G00F9E00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券