scipy.optimize作为优化模块可以实现任意曲线拟合,方程求根、非线性方程组求解、自定义代价函数求解等功能,下面给出了optimize中常用的几个子模块:
minimize:需要自己构建代价函数(有时也称损失函数,目标函数等),理论上可以求解任意最优化问题
curve_fit:可以拟合任意的显式函数曲线,对于隐式函数曲线不能拟合
fsolve:方程根、求解适定方程组,需要满足未知数数量等于方程数量
least_squares、leastsq:这两个可以用于求解最小二乘问题
root:求解方程根
minimize模块灵活度最多,功能最为强大,理论上上述子模块都可以看作minimize的特例:下面分别给出使用案例:
使用minimize拟合直线、圆、椭圆,使用fsolve、least_squares求解方程组,使用curve_fit拟合抛物线
from scipy.optimize import minimize
from scipy.optimize import curve_fit
from scipy.optimize import fsolve
from scipy.optimize import least_squares
from scipy.optimize import leastsq
from matplotlib import pyplot as plt
import numpy as np
import math
def line_fun(params,x,y):a,b=paramsreturn np.sum(np.square(x*a+b-y))def circle_fun(params,x,y):#圆的一般方程:x^2+y^2+ax+by+c=0a,b,c=paramsreturn np.sum(np.square(x*x+y*y+a*x+b*y+c))def ellipse_model2d(params, x, y):#椭圆一般方程:A * x ** 2 + B * y ** 2 + C * x * y + D * x + E * y + F=0A, B, C, D, E, F= paramserror=A * x ** 2 + B * y ** 2 + C * x * y + D * x + E * y + Fsquares = np.square(error) # 将数组a每个元素平方sum_of_squares = np.sum(squares) # 对平方后的数组求和return sum_of_squares#超定方程组
def non_linear_equations(var):x = var[0]y = var[1]Func= np.empty((3))Func[0] = x**2+y-5Func[1] = x+y-3Func[2] = 4*x+y**2-9return Func#适定方程组
def non_linear_equations2(var):x = var[0]y = var[1]Func= np.empty((2))Func[0] = x**2+y-5Func[1] = x+y-3return Funcdef curve_line(x, a, b):return a*x**2+b*x+10# 按间距中的绿色按钮以运行脚本。
if __name__ == '__main__':#生成直线数据x = np.linspace(-20, 20, 100)y = x * 2 + 1# plt.scatter(x,y)# plt.show()x0 = np.array([1.0, 0.8])#求解直线方程,当然这里也可以用curve_fit模块result = minimize(line_fun, x0, args=(x, y))print(r"line ab:", result.x)# 生成圆数据m = np.linspace(0, 2 * math.pi, 100)a = 3b = -6r = 5x = a + r * np.cos(m)y = b + r * np.sin(m)figure, axes = plt.subplots(1)axes.plot(x, y)axes.set_aspect(1)plt.show()#拟合圆result = minimize(circle_fun, np.ones(3), args=(x, y))print(r"circle abc:", result.x)#ellipse#datatheta_samples = np.linspace(0, 20, 100)# 椭圆方位角alpha_samples = -45.0 / 180.0 * np.pi# 长轴长度a_samples = 1.0# 短轴长度b_samples = 2.0# 样本x 序列,并叠加正态分布的随机值x_samples = a_samples * np.cos(theta_samples) * np.cos(alpha_samples) \- b_samples * np.sin(theta_samples) * np.sin(alpha_samples) +1\# + np.random.randn(100) * 0.05 * a_samples# 样本y 序列 ,并叠加正态分布的随机值y_samples = b_samples * np.sin(theta_samples) * np.cos(alpha_samples) \+ a_samples * np.cos(theta_samples) * np.sin(alpha_samples) +2\# + np.random.randn(100) * 0.05 * b_samplesz_samples=np.zeros(100)plt.axes([0.16, 0.15, 0.75, 0.75])plt.scatter(x_samples, y_samples, color="magenta", marker="+",zorder=1, s=80, label="samples")plt.show()#fit ellipseresult=minimize(ellipse_model2d,np.ones(6),args=(x_samples,y_samples))print("ellipse abcdef:",result.x)#solve non_linear_equationsa = np.array([0.0,0])b = least_squares(non_linear_equations, a)print("non_linear_equations x0,x1:", b.x)a = np.array([0.0,0])b = fsolve(non_linear_equations2, a)#方程数量和未知数数量要保持一致print("non_linear_equations x0,x1:",b)#fit curve_line# 这部分生成样本点,对函数值加上高斯噪声作为样本点xdata = np.linspace(-5, 5, 50)## a=2.5, b=1.3y = curve_line(xdata, 2.5, 1.3)np.random.seed(1)err_stdev = 0.2# 生成均值为0,标准差为err_stdev为0.2的高斯噪声y_noise = err_stdev * np.random.normal(size=len(xdata))ydata = y + y_noiseplt.scatter(xdata, ydata, label='data')# 利用curve_fit作简单的拟合,popt为拟合得到的参数,pcov是参数的协方差矩阵popt_1, pcov = curve_fit(curve_line, xdata, ydata)print("curve_line_result:",popt_1)plt.plot(xdata, curve_line(xdata, *popt_1), 'r-', label='fit_1')plt.show()
结果:
line ab: [1.99999999 0.99999999]
circle abc: [-6.00000001 11.99999997 19.99999983]
ellipse abcdef: [-0.04833563 -0.04833556 0.05800266 -0.01933401 0.13533947 -0.04833538]
non_linear_equations x0,x1: [2. 1.]
non_linear_equations x0,x1: [-1. 4.]
curve_line_result: [2.50137612 1.30813455]
椭圆拟合结果:
抛物线拟合结果:
参考:
1
2
3
4
5
6