目录
k-近邻算法概述
k-近邻算法的一般流程
kNN算法伪代码
k-近邻算法概述
优点:精度高、对异常值不敏感、无数据输入假定
缺点:计算复杂度高、空间复杂度高
适用数据范围:数值型和标称型
k-近邻算法的一般流程
(1)收集数据
(2)准备数据
(3)分析数据
(4)训练算法(不需要)
(5)测试算法
(6)使用算法
from numpy import *
import operator
def createDataSet():group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])labels = ['A', 'A', 'B', 'B']return group, labels
group, labels = createDataSet()
group
array([[1. , 1.1],[1. , 1. ],[0. , 0. ],[0. , 0.1]])
labels
['A', 'A', 'B', 'B']
import matplotlib.pyplot as plt
x = group[:, 0]
y = group[:, 1]
plt.scatter(x, y)
plt.xlim(-0.2, 1.2)
plt.ylim(-0.2, 1.2)
for i, pos in enumerate(zip(x, y)):plt.text(pos[0]-0.01, pos[1], f'{labels[i]}', ha='right')
plt.show()
kNN算法伪代码
对未知类别属性的数据集中的每个点依次执行以下操作:
(1)计算已知类别数据集中的点与当前点之间的距离
(2)按照距离递增的次序排列
(3)选取与当前点距离最小的k个点
(4)确定前k个点所在类别的出现频率
(5)返回前k个点出现频率最高的类别作为当前点的预测分类
def classify0(inX, dataSet, labels, k):dataSetSize = dataSet.shape[0]diffMat = tile(inX, (dataSetSize, 1)) - dataSetsqDiffMat = diffMat ** 2sqDistances = sqDiffMat.sum(axis=1)distances = sqDistances**0.5sortedDistIndicies = distances.argsort()classCount = {}for i in range(k):voteIlabel = labels[sortedDistIndicies[i]]classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]
classify0([0, 0], group, labels, 3)
'B'
这段代码实现了k近邻算法中的分类函数,用于根据输入的数据点inX
,在数据集dataSet
中找到距离最近的k个邻居,并统计它们的类别标签,最终返回频率最高的类别。
现在让我们逐步分析这段代码:
-
dataSetSize = dataSet.shape[0]
: 获取数据集的行数,即数据点的数量。 -
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
: 将输入数据点inX
复制成与数据集相同大小的矩阵,然后计算与数据集中每个点的差值。 -
sqDiffMat = diffMat ** 2
: 对差值矩阵的每个元素进行平方操作。 -
sqDistances = sqDiffMat.sum(axis=1)
: 沿着列的方向对平方差值矩阵进行求和,得到每个数据点与输入点的平方距离。 -
distances = sqDistances**0.5
: 对平方距离进行开方,得到真实距离。 -
sortedDistIndicies = distances.argsort()
: 对距离进行排序,返回排序后的索引值。 -
classCount = {}
: 初始化一个空字典,用于存储每个类别的投票数。 -
for i in range(k):
: 遍历前k个最小距离的索引。 -
voteIlabel = labels[sortedDistIndicies[i]]
: 获取对应索引的类别标签。 -
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
: 统计每个类别的投票数,使用get
方法获取字典中的值,如果键不存在则返回默认值0。 -
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
: 对字典按照值进行排序,items()
方法返回字典的键值对,key=operator.itemgetter(1)
表示按照值排序,reverse=True
表示降序排列。 -
return sortedClassCount[0][0]
: 返回排序后的字典中频率最高的类别标签,即k个邻居中出现最多的类别。
这个函数的核心思想是通过计算输入点与数据集中每个点的距离,找到距离最近的k个邻居,然后通过投票机制确定输入点的类别。