如果输入数据长度为2,上一章的方程就无法满足需求了,需要修改方程: z = w 1 x + w 2 y + b z=w_1x+w_2y+b z=w1x+w2y+b
数据产生器:
import matplotlib.pyplot as plt
import numpy as npclass DataGenerator2Input:"""线性回归数据产生器, 方程:z = w1 * x + w2 * y + b"""def __init__(self, w1, w2, b):self.w1 = w1self.w2 = w2self.b = bdef __call__(self, data_len):input_data = np.random.uniform(-50, 50, [data_len, 2]) # 生成 x, ylabels = self.w1 * input_data[:, 0] + self.w2 * input_data[:, 1] + self.b # 生成 z# 加随机误差noise = np.random.uniform(-20, 20, data_len)labels += noisereturn input_data, labelsw1, w2, b = 3.5, 7.1, 17
input_datas, labels = DataGenerator2Input(w1, w2, b)(5000)# 可视化
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
ax.scatter(labels, input_datas[:, 0], input_datas[:, 0])
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
plt.show()
分段函数问题