首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

随机梯度下降法和牛顿法的理论以及Python实现

梯度:对应一个可微函数(以二元函数为例)f(x,y),

梯度的几何意义:梯度的方向是函数增长最快的方向

梯度下降法:对函数做一阶逼近寻找函数下降最快的方向

牛顿法:对函数做二阶逼近,并找到函数的极小值点

梯度下降法的困难

1在机器学习和统计参数估计问题时,目标函数是求和函数

当样本量极大时,梯度的计算耗时耗力

2 学习率的选择

过小导致收敛太慢,过大容易发散

为了解决第一个问题,工程师们提出了随机梯度下降的算法

梯度下降法分为三种类型:

1. 批梯度下降法(GD)

原始的梯度下降法

2. 随机梯度下降法(SGD)

每次梯度计算只使用一个样本

• 避免在类似样本上计算梯度造成的冗余计算

• 增加了跳出当前的局部最小值的潜力

• 在逐渐缩小学习率的情况下,有与批梯度下降法类似的收敛速度

3. 小批量随机梯度下降法(Mini Batch SGD)

每次梯度计算使用一个小批量样本

• 梯度计算比单样本更加稳定

• 可以很好的利用现成的高度优化的矩阵运算工具

我们现在经常用到的是小批量随机梯度下降法

例子(随机梯度下降算法,基于神经网络的分类问题)

import matplotlib.pyplot as plt

import numpy as np

import argparse

def sigmoid_activation(x):

return 1.0 / (1 + np.exp(-x))

def next_batch(X, y, batchSize):

for i in np.arange(0, X.shape[0], batchSize):

#生成器,节省内存空间

yield (X[i:i + batchSize], y[i:i + batchSize])

#初始化训练的三个参数:迭代次数,学习率,批次大小,常用的有:32,64,128,256

ap = argparse.ArgumentParser()

ap.add_argument("-e", "--epochs", type=float, default=100,

help="# of epochs")

ap.add_argument("-a", "--alpha", type=float, default=0.01,

help="learning rate")

ap.add_argument("-b", "--batch-size", type=int, default=32,

help="size of SGD mini-batches")

#变成一个字典

args = vars(ap.parse_args())

#随机生成一些样本

#n_samples:样本数,n_features:样本的特征数,centers:中心数,cluster_std:每个类别的方差

(X, y) = make_blobs(n_samples=400, n_features=2, centers=2,

cluster_std=2.5, random_state=95)

X = np.c_[np.ones((X.shape[0])), X]

lossHistory = []

for epoch in np.arange(0, args["epochs"]):

epochLoss = []

for (batchX, batchY) in next_batch(X, y, args["batch_size"]):

preds = sigmoid_activation(batchX.dot(W))

error = preds - batchY

loss = np.sum(error ** 2)

epochLoss.append(loss)

gradient = batchX.T.dot(error) / batchX.shape[0]

W += -args["alpha"] * gradient

lossHistory.append(np.average(epochLoss))

Y = (-W[0] - (W[1] * X)) / W[2]

plt.figure()

plt.scatter(X[:, 1], X[:, 2], marker="o", c=y)

plt.plot(X, Y, "r-")

fig = plt.figure()

plt.plot(np.arange(0, args["epochs"]), lossHistory)

fig.suptitle("Training Loss")

plt.xlabel("Epoch #")

plt.ylabel("Loss")

plt.show()

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180222G13VW600?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券