Improving Deep Neural Networks学习笔记(一)

1. Setting up your Machine Learning Application

1.1 Train/Dev/Test sets

Make sure that the dev and test sets come from the same distribution。

Not having a test set might be okay.(Only dev set.)

So having set up a train dev and test set will allow you to integrate more quickly. It will also allow you to more efficiently measure the bias and variance of your algorithm, so you can more efficiently select ways to improve your algorithm.

1.2 Bias/Variance

High Bias: underfitting High Variance: overfitting

Assumption——human: 0% (Optimal/Bayes error), train set and dev set are drawn from the same distribution.

Train set error

Dev set error

Result

1%

11%

high variance

15%

16%

high bias

15%

30%

high bias and high variance

0.5%

1%

low bias and low variance

1.3 Basic Recipe for Machine Learning

High bias –> Bigger network, Training longer, Advanced optimization algorithms, Try different netword.

High variance –> More data, Try regularization, Find a more appropriate neural network architecture.

2. Regularizing your neural network

2.1 Regularization

In logistic regression,

w∈Rnx,b∈R

w \in R^{n_x}, b \in R

J(w,b)=1m∑i=1mL(ŷ (i),y(i))+λ2m||w||22

J(w, b) = \frac {1} {m} \sum _{i=1} ^m L(\hat y^{(i)}, y^{(i)}) + \frac {\lambda} {2m} ||w||_2^2

||w||22=∑j=1nxw2j=wTw

||w||_2^2 = \sum _{j=1} ^{n_x} w_j^2 = w^Tw This is called L2 regularization.

J(w,b)=1m∑i=1mL(ŷ (i),y(i))+λ2m||w||1

J(w, b) = \frac {1} {m} \sum _{i=1} ^m L(\hat y^{(i)}, y^{(i)}) + \frac {\lambda} {2m} ||w||_1 This is called L1 regularization. w will end up being sparse. λ\lambda is called regularization parameter.

In neural network, the formula is

J(w[1],b[1],...,w[L],b[L])=1m∑i=1mL(ŷ (i),y(i))+λ2m∑l=1L||w[l]||2

J(w^{[1]},b^{[1]},...,w^{[L]},b^{[L]}) = \frac {1} {m} \sum _{i=1} ^m L(\hat y^{(i)}, y^{(i)}) + \frac {\lambda} {2m} \sum _{l=1}^L ||w^{[l]}||^2

||w[l]||2=∑i=1n[l−1]∑j=1n[l](w[l]ij)2,w:(n[l−1],n[l])

||w^{[l]}||^2 = \sum_{i=1}^{n^{[l-1]}}\sum _{j=1}^{n^{[l]}} (w_{ij}^{[l]})^2, w:(n^{[l-1]}, n^{[l]})

This matrix norm, it turns out is called the Frobenius Norm of the matrix, denoted with a F in the subscript.

L2 norm regularization is also called weight decay.

2.2 Why regularization reduces overfitting?

If λ\lambda is set too large, matrices W is set to be reasonabley close to zero, and it will zero out the impact of these hidden units. And that’s the case, then this much simplified neural network becomes a much smaller neural network. It will take you from overfitting to underfitting, but there is a just right case in the middle.

2.3 Dropout regularization

Dropout will go through each of the layers of the network, and set some probability of eliminating a node in neural network. By far the most common implementation of dropouts today is inverted dropouts.

Inverted dropout, kp stands for keep-prob:

z[i+1]=w[i+1]a[i]+b[i+1]

z^{[i + 1]} = w^{[i + 1]} a^{[i]} + b^{[i + 1]}

a[i]=a[i]/kp

a^{[i]} = a^{[i]} / kp

In test phase, we don’t use dropout and keep-prob.

2.4 Understanding dropout

Why does dropout workd? Intuition: Can’t rely on any one feature, so have to spread out weights.

By spreading all the weights, this will tend to have an effect of shrinking the squared norm of the weights.

2.5 Other regularization methods

  • Data augmentation.
  • Early stopping

3. Setting up your optimization problem

3.1 Normalizing inputs

Normalizing inputs can speed up training. Normalizing inputs corresponds to two steps. The first is to subtract out or to zero out the mean. And then the second step is to normalize the variances.

3.2 Vanishing/Exploding gradients

If the network is very deeper, deep network suffer from the problems of vanishing or exploding gradients.

3.3 Weight initialization for deep networks

If activation function is ReLU or tanh, w initialization is:

w[l]=np.random.randn(shape)∗np.sqrt(2n[l−1]).

w^{[l]} = np.random.randn(shape) * np.sqrt(\frac {2} {n^{[l-1]}}). This is called Xavier initalization.

Another formula is

w[l]=np.random.randn(shape)∗np.sqrt(2n[l−1]+n[l]).

w^{[l]} = np.random.randn(shape) * np.sqrt(\frac {2} {n^{[l-1]} + n^{[l]}}).

3.4 Numberical approximation of gradients

In order to build up to gradient checking, you need to numerically approximate computatiions of gradients.

g(θ)≈f(θ+ϵ)−f(θ−ϵ)2ϵ

g(\theta) \approx \frac {f(\theta + \epsilon) - f(\theta - \epsilon)} {2 \epsilon}

3.5 Gradient checking

Take matrix W, vector b and reshape them into vectors, and then concatenate them, you have a giant vector θ\theta. For each i:

dθapprox[i]=J(θ1,...,θi+ϵ,...)−J(θ1,...,θi−ϵ,...)2ϵ≈dθi=∂J∂θi

d\theta _{approx}[i]= \frac {J(\theta_1,...,\theta_i + \epsilon,...)-J(\theta_1,...,\theta_i - \epsilon,...)} {2\epsilon} \approx d\theta_i=\frac {\partial J} {\partial \theta_i}

If

||dθapprox−dθ||2||dθapprox||2+||θ||2≈10−7

\frac {||d\theta_{approx} - d\theta ||_2} {||d\theta_{approx}||_2 + ||\theta||_2} \approx 10^{-7}, that’s great. If ≈10−5\approx 10^{-5}, you need to do double check, if ≈10−5\approx 10^{-5}, there may be a bug.

3.6 Gradient checking implementation notes

  • Don’t use gradient check in training, only to debug.
  • If algorithm fails gradient check, look at components to try to identify bug.
  • Remember regularization.
  • Doesn’t work with dropout.
  • Run at random initialization; perhaps again after some training.

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏程序员宝库

电商系统设计之订单

用户交易将经历一段艰辛的历程,一般用户感觉不到,实际程序是经历了一段生死离别。具体付款流程如下:

2112
来自专栏Grace development

电商系统设计之订单

用户交易将经历一段艰辛的历程,一般用户感觉不到,实际程序是经历了一段生死离别。具体付款流程如下

1162
来自专栏java学习

Oracle基础试题与答案!

表结构: create table tbEmp --职员表 ( eID number(7) primarykey, ...

31612
来自专栏Grace development

电商系统设计之商品 (上)

商品的设计是电商系统中占据重要地位,如何设计出高扩展,高性能的商品系统并非一件简单的事情,我的设计是观摩互联网各大佬的设计后自行研究的,并非完全正确,但也不完全...

3534
来自专栏沃趣科技

SQL优化案例-自定义函数索引(五)

SQL文本如下,表本身很小,走全表扫描也很快,但因业务重要性,要求尽可能缩短查询时间(为保证客户隐私,已经将注释和文字部分去掉):

903
来自专栏Grace development

电商系统设计之商品 (下)

完成上述流程则是完成了一笔交易,经常网上购物的童鞋都懂这个。今天我们讲下从商品系统到交易系统和订单系统的存储过程及其设计上的应该注意的“坑”。

4341
来自专栏Java帮帮-微信公众号-技术文章全总结

SQL经典5道题

SQL经典5道题 1:假设有一个“职工”表,表结构如下:(14分) 职工号姓名年龄月工资部门号电话办公室1张三2520001123451012李四2615001...

4435
来自专栏小文博客

贪吃蛇代码来咯

1682
来自专栏Java技术分享圈

2018-06-03-oracleTest

743
来自专栏数据和云

元宵快乐:看SQL大师们用SQL绘制的团圆

题记:在多年以前,论坛活跃的时代,在ITPUB上你能看到各种新奇有趣的知识,及时新鲜的信息,出类拔萃的技巧,有很多让人多年以后还记忆犹新。 这个帖子让我忍不住在...

3177

扫码关注云+社区