Java教程

梯段下降算法

本文主要是介绍梯段下降算法,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

1 线性回归

1.1 引子

引子-示例

比如上面这个图,可以感觉到是存在这样一条直线L:

  • (1)这条直线尽可能反映出数据点的整体走向、趋势
  • (2)给定x,代入这条直线中求解出来的y,我们称之为预测值ypredict;该x实际的取值y,我们称之为真实值。易知,图中每个点的x代入直线L求解出的ypredict,与该点实际的y值之间存在差距,并且所有点的差距总和会最小(或者说差距的均值最小)

因为上面的例子是二维的,我们可以去设这条直线:\(f(x)=kx\),简单起见,我们暂时先不考虑b
有了方程后,我们就可以表示(2)中的内容了:

\[CostFunction =\frac{1}{2n}\sum_{i=1}^{n}\left ( f(x_{i})-y_{i}\right )^{2}=\frac{1}{2n}\sum_{i=1}^{n}\left ( kx_{i}-y_{i}\right )^{2}① \]

须知:这里的\(\frac{1}{2n}\)不是必须的

由①式可以知晓,CostFunction是k的二次函数,通过给定不同的k,CostFunction会得到不同的取值,如下图:
CostFunctionValue-k

如果你要考虑b的话,可以得到下图,如同一张弯曲的纸,或者峡谷,存在最低点

image

1.2 梯度下降

易知,沿着梯度(导数)的方向,函数值增加最快,换个角度,沿着梯度(导数)的反方向,函数值降低最快。
我们可以令CostFunction对k求导,得:

\[{CostFunction}'_{k}=\frac{1}{2n}\sum_{i=1}^{n}2x_{i}(kx_{i}-y_{i})=\frac{1}{n}\sum_{i=1}^{n}x_{i}(kx_{i}-y_{i})② \]

在更新k时,我们可以:

\[k=k-②=k-\frac{1}{n}\sum_{i=1}^{n}x_{i}(kx_{i}-y_{i})③ \]

通过③可知,每次k的变化的步长是②,但实际操作而言,可能会存在一定的问题,假设当k=k0时函数值取得局部极小值。若k每次变化的步长很大,可能会导致k无法达到k0的位置,可能就是在k0这个点左右来回震荡,无法收敛;或者经过数次更新后会离k0越来越远。
所以,为了避免这种情况,我们会增加一个参数\(\alpha\)对③进行修改:\(k=k-\alpha\frac{1}{n}\sum_{i=1}^{n}x_{i}(kx_{i}-y_{i})\)
这个\(\alpha\),我们称之为学习率Learning Rate

  • 学习率越大,k每次更新的步长就越大,出现无法取得局部极小值点、无法收敛的概率也会增大
  • 学习率越小,k每次更新的步长就越小,达到局部极小值点的时间也会变长(至少不会错过)

须知,上述对k0的描述,使用的是“局部极小值”,而非“全局极小值”。
梯度下降法得到的是“局部最小值”而非“全局最小值”
对于某个山路十八弯的函数,在进行梯度下降时,很容易陷入局部最小值,而无法达到全局极小值
对于凸函数而言(比如上面的开口向上的抛物线),局部极小,即为全局极小

image

从上图可以看到

  • 当lr=0.01时,k每次变化的幅度较小,所以整体曲线较为平滑
  • 当lr=0.1时,k的变化波动比较大,在极值点左右反复横跳,最后落到极值点
  • 当lr=0.05时,可以看到到达极小值点的速度比较快,那么相对于lr=0.01与lr=0.1而言,0.05会更好

2 未完待续

这篇关于梯段下降算法的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!