import numpy as np
import matplotlib.pyplot as plt
%matplotlib inlinef = lambda x :(x-3)**2+2.5*x-7.5
f2 = lambda x :-(x-3)**2+2.5*x-7.5
求解导数 导数为0 取最小值
x = np.linspace(-2,5,100)
y = f(x)
plt.plot(x,y)
梯度下降求最小值
#导数函数
d = lambda x:2*(x-3)*1+2.5#学习率 需调节 每次改变数值的时候,改变多少
learning_rate = 0.1min_value = np.random.randint(-3,5,size =1)[0]
print('-'*30,min_value)
#记录数据更新了,原来的值,上一步的值
min_value_last = min_value+0.1#tollerence容忍度,误差,在万分之一,任务结束
tol = 0.0001
count = 0
while True:if np.abs(min_value - min_value_last)<tol:break
#梯度下降min_value_last = min_value
#更新值min_value = min_value - learning_rate*d(min_value)print("++++++++++%d"%count,min_value)count+=1
print("*"*30,min_value)
------------------------------ -2
++++++++++0 -1.25
++++++++++1 -0.6499999999999999
++++++++++2 -0.16999999999999993
++++++++++3 0.21400000000000008
++++++++++4 0.5212000000000001
++++++++++5 0.7669600000000001
++++++++++6 0.9635680000000001
++++++++++7 1.1208544
++++++++++8 1.24668352
++++++++++9 1.347346816
++++++++++10 1.4278774528
++++++++++11 1.49230196224
++++++++++12 1.543841569792
++++++++++13 1.5850732558336
++++++++++14 1.6180586046668801
++++++++++15 1.644446883733504
++++++++++16 1.6655575069868032
++++++++++17 1.6824460055894426
++++++++++18 1.695956804471554
++++++++++19 1.7067654435772432
++++++++++20 1.7154123548617946
++++++++++21 1.7223298838894356
++++++++++22 1.7278639071115485
++++++++++23 1.7322911256892388
++++++++++24 1.735832900551391
++++++++++25 1.7386663204411128
++++++++++26 1.7409330563528902
++++++++++27 1.7427464450823122
++++++++++28 1.7441971560658498
++++++++++29 1.74535772485268
++++++++++30 1.7462861798821439
++++++++++31 1.7470289439057152
++++++++++32 1.7476231551245722
++++++++++33 1.7480985240996578
++++++++++34 1.7484788192797263
++++++++++35 1.748783055423781
++++++++++36 1.7490264443390249
++++++++++37 1.7492211554712198
++++++++++38 1.749376924376976
++++++++++39 1.7495015395015807
++++++++++40 1.7496012316012646
****************************** 1.7496012316012646
更新值learning_rate*d(max_value) 最大/最小值导数为0
就可能满足np.abs(max_value - max_value_last)<precision:
d2 = lambda x:-2*(x-3)*1+2.5
#学习率 需调节 每次改变数值的时候,改变多少
learning_rate = 0.1
max_value = np.random.randint(-3,5,size =1)[0]
print('-'*30,min_value)
#记录数据更新了,原来的值,上一步的值
max_value_last = max_value+0.1
result =[]
#tollerence容忍度,误差,在万分之一,任务结束
#precision精确度, 误差,在万分之一,任务结束
precision = 0.0001
count = 0
while True:if count>3000:
# 避免梯度消失 rate =1
# 避免梯度爆炸 导数更新值有问题时 或 rate =10breakif np.abs(max_value - max_value_last)<precision:break
#梯度下降max_value_last = max_value#更新值learning_rate*d(max_value) 最大/最小值导数为0
# 就可能满足np.abs(max_value - max_value_last)<precision:max_value = max_value + learning_rate*d2(max_value)result.append(max_value)print("++++++++++%d"%count,max_value)count+=1
print("*"*30,max_value)
------------------------------ 1.7496012316012646
++++++++++0 0.050000000000000044
++++++++++1 0.8900000000000001
++++++++++2 1.5620000000000003
++++++++++3 2.0996
++++++++++4 2.52968
++++++++++5 2.873744
++++++++++6 3.1489952
++++++++++7 3.36919616
++++++++++8 3.545356928
++++++++++9 3.6862855424
++++++++++10 3.79902843392
++++++++++11 3.889222747136
++++++++++12 3.9613781977088
++++++++++13 4.01910255816704
++++++++++14 4.065282046533632
++++++++++15 4.102225637226906
++++++++++16 4.131780509781525
++++++++++17 4.15542440782522
++++++++++18 4.174339526260176
++++++++++19 4.18947162100814
++++++++++20 4.201577296806512
++++++++++21 4.2112618374452095
++++++++++22 4.219009469956168
++++++++++23 4.225207575964935
++++++++++24 4.230166060771948
++++++++++25 4.234132848617558
++++++++++26 4.237306278894047
++++++++++27 4.239845023115238
++++++++++28 4.24187601849219
++++++++++29 4.2435008147937525
++++++++++30 4.244800651835002
++++++++++31 4.2458405214680015
++++++++++32 4.246672417174401
++++++++++33 4.247337933739521
++++++++++34 4.247870346991617
++++++++++35 4.248296277593293
++++++++++36 4.248637022074634
++++++++++37 4.248909617659708
++++++++++38 4.2491276941277665
++++++++++39 4.249302155302213
++++++++++40 4.249441724241771
++++++++++41 4.249553379393417
++++++++++42 4.249642703514733
****************************** 4.249642703514733
ret = ret= ret*stepx = np.linspace(0,6,100)
y = f2(x)
plt.plot(x,y)result = np.asanyarray(result)
plt.plot(result,f2(result),'*')