梯度下降法是一个一阶最优化算法,通常也称为最陡下降法,要使用梯度下降法找到一个函数的局部极小值,必须向函数上当前点对应梯度的反方向的规定步长距离点进行迭代搜索。如果相反地向梯度正方向迭代进行搜索,则会接近函数的局部极大值点;这个过程则被称为梯度上升法。
介绍梯度下降法之前首先先介绍一下梯度。梯度的本意是一个向量(矢量),表示某一函数在该点处的方向导数沿着该方向取得最大值,即函数在该点处沿着该方向(此梯度的方向)变化最快,变化率最大(为该梯度的模)。
首先介绍一元函数:
梯度下降法的公式:
其中
是学习率(learn rate),
就是梯度。
接下来我们来介绍一下这个公式是怎么来的
首先从泰勒展开式入手
一阶泰勒展开式为
现在我们令
其中
为步进长度,是标量。而
表示的是
方向的单位向量。现在
可以表示成
移项后得到
因为梯度下降法每次更新
希望
可以减小,那么得到的是
。其中
和
都是矢量。现在我们来讨论如何保证
最小。
此时我们需要引入向量的知识, 看下面的三幅图
图1
图2
图3
可以看到图1的锐角和图2的钝角都不是最小的结果,只有图3的
时
使得
有最小值,即沿着导数
的反方向,函数下降的最快。
到此为止,我们分析了为什么导数沿着函数的反方向会下降最快,现在在把公式整理一下。
因为
是单位向量,且方向是沿着
的反方向,因此
,现在我们将
代入到
中得:
,我们将
都令为
因此可以写成
,不断迭代此公式即
。其中
被称为learn rate学习率
上述就是梯度下降法的推导过程。
求此函数
实现代码
#实现f = x**2+2*x - 0.2的最小值
import numpy as np
#定义fx
def f(x):
f = x**2+2*x - 0.2
return f
#定义导数
def df(x):
df = 2*x+2
return df
# f:函数
# df:导数
# x:zip变量
# err:误差,用来判断是否满足求解要求
def gradient_discent(f, df, x, err, learn_rate):
loop = 1
x_i = float(x)
e_tmp = err + 1
while e_tmp > err:
#超出循环次数
if loop > 1000:
print('cycles exceeded')
break
print('######loop'+str(loop))
f_tmp = f(x_i)
df_tmp = df(x_i)
print('xi = ' + str(x_i) + ',f = ' + str(f_tmp) + ',df = ' + str(df_tmp))
x_i = x_i - learn_rate*df_tmp
e_tmp = abs(f(x_i) - f_tmp)
print('err = ' + str(e_tmp))
loop = loop + 1
return x_i
x = gradient_discent(f, df, 3, 0.00000001, 0.01)
print(x)
多元函数的代码
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
def Fun(x,y):#原函数
return x-y+2*x*x+2*x*y+y*y
def PxFun(x,y):#偏x导
return 1+4*x+2*y
def PyFun(x,y):#偏y导
return -1+2*x+2*y
#初始化
fig=plt.figure()#figure对象
ax=Axes3D(fig)#Axes3D对象
X,Y=np.mgrid[-2:2:40j,-2:2:40j]#取样并作满射联合
Z=Fun(X,Y)#取样点Z坐标打表
ax.plot_surface(X,Y,Z,rstride=1,cstride=1,cmap="rainbow")
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
#梯度下降
step=0.0008#下降系数
x=0
y=0#初始选取一个点
tag_x=[x]
tag_y=[y]
tag_z=[Fun(x,y)]#三个坐标分别打入表中,该表用于绘制点
new_x=x
new_y=y
Over=False
while Over==False:
new_x-=step*PxFun(x,y)
new_y-=step*PyFun(x,y)#分别作梯度下降
if Fun(x,y)-Fun(new_x,new_y)<7e-9:#精度
Over=True
x=new_x
y=new_y#更新旧点
tag_x.append(x)
tag_y.append(y)
tag_z.append(Fun(x,y))#新点三个坐标打入表中
#绘制点/输出坐标
ax.plot(tag_x,tag_y,tag_z,'r.')
plt.title('(x,y)~('+str(x)+","+str(y)+')')
plt.show()
Caption