基于前文关于梯度下降法的理解,用python实现梯度下降求解,不过本文不具有通用性,关于求导缺乏通用性,关于梯度也未考虑很多因素,可以看到学习率很低,则收敛较慢,需要大量时间学习,学习率很高,则收敛很快,但有可能找不到极小值。本文只是浅尝辄止。
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# _ooOoo_
# o8888888o
# 88" . "88
# ( | - _ - | )
# O\ = /O
# ____/`---'\____
# .' \\| |// `.
# / \\|||:|||// \
# / _|||||-:- |||||- \
# | | \\\ - /// | |
# | \_| ''\---/'' | _/ |
# \ .-\__ `-` ___/-. /
# ___`. .' /--.--\ `. . __
# ."" '< `.___\_<|>_/___.' >'"".
# | | : `- \`.;`\ _ /`;.`/ - ` : | |
# \ \ `-. \_ __\ /__ _/ .-` / /
# ==`-.____`-.___\_____/___.-`____.-'==
# `=---=''''
@Project :pythonalgorithms
@File :Gradientsolution.py
@Author :不胜人生一场醉@Date :2021/8/3 1:17
'''import matplotlib.pyplot as plt
import numpy as np
import math
# 函数z=x^2+y^2,用梯度下降法求解,使函数取得最小值
# 首先求梯度 (∂f/∂x,∂f/∂y)=(2x,2y)
# 设定初始值位置 (x0,y0)=(3,2)
# 设定学习率η= 0.1
# 设定学习次数 t=50
# z为当前位置的求解值
# z=x0^2+y0^2
# (Δx,Δy)=-η(∂f/∂x,∂f/∂y)(x0,y0)
# (x1,y1)= (x0,y0)+(Δx,Δy)
# z=x1^2+y1^2
# (Δx,Δy)=-η(∂f/∂x,∂f/∂y)=-η(∂f/∂x,∂f/∂y)(x1,y1)
# (x2,y2)= (x1,y1)+(Δx,Δy)
# z=x2^2+y2^2
# 以此类推进行循环def solution1(u=0.1):
xdata = []
ydata = []
tdata = []
print('---------------当前学习率为{}----------------'.format(u))
x, y, u = 3, 2, u
for t in range(20):
z = x ** 2 + y ** 2 # 函数求解orgx, orgy = x, y # 保留当前位置信息xdata.append(orgx)
ydata.append(orgy)
tdata.append(t)
xt, yt = x * 2, y * 2 # 求梯度值xz, yz = -u * xt, -u * yt # 求增量值x, y = x + xz, y + yz # 求下一次迭代位置print(
'loop:{},current position=({:+.4f},{:+.4f}),Gradient({:+.4f},{:+.4f}),=step=({:+.4f},{:+.4f}),function value={:+.4f}'.format(
t, orgx, orgy, xt, yt, xz, yz, z))
return xdata, ydata, tdata
def drawtrack(xdata, ydata, tdata):
plt.figure(figsize=(10, 5))
ax = plt.gca() # 通过gca:get current axis得到当前轴plt.rcParams['font.sans-serif'] = ['SimHei'] # 绘图中文plt.rcParams['axes.unicode_minus'] = False # 绘图负号plt.plot(xdata, ydata, "ob")
for i in range(0, len(xdata)):
ax.text(xdata[i], ydata[i] + 0.1, tdata[i])
# 设置图片的右边框和上边框为不显示ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
# 挪动x,y轴的位置,也就是图片下边框和左边框的位置
# data表示通过值来设置x轴的位置,将x轴绑定在y=0的位置ax.spines['bottom'].set_position(('data', 0))
# axes表示以百分比的形式设置轴的位置,即将y轴绑定在x轴50%的位置
# ax.spines['left'].set_position(('axes', 0.5))ax.spines['left'].set_position(('data', 0))
plt.title("求梯度")
# plt.legend(loc='upper right')plt.show()
if __name__ == '__main__':
# 学习率0.4,下降很快xdata, ydata, tdata = solution1(0.4)
drawtrack(xdata, ydata, tdata)
# 学习率0.1xdata, ydata, tdata = solution1(0.1)
drawtrack(xdata, ydata, tdata)
# 学习率0.01,收敛效果不佳,还需要更长的时间和次数来学习xdata, ydata, tdata = solution1(0.01)
drawtrack(xdata, ydata, tdata)
C:\Python\Python37\python.exe C:/Python/Pycharm/system_api_test/Gradientsolution.py
---------------当前学习率为0.4----------------
loop:0,current position=(+3.0000,+2.0000),Gradient(+6.0000,+4.0000),=step=(-2.4000,-1.6000),function value=+13.0000
loop:1,current position=(+0.6000,+0.4000),Gradient(+1.2000,+0.8000),=step=(-0.4800,-0.3200),function value=+0.5200
loop:2,current position=(+0.1200,+0.0800),Gradient(+0.2400,+0.1600),=step=(-0.0960,-0.0640),function value=+0.0208
loop:3,current position=(+0.0240,+0.0160),Gradient(+0.0480,+0.0320),=step=(-0.0192,-0.0128),function value=+0.0008
loop:4,current position=(+0.0048,+0.0032),Gradient(+0.0096,+0.0064),=step=(-0.0038,-0.0026),function value=+0.0000
loop:5,current position=(+0.0010,+0.0006),Gradient(+0.0019,+0.0013),=step=(-0.0008,-0.0005),function value=+0.0000
loop:6,current position=(+0.0002,+0.0001),Gradient(+0.0004,+0.0003),=step=(-0.0002,-0.0001),function value=+0.0000
loop:7,current position=(+0.0000,+0.0000),Gradient(+0.0001,+0.0001),=step=(-0.0000,-0.0000),function value=+0.0000
loop:8,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:9,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:10,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:11,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:12,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:13,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:14,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:15,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:16,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:17,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:18,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
loop:19,current position=(+0.0000,+0.0000),Gradient(+0.0000,+0.0000),=step=(-0.0000,-0.0000),function value=+0.0000
---------------当前学习率为0.1----------------
loop:0,current position=(+3.0000,+2.0000),Gradient(+6.0000,+4.0000),=step=(-0.6000,-0.4000),function value=+13.0000
loop:1,current position=(+2.4000,+1.6000),Gradient(+4.8000,+3.2000),=step=(-0.4800,-0.3200),function value=+8.3200
loop:2,current position=(+1.9200,+1.2800),Gradient(+3.8400,+2.5600),=step=(-0.3840,-0.2560),function value=+5.3248
loop:3,current position=(+1.5360,+1.0240),Gradient(+3.0720,+2.0480),=step=(-0.3072,-0.2048),function value=+3.4079
loop:4,current position=(+1.2288,+0.8192),Gradient(+2.4576,+1.6384),=step=(-0.2458,-0.1638),function value=+2.1810
loop:5,current position=(+0.9830,+0.6554),Gradient(+1.9661,+1.3107),=step=(-0.1966,-0.1311),function value=+1.3959
loop:6,current position=(+0.7864,+0.5243),Gradient(+1.5729,+1.0486),=step=(-0.1573,-0.1049),function value=+0.8934
loop:7,current position=(+0.6291,+0.4194),Gradient(+1.2583,+0.8389),=step=(-0.1258,-0.0839),function value=+0.5717
loop:8,current position=(+0.5033,+0.3355),Gradient(+1.0066,+0.6711),=step=(-0.1007,-0.0671),function value=+0.3659
loop:9,current position=(+0.4027,+0.2684),Gradient(+0.8053,+0.5369),=step=(-0.0805,-0.0537),function value=+0.2342
loop:10,current position=(+0.3221,+0.2147),Gradient(+0.6442,+0.4295),=step=(-0.0644,-0.0429),function value=+0.1499
loop:11,current position=(+0.2577,+0.1718),Gradient(+0.5154,+0.3436),=step=(-0.0515,-0.0344),function value=+0.0959
loop:12,current position=(+0.2062,+0.1374),Gradient(+0.4123,+0.2749),=step=(-0.0412,-0.0275),function value=+0.0614
loop:13,current position=(+0.1649,+0.1100),Gradient(+0.3299,+0.2199),=step=(-0.0330,-0.0220),function value=+0.0393
loop:14,current position=(+0.1319,+0.0880),Gradient(+0.2639,+0.1759),=step=(-0.0264,-0.0176),function value=+0.0251
loop:15,current position=(+0.1056,+0.0704),Gradient(+0.2111,+0.1407),=step=(-0.0211,-0.0141),function value=+0.0161
loop:16,current position=(+0.0844,+0.0563),Gradient(+0.1689,+0.1126),=step=(-0.0169,-0.0113),function value=+0.0103
loop:17,current position=(+0.0676,+0.0450),Gradient(+0.1351,+0.0901),=step=(-0.0135,-0.0090),function value=+0.0066
loop:18,current position=(+0.0540,+0.0360),Gradient(+0.1081,+0.0721),=step=(-0.0108,-0.0072),function value=+0.0042
loop:19,current position=(+0.0432,+0.0288),Gradient(+0.0865,+0.0576),=step=(-0.0086,-0.0058),function value=+0.0027
---------------当前学习率为0.01----------------
loop:0,current position=(+3.0000,+2.0000),Gradient(+6.0000,+4.0000),=step=(-0.0600,-0.0400),function value=+13.0000
loop:1,current position=(+2.9400,+1.9600),Gradient(+5.8800,+3.9200),=step=(-0.0588,-0.0392),function value=+12.4852
loop:2,current position=(+2.8812,+1.9208),Gradient(+5.7624,+3.8416),=step=(-0.0576,-0.0384),function value=+11.9908
loop:3,current position=(+2.8236,+1.8824),Gradient(+5.6472,+3.7648),=step=(-0.0565,-0.0376),function value=+11.5160
loop:4,current position=(+2.7671,+1.8447),Gradient(+5.5342,+3.6895),=step=(-0.0553,-0.0369),function value=+11.0599
loop:5,current position=(+2.7118,+1.8078),Gradient(+5.4235,+3.6157),=step=(-0.0542,-0.0362),function value=+10.6219
loop:6,current position=(+2.6575,+1.7717),Gradient(+5.3151,+3.5434),=step=(-0.0532,-0.0354),function value=+10.2013
loop:7,current position=(+2.6044,+1.7363),Gradient(+5.2088,+3.4725),=step=(-0.0521,-0.0347),function value=+9.7973
loop:8,current position=(+2.5523,+1.7015),Gradient(+5.1046,+3.4031),=step=(-0.0510,-0.0340),function value=+9.4094
loop:9,current position=(+2.5012,+1.6675),Gradient(+5.0025,+3.3350),=step=(-0.0500,-0.0333),function value=+9.0368
loop:10,current position=(+2.4512,+1.6341),Gradient(+4.9024,+3.2683),=step=(-0.0490,-0.0327),function value=+8.6789
loop:11,current position=(+2.4022,+1.6015),Gradient(+4.8044,+3.2029),=step=(-0.0480,-0.0320),function value=+8.3352
loop:12,current position=(+2.3542,+1.5694),Gradient(+4.7083,+3.1389),=step=(-0.0471,-0.0314),function value=+8.0051
loop:13,current position=(+2.3071,+1.5380),Gradient(+4.6141,+3.0761),=step=(-0.0461,-0.0308),function value=+7.6881
loop:14,current position=(+2.2609,+1.5073),Gradient(+4.5219,+3.0146),=step=(-0.0452,-0.0301),function value=+7.3837
loop:15,current position=(+2.2157,+1.4771),Gradient(+4.4314,+2.9543),=step=(-0.0443,-0.0295),function value=+7.0913
loop:16,current position=(+2.1714,+1.4476),Gradient(+4.3428,+2.8952),=step=(-0.0434,-0.0290),function value=+6.8105
loop:17,current position=(+2.1280,+1.4186),Gradient(+4.2559,+2.8373),=step=(-0.0426,-0.0284),function value=+6.5408
loop:18,current position=(+2.0854,+1.3903),Gradient(+4.1708,+2.7805),=step=(-0.0417,-0.0278),function value=+6.2818
loop:19,current position=(+2.0437,+1.3625),Gradient(+4.0874,+2.7249),=step=(-0.0409,-0.0272),function value=+6.0330
Process finished with exit code 0
原创不易,转载请注明!请多多关注,谢谢!