专栏首页光城(guangcity)Softmax及两层神经网络

Softmax及两层神经网络

Softmax及两层神经网络

0.说在前面1.Softmax向量化1.1 Softmax梯度推导1.2 Softmax向量化实现2.两层神经网络2.1 反向传播推导2.2 两层神经网络实现3.作者的话

0.说在前面

今天是cs231n Assignment1的最后一块,也就是继上次的softmax及两层神经网络!今天在学习神经网络反向传播的时候,觉得很有意思,就仔细琢磨了一下,结果很有帮助,对于矩阵的求导有了更深的认识,下面给出手推神经网络反向传播的求导以及softmax向量化推导及实现!Assignment2等后续内容,正在撰写中,一起来期待!下面一起来研究吧。 下期预告,链表一道题多种解法!

1.Softmax向量化

1.1 Softmax梯度推导

首先来给出Loss的公式

data loss+regularization!

推导:

X矩阵是(N,D),W矩阵是(D,C),S矩阵是(N,C),S矩阵中每一行是Li,那么XW=S表示如下公式(1)所示:

L对W求导,最后的矩阵维度为W的维度,那么L对W求导维度为(D,C),而L对S的求导维度为(N,C),S对W的求导维度为(N,D)或者(D,N),根据维度相容来选择,如果X与W均是一维的那么就是X,否则就是X转置!下面的式子记作(2)式:

X转置后维度为(D,N),而L对S求导的维度为(N,C),此时可以相乘,否则不能乘!

L对Si求导,我们知道L1只与S1有关,推出Li只与Si有关!下面的式子记作(3)式:

紧接着,我们将Li对Si求导拆分成对q求导,在由q对S求导,这里的推论结果,直接使用上次推出的结果,带入就是下面的额式子(记作(4)式):

完成(2)式得,记作(5)式:

1.2 Softmax向量化实现

具体实现的流程解释看代码注释!

def softmax_loss_vectorized(W, X, y, reg):
    loss = 0.0
    dW = np.zeros_like(W)
    num_train = X.shape[0]
    num_class = W.shape[1]
    scores = X.dot(W)  # N*C
    # np.max后会变成一维,可设置keepdims=True变为二维(N,1)
    # 防止指数爆炸
    scores-=np.max(scores,axis=1,keepdims=True)
    # 取指数
    scores=np.exp(scores)
    # 计算softmax
    scores/=np.sum(scores,axis=1,keepdims=True)
    # ds表示L对S求导
    ds = np.copy(scores)
    # qiyi - 1
    ds[np.arange(num_train), y] -= 1
    dW = np.dot(X.T, ds)
    loss = scores[np.arange(num_train), y]
    # 计算Li
    loss =-np.log(loss).sum()
    # 计算所有loss除以N
    loss /= num_train
    # ds矩阵没有除以N,所以在这里补上,最后除以N,具体看(5)式
    dW /= num_train
    # 计算最终的大L
    loss += reg * np.sum(W * W)
    dW += 2 * reg * W
    return loss, dW

2.两层神经网络

2.1 反向传播推导

2.2 两层神经网络实现

计算前向传播

前向传播可以看上面手推图结构!

scores = None
s1 = np.dot(X, W1) + b1
# (N,H)
s1_relu = (s1 > 0) * s1
scores = np.dot(s1_relu, W2) + b2
if y is None:
    return scores

计算损失函数

这里计算损失与softmax一致,可以参看上面的!

# Compute the loss
loss = None
# 防止指数爆炸
scores -= np.max(scores, axis=1, keepdims=True)
# 取指数
scores = np.exp(scores)
# 计算softmax
scores /= np.sum(scores, axis=1, keepdims=True)
loss = -np.log(scores[np.arange(N), y]).sum()
loss /= N
loss += reg * np.sum(W1 * W1)
loss += reg * np.sum(W2 * W2)

计算反向传播

具体推导看上面手推图!

这里将上面的关键点提出来,ds2表示的是dl对ds2求导,ds1表示dl对ds1求导!其余的一致!

grads = {}
ds2 = np.copy(scores)
# qiyi - 1
ds2[np.arange(N), y] -= 1
grads['W2'] = np.dot(s1_relu.T, ds2) / N + 2 * reg * W2
# b2的shape=(N,C)广播机制
# (1,C)
# 这里除以N是因为ds的时候没有除以N,所以最后就得除以N,后面相同!
grads['b2'] = np.sum(ds2, axis=0) / N
ds1 = np.dot(ds2, W2.T)
# relu函数
ds1 = (s1 > 0) * ds1
grads['W1'] = np.dot(X.T, ds1) / N + 2 * reg * W1
grads['b1'] = np.sum(ds1, axis=0) / N

随机选择数据集batch_size大小

train方法中添加:

num_random = np.random.choice(np.arange(num_train), batch_size)
X_batch = X[num_random, :]
y_batch = y[num_random]

计算损失与梯度

train方法中添加:

loss, grads = self.loss(X_batch, y=y_batch, reg=reg)
loss_history.append(loss)

更新w与b

train方法中添加:

self.params['W1'] -= learning_rate * grads['W1']
self.params['W2'] -= learning_rate * grads['W2']
self.params['b1'] -= learning_rate * grads['b1']
self.params['b2'] -= learning_rate * grads['b2']

预测结果

output = np.maximum(X.dot(self.params['W1']) + self.params['b1'], 0).dot(self.params['W2'])+self.params['b2']
y_pred = np.argmax(output, axis=1)

本文分享自微信公众号 - 光城(guangcity),作者:lightcity

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-11-24

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • SVM梯度求导及实现

    昨晚看了一部电影,叫做我是马布里,非常正能量,推荐给各位,看完这部电影的总结话是:

    公众号guangcity
  • cs231n之KNN、SVM

    最近在学习cs231n,觉得有点困难,今天抽了一晚上时间来写这篇文章,作为总结。下面一起来看任务一的题目,由于篇幅长,故分成两部分,下节重点softmax!

    公众号guangcity
  • matlibplot绘制各种图形

    0.导语1.预备知识1.1 np.arange()1.2 numpy.random.uniform()1.3 zip()2.bar绘制3.散点图4.3D图5.参...

    公众号guangcity
  • 机器学习(二十一) 异常检测算法之IsolationForest

    IsolationForest指孤立森林,是一种高效的异常检测算法。在所有样本数据中,异常数据具有数量少并且与大多数数据不同的特点,利用这一特性分割样本,那些异...

    致Great
  • 测试开发:校招面试实录

    本人女,非985,211硕士,找的工作都是和测试开发相关的。因为一些原因没有参加过校招,只有17年8月份的实习经历和9月份百度云的提前批经历,以及12月底的几次...

    牛客网
  • SpringBoot开发案例之Nacos注册中心管理

    在之前的 Dubbo 服务开发中,我们一般使用 Zookeeper 作为注册中心,同时还需要部署 Dubbo 监控中心和管理后台。

    小柒2012
  • 基于UDP/IP协议的电口通信(三)

    有些生命自然而来的缘份,是约定俗成好了的。无力改变。只能精心的筹划痴心的遥望耐心的守候动心的注目。

    碎碎思
  • Springboot 系列(九)使用 Spring JDBC 和 Druid 数据源监控

    作为一名 Java 开发者,相信对 JDBC(Java Data Base Connectivity)是不会陌生的,JDBC作为 Java 基础内容,它提供了一...

    未读代码
  • 一种海量日志存储、分析解决方案V1.1 原

    针对上一个版本https://my.oschina.net/shyloveliyi/blog/786337,有如下更新:

    尚浩宇
  • 云+社区技术沙龙:音视频技术开发实战 报名开启

    近年来,随着移动互联网的普及和智能终端设备的广泛应用,短视频、直播、在线教学等各类形式的音视频形式的应用越来越广泛,然而,音视频技术使用起来虽然便捷,但是在技术...

    云加社区技术沙龙

扫码关注云+社区

领取腾讯云代金券