首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在Keras中的SGD优化不会垂直于水平曲线?

在Keras中的SGD优化不会垂直于水平曲线?
EN

Stack Overflow用户
提问于 2017-09-04 13:42:48
回答 4查看 639关注 0票数 3

我正在使用Keras进行线性回归。我的数据集由50个一维输入点和50个1D输出点组成。为了进行线性回归,我在训练一个单层和一个神经元的神经网络,没有激活函数。神经网络被定义为

代码语言:javascript
运行
复制
model = Sequential()
model.add(Dense(1, input_dim=1, kernel_initializer='zeros', 
bias_initializer='zeros'))

并要求Keras以SGD为优化器,均方误差为损失函数,求出w和b的最优值。

代码语言:javascript
运行
复制
model.compile(loss='mean_squared_error', optimizer=SGD(lr=0.01))
model.fit(x,y,epochs=100, callbacks=[history], verbose=0, batch_size=50);

其中,history是我创建的回调,用于在优化的每一步保存当前权重和偏差。

然后,我继续绘制损失函数的水平曲线,以及w x b空间中的优化轨迹。输出如下。

优化轨迹用红色圆表示,全局最优表示为蓝色'x‘。这似乎是合理的,因为我们从[0,0]开始,每次迭代之后我们都接近全局最优。最终,梯度变得如此之小,以至于我们停止了改进。

,但是,,我理解,通过使用梯度下降,在当前点(即垂直于水平曲线)的梯度方向总是移动的。这个优化轨道看起来并不是这样的。Keras SGD优化器是在做其他事情吗?还是我漏掉了什么?

编辑:虽然图表似乎说明水平曲线是平行线,但它们实际上是椭球,但非常拉长。选择一个不同的范围来绘制它们就会发现这一点。

编辑2:为了避免与如何绘制这个问题中所显示的图像有关的任何混淆,我现在创建了一个有代码的要点

EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2017-09-04 16:36:04

它是正交的(0.2比-5斜率),但是你的图的x/y单位不一样。在给定的方向上缩放并不能保持正交性。

票数 1
EN

Stack Overflow用户

发布于 2017-09-04 13:55:08

您必须记住,您使用的是一个SGD,它是一个Stochastic Gradient Descent。在下面的图像中可以看到使用SGD与香草GD相比所获得的轨迹差异的可视化:

(来源)

您可以看到,SGD轨迹不是垂直于水平线,而是不同的移动方式。也许这已经解释了你的轨迹的形式。

票数 2
EN

Stack Overflow用户

发布于 2017-09-04 13:55:07

首先,您应该认识到,由于您没有使用激活函数,您的神经网络只能表示线性系统(相当于矩阵乘法)。非线性激活函数带来了神经网络的表征能力。

实际上,您没有执行线性回归。如果您想这样做,例如2次多项式,您应该添加平方参数作为输入。由于http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.PolynomialFeatures.html,Scikit提供了这种转换。

让我们假设您有一个由两个输入x和y组成的函数,执行线性回归,就像使用x, x^2, xy, y, y^2和一个输出神经元的输入层一样。

编辑:然而,在(w,b)空间中,您实际上应该能够达到全局最小值。然而,对于收敛的速度还没有结果。如果你看看你的损失函数,你会注意到它在一个方向上被拉伸了很多:这相当于说Hessian矩阵有两个不同大小的特征值。这意味着你将能够快速地在一个方向(最大的价值之一)学习,但在另一个方向上慢慢地学习。

在神经网络优化中,计算Hessian矩阵是不可能的,因为它需要在每一步进行大量的计算。然而,一些学习模式可以避免马鞍点和糟糕的条件(如你的)优化问题。SGD的性能一般很差,而且已经很少被使用了。看看http://ruder.io/optimizing-gradient-descent/,知道所有这些优化器都包含在Keras中。对于你们来说,我首先要尝试增加动量来提高收敛的速度,就像你们说的,如果你等的时间足够长,它实际上可以收敛。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46038386

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档