前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用 Java 实现梯度下降

用 Java 实现梯度下降

作者头像
用户5224393
修改2020-06-01 17:24:02
1.4K0
修改2020-06-01 17:24:02
举报
文章被收录于专栏:Java研发军团Java研发军团

来自作者投稿  作者:覃佑桦

www.baeldung.com/java-gradient-descent

1.引言

文本会学习梯度下降算法。我们将分步对算法实现过程进行说明并用Java实现。

2.什么是梯度下降?

梯度下降是一种优化算法,用于查找给定函数的局部最小值。它被广泛用于高级机器学习算法中,最小化损失函数。

梯度(gradient)是坡度(slope)的另一种表达,下降(descent)表示降低。顾名思义,梯度下降随着函数的斜率下降直到抵达终点。

3.梯度下降特性

梯度下降可找到局部最小值,该局部最小值有可能与全局最小值不同。起始局部点会作为算法的一个参数给出。

这是一种迭代算法。每一步都会尝试沿斜率向下移动并接近局部最小值。

实践中,算法采用的是回溯(backtrack)。接下来我们将采用回溯实现梯度下降。

4.分步说明

梯度下降需要一个函数和一个起点作为输入。让我们定义并绘制一个函数:

可以从任何期望的点开始。让我们从 x=1 开始:

第一步,梯度下降以预定的步长沿斜率下降:

接下来以相同的步长继续前进。但是,这次结束时的y 值比上次大:

这就表明算法已超过了局部最小值,因此用较小的步长后退:

随后,只要当前y 大于前一次 y,就会减小步长并取反。迭代会一直进行直到满足所需的精度。

如我们看到的那样,梯度下降在这里处找到了局部最小值,但不是全局最小值。如果我们从 x=-1 而非 x=1 开始,则能找到全局最小值。

5.Java实现

有几种方法能够实现梯度下降。这里没有采用计算函数的导数来确定斜率的方向,因此我们的实现也适用于不可微函数。

定义 precision 和 stepCoefficient 并给它赋上初值:

代码语言:javascript
复制
double precision = 0.000001;
double stepCoefficient = 0.1;

进行第一步时,没有之前的 y 作比较。我们可以增加或减少 x 值确认 y 值是减少或增加。stepCoefficient 为正数表明正在增加 x 值。

现在让我们执行第一步:

代码语言:javascript
复制
double previousX = initialX;
double previousY = f.apply(previousX);
currentX += stepCoefficient * previousY;

上面的代码中,f 是 Function<Double, Double>,initialX 的类型是 double,二者都作为输入。

另一个需要考虑的关键点,梯度下降并不保证收敛。为了避免陷入死循环,需要限制迭代次数:

代码语言:javascript
复制
int iter = 100;

每次迭代都把 iter 减1。因此,最多循环100次。

现在有了一个 previousX,我们可以设置循环了:

代码语言:javascript
复制
while (previousStep > precision && iter > 0) {
    iter--;
    double currentY = f.apply(currentX);
    if (currentY > previousY) {
        stepCoefficient = -stepCoefficient/2;
    }
    previousX = currentX;
    currentX += stepCoefficient * previousY;
    previousY = currentY;
    previousStep = StrictMath.abs(currentX - previousX);
}

每次迭代,我们都会计算新的 y 值并将其与之前的 y 比较。如果 currentY 大于 previousY,将改变方向并减小步长。

循环会一直进行直到步长小于期望的precision 为止。最后,返回 currentX 作为本地最小值:

代码语言:javascript
复制
return currentX;

6.总结

本文分步骤介绍了梯度下降算法。

还用Java对算法进行了实现,完整源代码可以从 GitHub 下载。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-05-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Java研发军团 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档