假设data是m行两列的训练样本,labels是m行一列的类标签,类标签一共有3类,分别用1、2、3表示,现将data用散点图表示出来,且不同类的样本有不同的颜色:
import matplotlib.pyplot as pltfig = plt.figure()
ax = fig.add_subplot(111) # 创建一个一行一列的图
ax.scatter(data[:, 0], data[:, 1], 15.0*np.array(labels), 15.0*np.array(labels)) # 15.0是散点的大小
plt.show()
为了得到更好的效果,并以红色的'*'表示类标签1、蓝色的'o'表示表示类标签2、绿色的'+'表示类标签3,修改参数如下:
import numpy as np
import matplotlib.pyplot as pltfig = plt.figure()
ax = fig.add_subplot(111)
labels = np.array(labels)
idx_1 = np.where(labels == 1) # 找出第一类
p1 = ax.scatter(data[idx_1, 0], data[idx_1, 1], marker='*', color='r', label='1',s=20)
idx_2 = np.where(labels == 2) # 找出第二类
p2 = ax.scatter(data[idx_2, 0], data[idx_2, 1], marker = 'o',color ='b',label='2',s=10)
idx_3 = np.where(labels == 3) # 找出第三类
p3 = ax.scatter(data[idx_3, 0], data[idx_3, 1], marker = '+',color ='g',label='3',s=30)
plt.legend(loc='upper right')
plt.show()